Sfoglia il codice sorgente

feat: 更新jenkinsfile 完善点云算法

yangliu 3 mesi fa
parent
commit
1cdae60414

BIN
README.md


+ 1 - 1
device-service-application/src/main/java/com/hfln/device/application/event/EventHandlerImpl.java

@@ -39,7 +39,7 @@ public class EventHandlerImpl implements EventHandler {
         try {
             Device device = deviceService.getDeviceById(deviceId);
             if (device == null) {
-                log.warn("设备不存在: {}", deviceId);
+                log.debug("设备不存在: {}", deviceId);
                 return;
             }
             device.setLastReportFallTime(event.getTimestamp());

+ 6 - 6
device-service-application/src/main/java/com/hfln/device/application/service/impl/DeviceEventServiceImpl.java

@@ -1472,12 +1472,12 @@ public class DeviceEventServiceImpl implements DeviceEventService {
             Integer y2 = ((Number) base.get("y_cm_stop")).intValue();     // y2 = base["y_cm_stop"]
             Integer z2 = ((Number) base.get("z_cm_stop")).intValue();     // z2 = base["z_cm_stop"]
             
-            Number presenceEnterDurationSec = (Number) extRegion.get("presenceEnterDuration_sec");  // presenceEnterDuration_sec = ext_region["presenceEnterDuration_sec"]
-            Number presenceExitDurationSec = (Number) extRegion.get("presenceExitDuration_sec");    // presenceExitDuration_sec = ext_region["presenceExitDuration_sec"]
-            
-            // 解析其他字段
-            Map<String, Object> fallingStateMachineDurations = (Map<String, Object>) messageData.get("fallingStateMachineDurations");  // fallingStateMachineDurations = payload["fallingStateMachineDurations"]
-            int is45Degree = (int) messageData.get("is45Degree");                 // is45Degree = payload["is45Degree"]
+//            Number presenceEnterDurationSec = (Number) extRegion.get("presenceEnterDuration_sec");  // presenceEnterDuration_sec = ext_region["presenceEnterDuration_sec"]
+//            Number presenceExitDurationSec = (Number) extRegion.get("presenceExitDuration_sec");    // presenceExitDuration_sec = ext_region["presenceExitDuration_sec"]
+//
+//            // 解析其他字段
+//            Map<String, Object> fallingStateMachineDurations = (Map<String, Object>) messageData.get("fallingStateMachineDurations");  // fallingStateMachineDurations = payload["fallingStateMachineDurations"]
+//            int is45Degree = (int) messageData.get("is45Degree");                 // is45Degree = payload["is45Degree"]
             int isCeiling = (int) messageData.get("isCeiling");                   // isCeiling = payload["isCeiling"]
             
             // === 构造安装参数对象 (对应Python: install_param = InstallParam(mount_plain="", height=sensor_height, tracking_region=TrackingRegion(x1,y1,z1,x2,y2,z2))) ===

+ 36 - 15
device-service-common/src/main/java/com/hfln/device/common/util/PointCloudUtil.java

@@ -14,46 +14,67 @@ public class PointCloudUtil {
     }
     
     /**
-     * 获取目标点
-     * 对应Python版本中的get_tracker_targets函数
+     * 获取目标点 (对应Python版本的get_tracker_targets函数)
      * 
-     * @param pointCloud 点云数据列表
-     * @return 目标点列表
+     * Python版本实现:
+     * def get_tracker_targets(point_cloud:list):
+     *     target_point = numpy.mean(point_cloud, axis=0).tolist()
+     *     tracker_targets = []
+     *     tracker_targets.append(target_point)
+     *     return tracker_targets
+     * 
+     * @param pointCloud 点云数据列表,格式:List<List<Float>>,每个点包含[x, y, z]坐标
+     * @return 跟踪目标列表,格式:List<List<Float>>,每个目标包含[x, y, z]
      */
     public static List<List<Float>> getTrackerTargets(List<List<Float>> pointCloud) {
         if (pointCloud == null || pointCloud.isEmpty()) {
             return new ArrayList<>();
         }
         
-        // 计算点云的平均值
+        // 对应Python: target_point = numpy.mean(point_cloud, axis=0).tolist()
         List<Float> targetPoint = calculateMeanPoint(pointCloud);
         
-        // 创建目标点列表
+        // 对应Python: tracker_targets = []
         List<List<Float>> trackerTargets = new ArrayList<>();
+        
+        // 对应Python: tracker_targets.append(target_point)
         trackerTargets.add(targetPoint);
         
+        // 对应Python: return tracker_targets
         return trackerTargets;
     }
     
     /**
-     * 获取多个目标点
-     * 对应Python版本中的get_tracker_targets_mult函数
+     * 获取多个目标点 (对应Python版本的get_tracker_targets_mult函数)
      * 
-     * @param pointCloud 点云数据列表
-     * @return 多个目标点列表
+     * Python版本实现:
+     * def get_tracker_targets_mult(point_cloud:list):
+     *     target_point = numpy.mean(point_cloud, axis=0).tolist()
+     *     tracker_targets = []
+     *     tracker_targets.append(target_point)
+     *     return tracker_targets
+     * 
+     * 注意:在Python版本中,get_tracker_targets_mult与get_tracker_targets实现完全相同
+     * 
+     * @param pointCloud 点云数据列表,格式:List<List<Float>>,每个点包含[x, y, z]坐标
+     * @return 跟踪目标列表,格式:List<List<Float>>,每个目标包含[x, y, z]
      */
     public static List<List<Float>> getTrackerTargetsMult(List<List<Float>> pointCloud) {
-        // 在当前版本中,和getTrackerTargets实现相同
+        // 在Python版本中,get_tracker_targets_mult和get_tracker_targets实现相同
         return getTrackerTargets(pointCloud);
     }
     
     /**
-     * 计算点云的平均点
+     * 计算点云的平均点 (对应Python的numpy.mean(point_cloud, axis=0))
      * 
      * @param pointCloud 点云数据列表
-     * @return 平均点
+     * @return 平均点坐标 [x, y, z]
      */
     private static List<Float> calculateMeanPoint(List<List<Float>> pointCloud) {
+        if (pointCloud.isEmpty()) {
+            return new ArrayList<>();
+        }
+        
         int dimensions = pointCloud.get(0).size();
         List<Float> mean = new ArrayList<>(dimensions);
         
@@ -62,14 +83,14 @@ public class PointCloudUtil {
             mean.add(0.0f);
         }
         
-        // 累加所有点的坐标
+        // 累加所有点的坐标 (对应numpy.mean的求和步骤)
         for (List<Float> point : pointCloud) {
             for (int i = 0; i < dimensions; i++) {
                 mean.set(i, mean.get(i) + point.get(i));
             }
         }
         
-        // 计算平均值
+        // 计算平均值 (对应numpy.mean的除法步骤)
         for (int i = 0; i < dimensions; i++) {
             mean.set(i, mean.get(i) / pointCloud.size());
         }

+ 145 - 11
device-service-domain/src/main/java/com/hfln/device/domain/entity/Device.java

@@ -123,6 +123,88 @@ public class Device {
     private Map<String, Object> alarmSchedule;
     
     /**
+     * 目标稳定器 - 对应Python版本的TargetStabilizer类
+     * 动态阈值调整的目标稳定算法
+     */
+    private static class TargetStabilizer {
+        private final java.util.Deque<List<Float>> targetList;
+        private final float baseThreshold;
+        
+        public TargetStabilizer(int queueLength, float baseThreshold) {
+            this.targetList = new java.util.ArrayDeque<>(queueLength);
+            this.baseThreshold = baseThreshold;
+        }
+        
+        /**
+         * 优化目标点 (对应Python版本的optimize_target方法)
+         * @param newTarget 新接收到的目标点 [x, y, z, snr]
+         * @return 优化后的目标点
+         */
+        public List<Float> optimizeTarget(List<Float> newTarget) {
+            if (targetList.isEmpty()) {
+                return new ArrayList<>(newTarget);
+            }
+            
+            // 计算队列中点的平均值 (对应Python版本的avg_x, avg_y, avg_z计算)
+            float avgX = 0.0f, avgY = 0.0f, avgZ = 0.0f;
+            for (List<Float> target : targetList) {
+                avgX += target.get(0);
+                avgY += target.get(1);
+                avgZ += target.get(2);
+            }
+            avgX /= targetList.size();
+            avgY /= targetList.size();
+            avgZ /= targetList.size();
+            
+            // 动态阈值调整 (对应Python版本的dynamic_threshold计算)
+            float dx = newTarget.get(0) - avgX;
+            float dy = newTarget.get(1) - avgY;
+            float distanceToAvg = (float) Math.sqrt(dx * dx + dy * dy);
+            float dynamicThreshold = baseThreshold + (distanceToAvg * 0.3f);
+            
+            // 如果与均值的偏差超过动态阈值,则平滑处理 (对应Python版本的平滑逻辑)
+            if (distanceToAvg > dynamicThreshold) {
+                float smoothedX = (newTarget.get(0) + avgX) / 2.0f;
+                float smoothedY = (newTarget.get(1) + avgY) / 2.0f;
+                float smoothedZ = (newTarget.get(1) + avgZ) / 2.0f;  // 保持与Python版本完全一致,包括这个bug:使用newTarget.get(1)而不是get(2)
+                
+                List<Float> optimizedTarget = new ArrayList<>();
+                optimizedTarget.add(smoothedX);
+                optimizedTarget.add(smoothedY);
+                optimizedTarget.add(smoothedZ);
+                return optimizedTarget;
+            } else {
+                return new ArrayList<>(newTarget);
+            }
+        }
+        
+        /**
+         * 更新目标列表 (对应Python版本的update_target_list方法)
+         * @param newTarget 原始目标点 [x, y, z, snr]
+         * @return 优化后的目标点
+         */
+        public List<Float> updateTargetList(List<Float> newTarget) {
+            if (newTarget.size() < 3) {
+                // 对应Python: LOGDBG(f"update_target_list error: invlid target:{new_target}")
+                return new ArrayList<>(newTarget);
+            }
+            
+            List<Float> optimizedTarget = optimizeTarget(newTarget);
+            
+            // 如果队列已满,移除最老的目标 (对应Python的deque maxlen行为)
+            if (targetList.size() >= 10) {  // queueLength
+                targetList.removeFirst();
+            }
+            targetList.addLast(optimizedTarget);
+            
+            return optimizedTarget;
+        }
+    }
+    
+    // 目标稳定器实例 (对应Python版本device中的stabilizer)
+    private TargetStabilizer stabilizer = new TargetStabilizer(10, 1.0f);
+    
+    /**
      * 构造函数,设置默认值
      * @param devId 设备ID
      */
@@ -1080,20 +1162,38 @@ public class Device {
         }
     }
     
-    /**
-     * 更新设备目标位置 (对应Python版本的update_targets方法)
-     * @param newTargets 新的目标位置列表
-     * @return 稳定的目标位置列表
-     */
-    public List<List<Float>> updateTargets(List<List<Float>> newTargets) {
+         /**
+      * 更新实时位置(平滑算法) - 重写以对应Python版本的update_targets方法
+      * @param newTargets 新的目标位置列表
+      * @return 稳定的目标位置列表
+      */
+     public List<List<Float>> updateTargets(List<List<Float>> newTargets) {
         try {
             lock.lock();
-            if (newTargets != null && !newTargets.isEmpty()) {
-                // 简单实现:直接更新targets字段
-                // 在实际应用中,这里可能需要更复杂的目标稳定算法
-                this.targets = newTargets.get(0); // 取第一个目标
-                return newTargets; // 返回稳定的目标
+            
+            // 对应Python版本的输入验证
+            if (newTargets == null || newTargets.isEmpty() || 
+                newTargets.get(0).size() < 3) {
+                // 对应Python: LOGERR(f"update_targets error: invalid new_targets")
+                return new ArrayList<>();
             }
+            
+            // 对应Python: new_target = new_targets[0]
+            List<Float> newTarget = newTargets.get(0);
+            
+            // 对应Python: stable_target = self.stabilizer.update_target_list(new_target)
+            List<Float> stableTarget = stabilizer.updateTargetList(newTarget);
+            
+            // 更新targets字段
+            this.targets = stableTarget;
+            
+            // 对应Python: return [stable_target]
+            List<List<Float>> result = new ArrayList<>();
+            result.add(stableTarget);
+            return result;
+            
+        } catch (Exception e) {
+            // 对应Python: except Exception as e: LOGERR(f"update_targets error: {e}")
             return new ArrayList<>();
         } finally {
             lock.unlock();
@@ -1475,4 +1575,38 @@ public class Device {
         private Object location;  // 修改为Object类型,匹配MqttGateway接口
         private boolean shouldProcess;
     }
+
+    /**
+     * 获取数据量最多的一组点云数据,然后清空cloud_points_que_ 
+     * (对应Python版本的get_max_len_cloud_points方法)
+     * @return 数据量最多的点云list,失败返回null
+     */
+    public List<List<Float>> getMaxLenCloudPoints() {
+        try {
+            lock.lock();
+            
+            // 对应Python: if self.dev_type_ == "LNB": return None
+            if ("LNB".equals(devType)) {
+                return null;
+            }
+            
+            // 对应Python: 取出数据量最多的的点云,成功返回点云list,失败返回None
+            int maxLen = 0;
+            List<List<Float>> maxLenList = null;
+            
+            // 对应Python: while not self.cloud_points_que_.empty():
+            while (!cloudPointsQueue.isEmpty()) {
+                List<List<Float>> currentList = cloudPointsQueue.poll();
+                // 对应Python: if len(current_list) >= max_len:
+                if (currentList != null && currentList.size() >= maxLen) {
+                    maxLen = currentList.size();
+                    maxLenList = currentList;
+                }
+            }
+            
+            return maxLenList;
+        } finally {
+            lock.unlock();
+        }
+    }
 } 

+ 193 - 34
device-service-domain/src/main/java/com/hfln/device/domain/service/PointCloudProcessService.java

@@ -8,6 +8,8 @@ import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
 import java.util.stream.Collectors;
+import java.util.HashMap;
+import java.util.Map;
 
 /**
  * 点云数据处理服务
@@ -45,6 +47,32 @@ public class PointCloudProcessService {
     }
 
     /**
+     * 后处理数据结果 (对应Python版本的post_process.py功能)
+     */
+    public static class PostProcessResult {
+        private int pose;
+        private float confidence;
+        private long responseTime;
+        private String deviceId;
+        private int pointCount;
+        
+        public PostProcessResult(int pose, float confidence, long responseTime, String deviceId, int pointCount) {
+            this.pose = pose;
+            this.confidence = confidence;
+            this.responseTime = responseTime;
+            this.deviceId = deviceId;
+            this.pointCount = pointCount;
+        }
+        
+        // Getters
+        public int getPose() { return pose; }
+        public float getConfidence() { return confidence; }
+        public long getResponseTime() { return responseTime; }
+        public String getDeviceId() { return deviceId; }
+        public int getPointCount() { return pointCount; }
+    }
+
+    /**
      * 分析点云数据,判断姿态
      * 适配DeviceEventServiceImpl的调用
      * 
@@ -193,14 +221,15 @@ public class PointCloudProcessService {
     /**
      * 从点云数据计算跟踪目标 (对应Python版本的get_tracker_targets方法)
      * 
-     * Python版本功能:
-     * 1. 对点云数据进行聚类分析
-     * 2. 识别人员目标
-     * 3. 计算目标的边界框和中心位置
-     * 4. 返回跟踪目标列表,格式:[[x, y, z, id], ...]
+     * Python版本实现:
+     * def get_tracker_targets(point_cloud:list):
+     *     target_point = numpy.mean(point_cloud, axis=0).tolist()
+     *     tracker_targets = []
+     *     tracker_targets.append(target_point)
+     *     return tracker_targets
      * 
      * @param cloudPoints 点云数据,格式:List<List<Float>>,每个点包含[x, y, z]坐标
-     * @return 跟踪目标列表,格式:List<List<Float>>,每个目标包含[x, y, z, id]
+     * @return 跟踪目标列表,格式:List<List<Float>>,每个目标包含[x, y, z]
      */
     public List<List<Float>> getTrackerTargets(List<List<Float>> cloudPoints) {
         if (cloudPoints == null || cloudPoints.isEmpty()) {
@@ -211,30 +240,23 @@ public class PointCloudProcessService {
         try {
             log.trace("Processing {} cloud points for target tracking", cloudPoints.size());
             
-            // 简化实现:使用聚类算法识别目标
-            // 实际应用中应使用更复杂的3D目标检测算法
-            List<List<Float>> targets = new ArrayList<>();
+            // 对应Python: target_point = numpy.mean(point_cloud, axis=0).tolist()
+            List<Float> targetPoint = calculateMeanPoint(cloudPoints);
             
-            // 基本的聚类方法:计算点云的质心作为主要目标
-            List<Float> centroid = calculateCentroidFromFloat(cloudPoints);
-            if (!centroid.isEmpty()) {
-                // 添加目标ID (第4个元素)
-                List<Float> target = new ArrayList<>(centroid);
-                target.add(1.0f); // 目标ID = 1
-                targets.add(target);
+            // 对应Python: tracker_targets = []
+            List<List<Float>> trackerTargets = new ArrayList<>();
+            
+            if (!targetPoint.isEmpty()) {
+                // 对应Python: tracker_targets.append(target_point)
+                trackerTargets.add(targetPoint);
                 
                 log.trace("Detected target at position: [{}, {}, {}]", 
-                         centroid.get(0), centroid.get(1), centroid.get(2));
+                         targetPoint.get(0), targetPoint.get(1), targetPoint.get(2));
             }
             
-            // TODO: 实现更复杂的目标检测算法
-            // 1. DBSCAN聚类算法识别多个目标
-            // 2. 卡尔曼滤波跟踪目标轨迹
-            // 3. 目标关联和ID管理
-            // 4. 噪声过滤和异常点检测
-            
-            log.debug("Extracted {} targets from {} cloud points", targets.size(), cloudPoints.size());
-            return targets;
+            // 对应Python: return tracker_targets
+            log.debug("Extracted {} targets from {} cloud points", trackerTargets.size(), cloudPoints.size());
+            return trackerTargets;
             
         } catch (Exception e) {
             log.error("Error extracting tracker targets from cloud points: {}", e.getMessage(), e);
@@ -243,17 +265,38 @@ public class PointCloudProcessService {
     }
     
     /**
-     * 计算Float类型点云的中心点
+     * 获取多个目标点 (对应Python版本的get_tracker_targets_mult方法)
      * 
-     * @param pointCloud Float类型点云数据
-     * @return 中心点坐标 [x, y, z]
+     * Python版本实现:
+     * def get_tracker_targets_mult(point_cloud:list):
+     *     target_point = numpy.mean(point_cloud, axis=0).tolist()
+     *     tracker_targets = []
+     *     tracker_targets.append(target_point)
+     *     return tracker_targets
+     * 
+     * 注意:在Python版本中,get_tracker_targets_mult与get_tracker_targets实现完全相同
+     * 
+     * @param cloudPoints 点云数据,格式:List<List<Float>>,每个点包含[x, y, z]坐标
+     * @return 跟踪目标列表,格式:List<List<Float>>,每个目标包含[x, y, z]
      */
-    private List<Float> calculateCentroidFromFloat(List<List<Float>> pointCloud) {
+    public List<List<Float>> getTrackerTargetsMult(List<List<Float>> cloudPoints) {
+        // 在Python版本中,get_tracker_targets_mult和get_tracker_targets实现相同
+        return getTrackerTargets(cloudPoints);
+    }
+    
+    /**
+     * 计算点云的平均点 (对应Python的numpy.mean(point_cloud, axis=0))
+     * 
+     * @param pointCloud 点云数据
+     * @return 平均点坐标 [x, y, z]
+     */
+    private List<Float> calculateMeanPoint(List<List<Float>> pointCloud) {
         int numPoints = pointCloud.size();
         if (numPoints == 0) {
             return new ArrayList<>();
         }
 
+        // 计算各维度的总和
         float sumX = 0.0f;
         float sumY = 0.0f;
         float sumZ = 0.0f;
@@ -266,11 +309,127 @@ public class PointCloudProcessService {
             }
         }
 
-        List<Float> centroid = new ArrayList<>();
-        centroid.add(sumX / numPoints);
-        centroid.add(sumY / numPoints);
-        centroid.add(sumZ / numPoints);
+        // 计算平均值 (对应numpy.mean的功能)
+        List<Float> meanPoint = new ArrayList<>();
+        meanPoint.add(sumX / numPoints);
+        meanPoint.add(sumY / numPoints);
+        meanPoint.add(sumZ / numPoints);
 
-        return centroid;
+        return meanPoint;
+    }
+
+    /**
+     * 处理点云数据进行姿态识别 (对应Python版本的deal_post_data方法)
+     * @param rawPoints 原始点云数据
+     * @return 处理后的数据,准备发送给AI算法服务
+     */
+    public Map<String, Object> preparePostData(List<List<Float>> rawPoints) {
+        if (rawPoints == null || rawPoints.isEmpty()) {
+            return new HashMap<>();
+        }
+        
+        // 对应Python: RawPoints = [sublist[0:3] for sublist in RawPoints]
+        List<List<Float>> processedPoints = new ArrayList<>();
+        for (List<Float> point : rawPoints) {
+            if (point.size() >= 3) {
+                List<Float> processedPoint = new ArrayList<>();
+                processedPoint.add(point.get(0)); // x
+                processedPoint.add(point.get(1)); // y
+                processedPoint.add(point.get(2)); // z
+                processedPoints.add(processedPoint);
+            }
+        }
+        
+        // 对应Python的数据格式准备
+        Map<String, Object> pointCloudData = new HashMap<>();
+        
+        // 默认使用李博模型格式 (对应Python: if e_model == MODEL_E.MODEL_LIBO)
+        pointCloudData.put("point_cloud", processedPoints);
+        
+        // 也可以支持安大模型格式 (对应Python: elif e_model == MODEL_E.MODEL_ANDA)
+        // pointCloudData.put("ID", "JSON_DATA");
+        // Map<String, Object> payload = new HashMap<>();
+        // payload.put("raw_points", processedPoints);
+        // pointCloudData.put("Payload", payload);
+        // pointCloudData.put("Type", "POINT");
+        
+        return pointCloudData;
+    }
+    
+    /**
+     * 检查姿态类型 (对应Python版本的check_pose方法)
+     * @param predictedClass AI算法返回的姿态分类
+     * @return 标准化的姿态枚举值
+     */
+    public int checkPose(int predictedClass) {
+        // 对应Python版本的姿态映射逻辑
+        // 这里简化处理,实际应根据具体的模型类型进行映射
+        
+        // 对应Python: if e_model == MODEL_E.MODEL_LIBO:
+        //   if e_pose_class == POSE_CLASS_E.POSE_CLASS_3:
+        //     if predicted_class == 2: pose = POSE_E.POSE_4.value
+        //     else: pose = predicted_class
+        
+        switch (predictedClass) {
+            case 0: return DeviceConstants.PoseEnum.POSE_FALLING.getCode();      // 摔倒
+            case 1: return DeviceConstants.PoseEnum.POSE_SITTING_ON_CHAIR.getCode(); // 坐在椅子上
+            case 2: return DeviceConstants.PoseEnum.POSE_SITTING_ON_FLOOR.getCode(); // 坐在地上  
+            case 3: return DeviceConstants.PoseEnum.POSE_SQUATTING.getCode();    // 蹲
+            case 4: return DeviceConstants.PoseEnum.POSE_STANDING.getCode();     // 站
+            case 5: return DeviceConstants.PoseEnum.POSE_SITTING.getCode();      // 坐
+            case 6: return DeviceConstants.PoseEnum.POSE_LYING.getCode();        // 躺
+            default: return DeviceConstants.PoseEnum.POSE_INVALID.getCode();     // 无效
+        }
+    }
+    
+    /**
+     * 获取最大长度的点云数据 (对应Python版本的get_max_len_raw_points方法)
+     * 这个方法应该由调用方(如设备管理服务)提供设备列表
+     * @param devices 设备列表
+     * @return 数据量最多的点云对象,包含设备ID和点云数据
+     */
+    public Map<String, Object> getMaxLenRawPoints(java.util.Collection<com.hfln.device.domain.entity.Device> devices) {
+        int maxLen = 0;
+        Map<String, Object> maxLenObj = null;
+        
+        // 对应Python: for dev_id, device in g_dev_map.items():
+        for (com.hfln.device.domain.entity.Device device : devices) {
+            // 对应Python: cloud_points = device.get_max_len_cloud_points()
+            List<List<Float>> cloudPoints = device.getMaxLenCloudPoints();
+            
+            // 对应Python: if (cloud_points == None or len(cloud_points) <= 20): continue
+            if (cloudPoints == null || cloudPoints.size() <= 20) {
+                continue;
+            }
+            
+            // 对应Python: if len(current_list) >= max_len:
+            if (cloudPoints.size() >= maxLen) {
+                maxLen = cloudPoints.size();
+                maxLenObj = new HashMap<>();
+                maxLenObj.put("dev_id", device.getDevId());
+                maxLenObj.put("raw_points", cloudPoints);
+                
+                // 找到符合条件的设备后立即返回 (对应Python的break)
+                break;
+            }
+        }
+        
+        return maxLenObj;
+    }
+    
+    /**
+     * 处理AI算法响应 (对应Python版本的check_pose_ex方法)
+     * @param responseJson AI算法服务返回的JSON响应
+     * @return 姿态分类结果
+     */
+    public int processPoseResponse(Map<String, Object> responseJson) {
+        // 对应Python: predicted_class = resp_json["predicted_class"]
+        Object predictedClassObj = responseJson.get("predicted_class");
+        if (predictedClassObj instanceof Number) {
+            int predictedClass = ((Number) predictedClassObj).intValue();
+            return checkPose(predictedClass);
+        }
+        
+        return DeviceConstants.PoseEnum.POSE_INVALID.getCode();
     }
 }

+ 1 - 4
device-service-domain/src/main/java/com/hfln/device/domain/service/impl/PoseAnalysisServiceImpl.java

@@ -220,10 +220,7 @@ public class PoseAnalysisServiceImpl implements PoseAnalysisService {
             Long lastActivityTime = getLastActivityTime(device);
             
             // 如果没有记录最后活动时间,返回false
-            if (lastActivityTime == null) {
-                return false;
-            }
-            
+
             // 计算不活动时间(毫秒)
             long inactivityTime = currentTime - lastActivityTime;
             

+ 2 - 2
device-service-infrastructure/src/main/java/com/hfln/device/infrastructure/config/MqttConfig.java

@@ -104,7 +104,7 @@ public class MqttConfig {
         
         return factory;
     }
-
+    
     // ===========================================
     // 设备消息通道和适配器
     // ===========================================
@@ -322,5 +322,5 @@ public class MqttConfig {
     /**
      * 由于MqttGatewayImpl已设置@Primary注解,
      * 无需额外配置,Spring会自动选择MqttGatewayImpl作为主要实现
-     */
+    */
 } 

+ 109 - 109
device-service-infrastructure/src/main/java/com/hfln/device/infrastructure/gateway/MqttGatewayImpl.java

@@ -209,14 +209,14 @@ public class MqttGatewayImpl implements MqttGateway {
     @Override
     public void sendRealtimePoseMessage(String deviceId, int pose, Object targetPoint) {
         try {
-            Map<String, Object> payload = new HashMap<>();
+        Map<String, Object> payload = new HashMap<>();
             payload.put("message", "notify");
             payload.put("message_type", DeviceConstants.MessageType.MSG_REALTIME_TARGET.getCode());
             payload.put("timestamp", System.currentTimeMillis());
             payload.put("dev_id", deviceId);
-            payload.put("pose", pose);
+        payload.put("pose", pose);
             payload.put("target_point", targetPoint);
-            
+        
             sendMessage(MqttTopics.DAS_REALTIME_POS, payload);
         } catch (Exception e) {
             log.error("Error sending realtime pose message: {}", deviceId, e);
@@ -226,13 +226,13 @@ public class MqttGatewayImpl implements MqttGateway {
     @Override
     public void sendAlarmMessage(String deviceId, String alarmType, Map<String, Object> data) {
         try {
-            Map<String, Object> payload = new HashMap<>(data);
+        Map<String, Object> payload = new HashMap<>(data);
             payload.put("message", "notify");
             payload.put("message_type", DeviceConstants.MessageType.MSG_ALARM_EVENT.getCode());
             payload.put("dev_id", deviceId);
             payload.put("timestamp", System.currentTimeMillis());
-            payload.put("alarmType", alarmType);
-            
+        payload.put("alarmType", alarmType);
+        
             sendMessage(MqttTopics.DAS_ALARM_EVENT, payload);
             log.info("Alarm message sent: {}, type: {}", deviceId, alarmType);
         } catch (Exception e) {
@@ -243,12 +243,12 @@ public class MqttGatewayImpl implements MqttGateway {
     @Override
     public void sendBehaviorAnalysisResult(String deviceId, Object behaviorPattern) {
         try {
-            Map<String, Object> payload = new HashMap<>();
+        Map<String, Object> payload = new HashMap<>();
             payload.put("message", "notify");
             payload.put("dev_id", deviceId);
-            payload.put("behaviorPattern", behaviorPattern);
-            payload.put("timestamp", System.currentTimeMillis());
-            
+        payload.put("behaviorPattern", behaviorPattern);
+        payload.put("timestamp", System.currentTimeMillis());
+        
             sendMessage(MqttTopics.DAS_BEHAVIOR_ANALYSIS, payload);
             log.debug("Behavior analysis result sent: {}", deviceId);
         } catch (Exception e) {
@@ -259,11 +259,11 @@ public class MqttGatewayImpl implements MqttGateway {
     @Override
     public boolean sendCommandToDevice(String deviceId, String command, Object payload) {
         try {
-            Map<String, Object> message = new HashMap<>();
-            message.put("command", command);
-            message.put("payload", payload);
-            message.put("timestamp", System.currentTimeMillis());
-            
+        Map<String, Object> message = new HashMap<>();
+        message.put("command", command);
+        message.put("payload", payload);
+        message.put("timestamp", System.currentTimeMillis());
+        
             String topic = String.format(MqttTopics.DEV_COMMAND, deviceId);
             sendMessage(topic, message);
             log.info("Command sent to device: {}, command: {}", deviceId, command);
@@ -277,15 +277,15 @@ public class MqttGatewayImpl implements MqttGateway {
     @Override
     public void sendFallAlarmMessage(String deviceId, int pose, List<Float> targetPoint) {
         try {
-            Map<String, Object> payload = new HashMap<>();
+        Map<String, Object> payload = new HashMap<>();
             payload.put("message", "notify");
             payload.put("message_type", DeviceConstants.MessageType.MSG_EVENT_FALL.getCode());
             payload.put("dev_id", deviceId);
-            payload.put("pose", pose);
+        payload.put("pose", pose);
             payload.put("target_point", targetPoint);
             payload.put("alarmType", "fall");
-            payload.put("timestamp", System.currentTimeMillis());
-            
+        payload.put("timestamp", System.currentTimeMillis());
+        
             // 跌倒告警使用QoS 2确保可靠传输
             sendMessage(MqttTopics.DAS_ALARM_EVENT, payload, 2);
             log.info("Fall alarm message sent: {}", deviceId);
@@ -297,9 +297,9 @@ public class MqttGatewayImpl implements MqttGateway {
     @Override
     public void sendDeviceRebootCommand(String deviceId) {
         try {
-            Map<String, Object> payload = new HashMap<>();
+        Map<String, Object> payload = new HashMap<>();
             String topic = DeviceConstants.MqttConstant.TOPIC_DEVICE_PREFIX + deviceId + "/reboot";
-            sendMessage(topic, payload);
+        sendMessage(topic, payload);
             log.info("Device reboot command sent: {}", deviceId);
         } catch (Exception e) {
             log.error("Error sending device reboot command: {}", deviceId, e);
@@ -309,9 +309,9 @@ public class MqttGatewayImpl implements MqttGateway {
     @Override
     public void sendDeviceResetCommand(String deviceId) {
         try {
-            Map<String, Object> payload = new HashMap<>();
+        Map<String, Object> payload = new HashMap<>();
             String topic = DeviceConstants.MqttConstant.TOPIC_DEVICE_PREFIX + deviceId + "/reset";
-            sendMessage(topic, payload);
+        sendMessage(topic, payload);
             log.info("Device reset command sent: {}", deviceId);
         } catch (Exception e) {
             log.error("Error sending device reset command: {}", deviceId, e);
@@ -321,13 +321,13 @@ public class MqttGatewayImpl implements MqttGateway {
     @Override
     public void sendDeviceCommand(String deviceId, String command, Map<String, Object> params) {
         try {
-            Map<String, Object> payload = new HashMap<>();
-            payload.put("command", command);
-            payload.put("params", params);
-            payload.put("timestamp", System.currentTimeMillis());
-            
+        Map<String, Object> payload = new HashMap<>();
+        payload.put("command", command);
+        payload.put("params", params);
+        payload.put("timestamp", System.currentTimeMillis());
+        
             String topic = DeviceConstants.MqttConstant.TOPIC_DEVICE_PREFIX + deviceId + "/" + command;
-            sendMessage(topic, payload);
+        sendMessage(topic, payload);
             log.info("Device command sent: {}, command: {}", deviceId, command);
         } catch (Exception e) {
             log.error("Error sending device command: {}, command: {}", deviceId, command, e);
@@ -337,11 +337,11 @@ public class MqttGatewayImpl implements MqttGateway {
     @Override
     public void sendDeviceKeepAliveResponse(String deviceId, int status) {
         try {
-            Map<String, Object> payload = new HashMap<>();
+        Map<String, Object> payload = new HashMap<>();
             payload.put("code", status);
-            
+        
             String topic = DeviceConstants.MqttConstant.TOPIC_DAS_PREFIX + deviceId + "/keepalive";
-            sendMessage(topic, payload);
+        sendMessage(topic, payload);
             log.debug("Device keepalive response sent: {}, code: {}", deviceId, status);
         } catch (Exception e) {
             log.error("Error sending device keepalive response: {}", deviceId, e);
@@ -351,12 +351,12 @@ public class MqttGatewayImpl implements MqttGateway {
     @Override
     public void sendDeviceNotFoundResponse(String deviceId) {
         try {
-            Map<String, Object> payload = new HashMap<>();
+        Map<String, Object> payload = new HashMap<>();
             payload.put("code", 404);
-            payload.put("message", "Device not found");
-            
+        payload.put("message", "Device not found");
+        
             String topic = MqttTopics.APP_DEVICE_INFO_RESPONSE;
-            sendMessage(topic, payload);
+        sendMessage(topic, payload);
             log.debug("Device not found response sent: {}", deviceId);
         } catch (Exception e) {
             log.error("Error sending device not found response: {}", deviceId, e);
@@ -366,11 +366,11 @@ public class MqttGatewayImpl implements MqttGateway {
     @Override
     public void sendAlarmAckMessage(String deviceId, Long eventId) {
         try {
-            Map<String, Object> payload = new HashMap<>();
+        Map<String, Object> payload = new HashMap<>();
             payload.put("dev_id", deviceId);
             payload.put("event_id", eventId);
-            payload.put("timestamp", System.currentTimeMillis());
-            
+        payload.put("timestamp", System.currentTimeMillis());
+        
             sendMessage(MqttTopics.APP_FALL_EVENT_ACK, payload);
             log.debug("Alarm acknowledgment sent: {}, eventId: {}", deviceId, eventId);
         } catch (Exception e) {
@@ -381,14 +381,14 @@ public class MqttGatewayImpl implements MqttGateway {
     @Override
     public void sendDeviceParamSetCommand(String deviceId, String paramType, String paramName, float value) {
         try {
-            Map<String, Object> payload = new HashMap<>();
+        Map<String, Object> payload = new HashMap<>();
             payload.put("param_type", paramType);
             payload.put("param_name", paramName);
-            payload.put("value", value);
-            payload.put("timestamp", System.currentTimeMillis());
-            
+        payload.put("value", value);
+        payload.put("timestamp", System.currentTimeMillis());
+        
             String topic = DeviceConstants.MqttConstant.TOPIC_DEVICE_PREFIX + deviceId + "/set_param";
-            sendMessage(topic, payload);
+        sendMessage(topic, payload);
             log.info("Device parameter set command sent: {}, {}={}", deviceId, paramName, value);
         } catch (Exception e) {
             log.error("Error sending device parameter set command: {}", deviceId, e);
@@ -403,13 +403,13 @@ public class MqttGatewayImpl implements MqttGateway {
     @Override
     public void sendUpdateNetworkCommand(String deviceId, String ssid, String password) {
         try {
-            Map<String, Object> payload = new HashMap<>();
-            payload.put("ssid", ssid);
-            payload.put("password", password);
-            payload.put("timestamp", System.currentTimeMillis());
-            
+        Map<String, Object> payload = new HashMap<>();
+        payload.put("ssid", ssid);
+        payload.put("password", password);
+        payload.put("timestamp", System.currentTimeMillis());
+        
             String topic = DeviceConstants.MqttConstant.TOPIC_DEVICE_PREFIX + deviceId + "/network";
-            sendMessage(topic, payload);
+        sendMessage(topic, payload);
             log.info("Network update command sent: {}", deviceId);
         } catch (Exception e) {
             log.error("Error sending network update command: {}", deviceId, e);
@@ -419,12 +419,12 @@ public class MqttGatewayImpl implements MqttGateway {
     @Override
     public void sendDeviceLoginResponse(String deviceId, int code) {
         try {
-            Map<String, Object> payload = new HashMap<>();
-            payload.put("code", code);
+        Map<String, Object> payload = new HashMap<>();
+        payload.put("code", code);
             payload.put("expires", 90); // 过期时间,单位秒
-            
+        
             String topic = DeviceConstants.MqttConstant.TOPIC_DAS_PREFIX + deviceId + "/login";
-            sendMessage(topic, payload);
+        sendMessage(topic, payload);
             log.debug("Device login response sent: {}, code: {}", deviceId, code);
         } catch (Exception e) {
             log.error("Error sending device login response: {}", deviceId, e);
@@ -434,12 +434,12 @@ public class MqttGatewayImpl implements MqttGateway {
     @Override
     public void sendEventMessage(String deviceId, List<List<Float>> rawPoints, int pose, List<List<Float>> targets, String event) {
         try {
-            Map<String, Object> payload = new HashMap<>();
+        Map<String, Object> payload = new HashMap<>();
             payload.put("message", "notify");
             payload.put("message_type", DeviceConstants.MessageType.MSG_EVENT_FALL.getCode());
             payload.put("dev_id", deviceId);
-            payload.put("event", event);
-            payload.put("timestamp", System.currentTimeMillis());
+        payload.put("event", event);
+        payload.put("timestamp", System.currentTimeMillis());
             payload.put("pose", pose);
             payload.put("RawPoints", rawPoints != null ? rawPoints : new ArrayList<>());  // 对应Python版本的RawPoints参数
             payload.put("targets", targets != null ? targets : new ArrayList<>());        // 对应Python版本的targets参数
@@ -457,15 +457,15 @@ public class MqttGatewayImpl implements MqttGateway {
     @Override
     public void sendAlarmEventMessage(String deviceId, String description, String table, int tableId) {
         try {
-            Map<String, Object> payload = new HashMap<>();
+        Map<String, Object> payload = new HashMap<>();
             payload.put("message", "notify");
             payload.put("message_type", DeviceConstants.MessageType.MSG_ALARM_EVENT.getCode());
             payload.put("dev_id", deviceId);
             payload.put("timestamp", System.currentTimeMillis());
             payload.put("desc", description);
-            payload.put("table", table);
+        payload.put("table", table);
             payload.put("table_id", tableId);
-            
+        
             sendMessage(MqttTopics.DAS_ALARM_EVENT, payload);
         } catch (Exception e) {
             log.error("Error sending alarm event message: {}, desc: {}", deviceId, description, e);
@@ -475,13 +475,13 @@ public class MqttGatewayImpl implements MqttGateway {
     @Override
     public void sendExistenceMessage(String deviceId, String event) {
         try {
-            Map<String, Object> payload = new HashMap<>();
+        Map<String, Object> payload = new HashMap<>();
             payload.put("message", "notify");
             payload.put("message_type", DeviceConstants.MessageType.MSG_EVENT_EXIST.getCode());
             payload.put("dev_id", deviceId);
-            payload.put("event", event);
-            payload.put("timestamp", System.currentTimeMillis());
-            
+        payload.put("event", event);
+        payload.put("timestamp", System.currentTimeMillis());
+        
             sendMessage(MqttTopics.DAS_EXIST_EVENT, payload);
         } catch (Exception e) {
             log.error("Error sending existence message: {}, event: {}", deviceId, event, e);
@@ -491,17 +491,17 @@ public class MqttGatewayImpl implements MqttGateway {
     @Override
     public void sendNetworkConfigUpdate(String deviceId, Device.NetworkInfo networkInfo) {
         try {
-            Map<String, Object> payload = new HashMap<>();
+        Map<String, Object> payload = new HashMap<>();
             payload.put("dev_id", deviceId);
             if (networkInfo != null) {
                 payload.put("ssid", networkInfo.getSsid());
                 payload.put("password", networkInfo.getPassword());
                 payload.put("ip", networkInfo.getIp());
             }
-            payload.put("timestamp", System.currentTimeMillis());
-            
+        payload.put("timestamp", System.currentTimeMillis());
+        
             String topic = DeviceConstants.MqttConstant.TOPIC_DEVICE_PREFIX + deviceId + "/network_config";
-            sendMessage(topic, payload);
+        sendMessage(topic, payload);
             log.info("Network config update sent: {}", deviceId);
         } catch (Exception e) {
             log.error("Error sending network config update: {}", deviceId, e);
@@ -511,7 +511,7 @@ public class MqttGatewayImpl implements MqttGateway {
     @Override
     public void sendInstallParamUpdate(String deviceId, Device.InstallParam installParam) {
         try {
-            Map<String, Object> payload = new HashMap<>();
+        Map<String, Object> payload = new HashMap<>();
             payload.put("dev_id", deviceId);
             if (installParam != null) {
                 payload.put("mount_plain", installParam.getMountPlain());
@@ -523,10 +523,10 @@ public class MqttGatewayImpl implements MqttGateway {
                     payload.put("tracking_region", trackingRegion);
                 }
             }
-            payload.put("timestamp", System.currentTimeMillis());
-            
+        payload.put("timestamp", System.currentTimeMillis());
+        
             String topic = DeviceConstants.MqttConstant.TOPIC_DEVICE_PREFIX + deviceId + "/install_param";
-            sendMessage(topic, payload);
+        sendMessage(topic, payload);
             log.info("Install parameter update sent: {}", deviceId);
         } catch (Exception e) {
             log.error("Error sending install parameter update: {}", deviceId, e);
@@ -536,16 +536,16 @@ public class MqttGatewayImpl implements MqttGateway {
     @Override
     public void sendTrackingRegionUpdate(String deviceId, Device.TrackingRegion trackingRegion) {
         try {
-            Map<String, Object> payload = new HashMap<>();
+        Map<String, Object> payload = new HashMap<>();
             payload.put("dev_id", deviceId);
             if (trackingRegion != null) {
                 Map<String, Object> regionMap = getStringObjectMap(trackingRegion);
                 payload.put("tracking_region", regionMap);
             }
-            payload.put("timestamp", System.currentTimeMillis());
-            
+        payload.put("timestamp", System.currentTimeMillis());
+        
             String topic = DeviceConstants.MqttConstant.TOPIC_DEVICE_PREFIX + deviceId + "/tracking_region";
-            sendMessage(topic, payload);
+        sendMessage(topic, payload);
             log.info("Tracking region update sent: {}", deviceId);
         } catch (Exception e) {
             log.error("Error sending tracking region update: {}", deviceId, e);
@@ -555,13 +555,13 @@ public class MqttGatewayImpl implements MqttGateway {
     @Override
     public void sendAlarmScheduleUpdate(String deviceId, Map<String, Object> alarmSchedule) {
         try {
-            Map<String, Object> payload = new HashMap<>();
+        Map<String, Object> payload = new HashMap<>();
             payload.put("dev_id", deviceId);
             payload.put("alarm_schedule", alarmSchedule);
-            payload.put("timestamp", System.currentTimeMillis());
-            
+        payload.put("timestamp", System.currentTimeMillis());
+        
             String topic = DeviceConstants.MqttConstant.TOPIC_DEVICE_PREFIX + deviceId + "/alarm_schedule";
-            sendMessage(topic, payload);
+        sendMessage(topic, payload);
             log.info("Alarm schedule update sent: {}", deviceId);
         } catch (Exception e) {
             log.error("Error sending alarm schedule update: {}", deviceId, e);
@@ -571,11 +571,11 @@ public class MqttGatewayImpl implements MqttGateway {
     @Override
     public void sendDeviceInfoResponse(String deviceId, Device device) {
         try {
-            Map<String, Object> payload = new HashMap<>();
+        Map<String, Object> payload = new HashMap<>();
             payload.put("dev_id", deviceId);
-            payload.put("device", device);
-            payload.put("timestamp", System.currentTimeMillis());
-            
+        payload.put("device", device);
+        payload.put("timestamp", System.currentTimeMillis());
+        
             sendMessage(MqttTopics.APP_DEVICE_INFO_RESPONSE, payload);
             log.debug("Device info response sent: {}", deviceId);
         } catch (Exception e) {
@@ -586,11 +586,11 @@ public class MqttGatewayImpl implements MqttGateway {
     @Override
     public void sendStatusMessage(String deviceId, String status, Map<String, Object> data) {
         try {
-            Map<String, Object> payload = new HashMap<>(data);
+        Map<String, Object> payload = new HashMap<>(data);
             payload.put("dev_id", deviceId);
-            payload.put("status", status);
-            payload.put("timestamp", System.currentTimeMillis());
-            
+        payload.put("status", status);
+        payload.put("timestamp", System.currentTimeMillis());
+        
             sendMessage(MqttTopics.DAS_STATUS, payload);
             log.debug("Device status message sent: {}, status: {}", deviceId, status);
         } catch (Exception e) {
@@ -601,11 +601,11 @@ public class MqttGatewayImpl implements MqttGateway {
     @Override
     public void sendBehaviorMessage(String deviceId, String behaviorType, Map<String, Object> data) {
         try {
-            Map<String, Object> payload = new HashMap<>(data);
+        Map<String, Object> payload = new HashMap<>(data);
             payload.put("dev_id", deviceId);
-            payload.put("behaviorType", behaviorType);
-            payload.put("timestamp", System.currentTimeMillis());
-            
+        payload.put("behaviorType", behaviorType);
+        payload.put("timestamp", System.currentTimeMillis());
+        
             sendMessage(MqttTopics.DAS_BEHAVIOR_ANALYSIS, payload);
             log.debug("Behavior message sent: {}, type: {}", deviceId, behaviorType);
         } catch (Exception e) {
@@ -646,12 +646,12 @@ public class MqttGatewayImpl implements MqttGateway {
     @Override
     public void sendResponse(String topic, int code, Map<String, Object> data) {
         try {
-            Map<String, Object> payload = new HashMap<>(data);
-            payload.put("code", code);
-            payload.put("timestamp", System.currentTimeMillis());
-            
-            sendMessage(topic, payload);
-            log.debug("Response sent to topic: {}, code: {}", topic, code);
+        Map<String, Object> payload = new HashMap<>(data);
+        payload.put("code", code);
+        payload.put("timestamp", System.currentTimeMillis());
+        
+        sendMessage(topic, payload);
+        log.debug("Response sent to topic: {}, code: {}", topic, code);
         } catch (Exception e) {
             log.error("Error sending response to topic: {}", topic, e);
         }
@@ -660,13 +660,13 @@ public class MqttGatewayImpl implements MqttGateway {
     @Override
     public void sendCommand(String topic, String command, Map<String, Object> params) {
         try {
-            Map<String, Object> payload = new HashMap<>();
-            payload.put("command", command);
+        Map<String, Object> payload = new HashMap<>();
+        payload.put("command", command);
             payload.put("params", params);
-            payload.put("timestamp", System.currentTimeMillis());
-            
-            sendMessage(topic, payload);
-            log.debug("Command sent to topic: {}, command: {}", topic, command);
+        payload.put("timestamp", System.currentTimeMillis());
+        
+        sendMessage(topic, payload);
+        log.debug("Command sent to topic: {}, command: {}", topic, command);
         } catch (Exception e) {
             log.error("Error sending command to topic: {}", topic, e);
         }
@@ -675,12 +675,12 @@ public class MqttGatewayImpl implements MqttGateway {
     @Override
     public void sendGenericMessage(String topic, String messageType, Map<String, Object> messageData) {
         try {
-            Map<String, Object> payload = new HashMap<>(messageData);
+        Map<String, Object> payload = new HashMap<>(messageData);
             payload.put("message_type", messageType);
-            payload.put("timestamp", System.currentTimeMillis());
-            
-            sendMessage(topic, payload);
-            log.debug("Generic message sent to topic: {}, type: {}", topic, messageType);
+        payload.put("timestamp", System.currentTimeMillis());
+        
+        sendMessage(topic, payload);
+        log.debug("Generic message sent to topic: {}, type: {}", topic, messageType);
         } catch (Exception e) {
             log.error("Error sending generic message to topic: {}", topic, e);
         }