|
@@ -0,0 +1,266 @@
|
|
|
+package com.hfln.device.application.service.impl;
|
|
|
+
|
|
|
+import com.hfln.device.domain.service.DeviceManagerService;
|
|
|
+import com.hfln.device.application.service.PoseAnalysisService;
|
|
|
+import com.hfln.device.domain.constant.DeviceConstants;
|
|
|
+import com.hfln.device.domain.entity.Device;
|
|
|
+import com.hfln.device.domain.gateway.MqttGateway;
|
|
|
+import lombok.extern.slf4j.Slf4j;
|
|
|
+import org.springframework.beans.factory.annotation.Autowired;
|
|
|
+import org.springframework.beans.factory.annotation.Value;
|
|
|
+import org.springframework.stereotype.Service;
|
|
|
+import org.springframework.web.client.RestTemplate;
|
|
|
+
|
|
|
+import javax.annotation.PostConstruct;
|
|
|
+import javax.annotation.PreDestroy;
|
|
|
+import java.util.*;
|
|
|
+import java.util.concurrent.atomic.AtomicBoolean;
|
|
|
+
|
|
|
+/**
|
|
|
+ * 姿态分析服务实现
|
|
|
+ * 对应Python版本的post_process.py线程功能
|
|
|
+ */
|
|
|
+@Service
|
|
|
+@Slf4j
|
|
|
+public class PoseAnalysisServiceImpl implements PoseAnalysisService {
|
|
|
+
|
|
|
+ @Autowired
|
|
|
+ DeviceManagerService deviceManagerService;
|
|
|
+
|
|
|
+ @Autowired
|
|
|
+ MqttGateway mqttGateway;
|
|
|
+
|
|
|
+ @Autowired
|
|
|
+ RestTemplate restTemplate;
|
|
|
+
|
|
|
+ @Value("${pose.analysis.server.url:http://43.137.10.199:5000/predict}")
|
|
|
+ String poseAnalysisServerUrl;
|
|
|
+
|
|
|
+ @Value("${pose.analysis.model.type:LIBO}")
|
|
|
+ String modelType;
|
|
|
+
|
|
|
+ @Value("${pose.analysis.pose.class:POSE_CLASS_4}")
|
|
|
+ String poseClass;
|
|
|
+
|
|
|
+ Thread poseAnalysisThread;
|
|
|
+ final AtomicBoolean running = new AtomicBoolean(false);
|
|
|
+
|
|
|
+ // 统计信息
|
|
|
+ int count = 0;
|
|
|
+ double totalInterval = 0;
|
|
|
+ double averageInterval = 0;
|
|
|
+
|
|
|
+ @PostConstruct
|
|
|
+ public void init() {
|
|
|
+ startPoseAnalysisThread();
|
|
|
+ }
|
|
|
+
|
|
|
+ @PreDestroy
|
|
|
+ public void destroy() {
|
|
|
+ stopPoseAnalysisThread();
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void startPoseAnalysisThread() {
|
|
|
+ if (running.get()) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+
|
|
|
+ running.set(true);
|
|
|
+ poseAnalysisThread = new Thread(this::poseAnalysisProcess, "PoseAnalysisThread");
|
|
|
+ poseAnalysisThread.setDaemon(true);
|
|
|
+ poseAnalysisThread.start();
|
|
|
+
|
|
|
+ log.info("姿态分析后台线程已启动 - 对应Python版本的post_process_ex线程");
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void stopPoseAnalysisThread() {
|
|
|
+ running.set(false);
|
|
|
+ if (poseAnalysisThread != null) {
|
|
|
+ poseAnalysisThread.interrupt();
|
|
|
+ }
|
|
|
+ log.info("姿态分析后台线程已停止");
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 姿态分析处理主循环
|
|
|
+ * 对应Python版本的post_process_ex()函数
|
|
|
+ */
|
|
|
+ private void poseAnalysisProcess() {
|
|
|
+ log.info("开始姿态分析处理循环");
|
|
|
+
|
|
|
+ while (running.get() && !Thread.currentThread().isInterrupted()) {
|
|
|
+ try {
|
|
|
+ Thread.sleep(100); // 对应Python版本的time.sleep(0.1)
|
|
|
+
|
|
|
+ // 分别处理每个设备积累的点云数据
|
|
|
+ Device device = null;
|
|
|
+ List<List<Float>> cloudPoints = null;
|
|
|
+ String deviceId = null;
|
|
|
+
|
|
|
+ // 对应Python: for dev_id, device in g_dev_map.items():
|
|
|
+ Collection<Device> deviceCollection = deviceManagerService.getAllDevicesFromCache();
|
|
|
+ for (Device dev : deviceCollection) {
|
|
|
+ cloudPoints = dev.getMaxLenCloudPoints();
|
|
|
+ // 对应Python: if (cloud_points == None or len(cloud_points) <= 20): continue
|
|
|
+ if (cloudPoints == null || cloudPoints.size() <= 20) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+ device = dev;
|
|
|
+ deviceId = dev.getDevId();
|
|
|
+ break; // 对应Python的break
|
|
|
+ }
|
|
|
+
|
|
|
+ if (device == null || cloudPoints == null) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+
|
|
|
+ int rawPointsLen = cloudPoints.size();
|
|
|
+
|
|
|
+ try {
|
|
|
+ // 处理点云数据
|
|
|
+ Map<String, Object> postData = dealPostData(cloudPoints);
|
|
|
+
|
|
|
+ // 向姿态算法服务发起请求
|
|
|
+ long reqTimestamp = System.currentTimeMillis();
|
|
|
+ Map<String, Object> response = restTemplate.postForObject(poseAnalysisServerUrl, postData, Map.class);
|
|
|
+ double interval = (System.currentTimeMillis() - reqTimestamp) / 1000.0;
|
|
|
+
|
|
|
+ if (response == null) {
|
|
|
+ log.debug("post error: invalid response, response is null, [{}]", deviceId);
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+
|
|
|
+ // 检查姿态结果
|
|
|
+ int pose = checkPoseFromResponse(response);
|
|
|
+
|
|
|
+ // 计算耗时统计
|
|
|
+ count++;
|
|
|
+ totalInterval += interval;
|
|
|
+ averageInterval = totalInterval / count;
|
|
|
+
|
|
|
+ // 更新目标实时姿态
|
|
|
+ device.updatePose(pose);
|
|
|
+ List<Integer> realtimePose = device.getRealtimePose();
|
|
|
+
|
|
|
+ String poseText = getPoseText(realtimePose.get(0));
|
|
|
+ String currentPoseText = getPoseText(pose);
|
|
|
+ String logText = String.format("姿态:%s (%s) \t[%d]\t[%.3f %.3f]\t[%s]",
|
|
|
+ poseText, currentPoseText, rawPointsLen, interval, averageInterval, deviceId);
|
|
|
+ log.info(logText);
|
|
|
+
|
|
|
+ // 若姿态为躺,上报跌倒事件
|
|
|
+ long now = System.currentTimeMillis();
|
|
|
+ if (device.getAlarmAck() ||
|
|
|
+ (device.getLastReportFallTime() != null &&
|
|
|
+ now - device.getLastReportFallTime() < device.getAlarmInterval())) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (realtimePose.isEmpty() || realtimePose.get(0) != 0) { // POSE_0 = 0 (躺)
|
|
|
+ continue;
|
|
|
+ }
|
|
|
+
|
|
|
+ // 获取目标追踪数据并发送跌倒事件
|
|
|
+ List<List<Float>> targets = getTrackerTargets(cloudPoints);
|
|
|
+ String event = "fall_detected";
|
|
|
+ mqttGateway.sendEventMessage(deviceId, cloudPoints, realtimePose.get(0), targets, event);
|
|
|
+ device.setLastFallTime("fall_detected", now);
|
|
|
+
|
|
|
+ log.debug("算法检测跌倒事件: 摔, dev_id:{}", deviceId);
|
|
|
+
|
|
|
+ } catch (Exception e) {
|
|
|
+ log.error("处理设备{}的姿态分析时发生错误", deviceId, e);
|
|
|
+ }
|
|
|
+
|
|
|
+ } catch (InterruptedException e) {
|
|
|
+ Thread.currentThread().interrupt();
|
|
|
+ break;
|
|
|
+ } catch (Exception e) {
|
|
|
+ log.error("姿态分析处理循环发生错误", e);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ log.info("姿态分析处理循环结束");
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public Map<String, Object> dealPostData(List<List<Float>> rawPoints) {
|
|
|
+ // 对应Python: RawPoints = [sublist[0:3] for sublist in RawPoints]
|
|
|
+ List<List<Float>> processedPoints = new ArrayList<>();
|
|
|
+ for (List<Float> point : rawPoints) {
|
|
|
+ if (point.size() >= 3) {
|
|
|
+ processedPoints.add(Arrays.asList(point.get(0), point.get(1), point.get(2)));
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ Map<String, Object> pointCloudData = new HashMap<>();
|
|
|
+
|
|
|
+ if ("LIBO".equals(modelType)) {
|
|
|
+ // 李博模型
|
|
|
+ pointCloudData.put("point_cloud", processedPoints);
|
|
|
+ } else if ("ANDA".equals(modelType)) {
|
|
|
+ // 安大模型
|
|
|
+ pointCloudData.put("ID", "JSON_DATA");
|
|
|
+ Map<String, Object> payload = new HashMap<>();
|
|
|
+ payload.put("raw_points", processedPoints);
|
|
|
+ pointCloudData.put("Payload", payload);
|
|
|
+ pointCloudData.put("Type", "POINT");
|
|
|
+ }
|
|
|
+
|
|
|
+ return pointCloudData;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public int checkPose(int predictedClass) {
|
|
|
+ if ("LIBO".equals(modelType)) {
|
|
|
+ if ("POSE_CLASS_3".equals(poseClass)) {
|
|
|
+ if (predictedClass == 2) {
|
|
|
+ return 4; // POSE_4
|
|
|
+ } else {
|
|
|
+ return predictedClass;
|
|
|
+ }
|
|
|
+ } else if ("POSE_CLASS_4".equals(poseClass) || "POSE_CLASS_5".equals(poseClass)) {
|
|
|
+ return predictedClass;
|
|
|
+ }
|
|
|
+ } else if ("ANDA".equals(modelType)) {
|
|
|
+ return predictedClass;
|
|
|
+ }
|
|
|
+
|
|
|
+ return predictedClass;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public int checkPoseFromResponse(Map<String, Object> response) {
|
|
|
+ Object predictedClassObj = response.get("predicted_class");
|
|
|
+ if (predictedClassObj instanceof Number) {
|
|
|
+ int predictedClass = ((Number) predictedClassObj).intValue();
|
|
|
+ return checkPose(predictedClass);
|
|
|
+ }
|
|
|
+ return DeviceConstants.PoseEnum.POSE_INVALID.getCode();
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 获取目标追踪数据
|
|
|
+ * 对应Python版本的get_tracker_targets方法
|
|
|
+ */
|
|
|
+ private List<List<Float>> getTrackerTargets(List<List<Float>> cloudPoints) {
|
|
|
+ // 简化实现,实际应该调用点云处理服务
|
|
|
+ return new ArrayList<>();
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * 获取姿态文本描述
|
|
|
+ */
|
|
|
+ private String getPoseText(int pose) {
|
|
|
+ switch (pose) {
|
|
|
+ case 0: return "躺";
|
|
|
+ case 1: return "坐";
|
|
|
+ case 2: return "蹲";
|
|
|
+ case 3: return "弯腰";
|
|
|
+ case 4: return "站";
|
|
|
+ default: return "无效";
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|