|
@@ -0,0 +1,159 @@
|
|
|
|
+package com.hfln.device.application.service.impl;
|
|
|
|
+
|
|
|
|
+import com.hfln.device.application.service.PoseAnalysisService;
|
|
|
|
+import com.hfln.device.domain.entity.Device;
|
|
|
|
+import com.hfln.device.domain.gateway.MqttGateway;
|
|
|
|
+import com.hfln.device.domain.service.DeviceManagerService;
|
|
|
|
+import org.junit.jupiter.api.AfterEach;
|
|
|
|
+import org.junit.jupiter.api.BeforeEach;
|
|
|
|
+import org.junit.jupiter.api.Test;
|
|
|
|
+import org.mockito.ArgumentCaptor;
|
|
|
|
+import org.mockito.InjectMocks;
|
|
|
|
+import org.mockito.Mock;
|
|
|
|
+import org.mockito.MockitoAnnotations;
|
|
|
|
+import org.springframework.web.client.RestTemplate;
|
|
|
|
+
|
|
|
|
+import java.util.*;
|
|
|
|
+
|
|
|
|
+import static org.junit.jupiter.api.Assertions.*;
|
|
|
|
+import static org.mockito.Mockito.*;
|
|
|
|
+
|
|
|
|
+class PoseAnalysisServiceImplTest {
|
|
|
|
+
|
|
|
|
+ @Mock
|
|
|
|
+ private DeviceManagerService deviceManagerService;
|
|
|
|
+ @Mock
|
|
|
|
+ private MqttGateway mqttGateway;
|
|
|
|
+ @Mock
|
|
|
|
+ private RestTemplate restTemplate;
|
|
|
|
+
|
|
|
|
+ @InjectMocks
|
|
|
|
+ private PoseAnalysisServiceImpl poseAnalysisService;
|
|
|
|
+
|
|
|
|
+ private AutoCloseable closeable;
|
|
|
|
+
|
|
|
|
+ @BeforeEach
|
|
|
|
+ void setUp() {
|
|
|
|
+ closeable = MockitoAnnotations.openMocks(this);
|
|
|
|
+ poseAnalysisService = new PoseAnalysisServiceImpl();
|
|
|
|
+ poseAnalysisService.deviceManagerService = deviceManagerService;
|
|
|
|
+ poseAnalysisService.mqttGateway = mqttGateway;
|
|
|
|
+ poseAnalysisService.restTemplate = restTemplate;
|
|
|
|
+ // 设置默认参数
|
|
|
|
+ poseAnalysisService.modelType = "LIBO";
|
|
|
|
+ poseAnalysisService.poseClass = "POSE_CLASS_4";
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @AfterEach
|
|
|
|
+ void tearDown() throws Exception {
|
|
|
|
+ closeable.close();
|
|
|
|
+ poseAnalysisService.stopPoseAnalysisThread();
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Test
|
|
|
|
+ void testDealPostData_LiboModel() {
|
|
|
|
+ List<List<Float>> rawPoints = Arrays.asList(
|
|
|
|
+ Arrays.asList(1.1f, 2.2f, 3.3f, 4.4f),
|
|
|
|
+ Arrays.asList(5.5f, 6.6f, 7.7f, 8.8f)
|
|
|
|
+ );
|
|
|
|
+ Map<String, Object> result = poseAnalysisService.dealPostData(rawPoints);
|
|
|
|
+ assertTrue(result.containsKey("point_cloud"));
|
|
|
|
+ List<List<Float>> pc = (List<List<Float>>) result.get("point_cloud");
|
|
|
|
+ assertEquals(2, pc.size());
|
|
|
|
+ assertEquals(Arrays.asList(1.1f, 2.2f, 3.3f), pc.get(0));
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Test
|
|
|
|
+ void testDealPostData_AndaModel() {
|
|
|
|
+ poseAnalysisService.modelType = "ANDA";
|
|
|
|
+ List<List<Float>> rawPoints = Arrays.asList(
|
|
|
|
+ Arrays.asList(1.1f, 2.2f, 3.3f, 4.4f)
|
|
|
|
+ );
|
|
|
|
+ Map<String, Object> result = poseAnalysisService.dealPostData(rawPoints);
|
|
|
|
+ assertTrue(result.containsKey("Payload"));
|
|
|
|
+ Map<String, Object> payload = (Map<String, Object>) result.get("Payload");
|
|
|
|
+ assertTrue(payload.containsKey("raw_points"));
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Test
|
|
|
|
+ void testCheckPose_LiboClass4() {
|
|
|
|
+ poseAnalysisService.modelType = "LIBO";
|
|
|
|
+ poseAnalysisService.poseClass = "POSE_CLASS_4";
|
|
|
|
+ assertEquals(2, poseAnalysisService.checkPose(2));
|
|
|
|
+ assertEquals(4, poseAnalysisService.checkPose(4));
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Test
|
|
|
|
+ void testCheckPose_LiboClass3() {
|
|
|
|
+ poseAnalysisService.modelType = "LIBO";
|
|
|
|
+ poseAnalysisService.poseClass = "POSE_CLASS_3";
|
|
|
|
+ assertEquals(4, poseAnalysisService.checkPose(2)); // 2->4
|
|
|
|
+ assertEquals(1, poseAnalysisService.checkPose(1));
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Test
|
|
|
|
+ void testCheckPose_Anda() {
|
|
|
|
+ poseAnalysisService.modelType = "ANDA";
|
|
|
|
+ assertEquals(3, poseAnalysisService.checkPose(3));
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Test
|
|
|
|
+ void testCheckPoseFromResponse() {
|
|
|
|
+ Map<String, Object> resp = new HashMap<>();
|
|
|
|
+ resp.put("predicted_class", 2);
|
|
|
|
+ poseAnalysisService.modelType = "LIBO";
|
|
|
|
+ poseAnalysisService.poseClass = "POSE_CLASS_4";
|
|
|
|
+ assertEquals(2, poseAnalysisService.checkPoseFromResponse(resp));
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Test
|
|
|
|
+ void testStartAndStopPoseAnalysisThread() throws InterruptedException {
|
|
|
|
+ poseAnalysisService.startPoseAnalysisThread();
|
|
|
|
+ assertTrue(poseAnalysisService.running.get());
|
|
|
|
+ poseAnalysisService.stopPoseAnalysisThread();
|
|
|
|
+ Thread.sleep(100); // 等待线程停止
|
|
|
|
+ assertFalse(poseAnalysisService.running.get());
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Test
|
|
|
|
+ void testPoseAnalysisProcess_MockDeviceAndRest() throws InterruptedException {
|
|
|
|
+ // 构造一个带点云的设备
|
|
|
|
+ Device device = mock(Device.class);
|
|
|
|
+ when(device.getMaxLenCloudPoints()).thenReturn(Arrays.asList(
|
|
|
|
+ Arrays.asList(1.0f, 2.0f, 3.0f),
|
|
|
|
+ Arrays.asList(4.0f, 5.0f, 6.0f)
|
|
|
|
+ ));
|
|
|
|
+ when(device.getDevId()).thenReturn("dev001");
|
|
|
|
+ when(device.getRealtimePose()).thenReturn(Arrays.asList(0));
|
|
|
|
+ when(device.getAlarmAck()).thenReturn(false);
|
|
|
|
+ when(device.getLastReportFallTime()).thenReturn(null);
|
|
|
|
+ when(device.getAlarmInterval()).thenReturn(0L);
|
|
|
|
+ doNothing().when(device).updatePose(anyInt());
|
|
|
|
+ doNothing().when(device).setLastFallTime(anyString(), anyLong());
|
|
|
|
+
|
|
|
|
+ Collection<Device> devices = Collections.singletonList(device);
|
|
|
|
+ when(deviceManagerService.getAllDevicesFromCache()).thenReturn(devices);
|
|
|
|
+
|
|
|
|
+ // mock restTemplate返回AI响应
|
|
|
|
+ Map<String, Object> aiResp = new HashMap<>();
|
|
|
|
+ aiResp.put("predicted_class", 0);
|
|
|
|
+ when(restTemplate.postForObject(anyString(), any(), eq(Map.class))).thenReturn(aiResp);
|
|
|
|
+
|
|
|
|
+ // mock mqttGateway
|
|
|
|
+ doNothing().when(mqttGateway).sendEventMessage(anyString(), any(), anyInt(), any(), anyString());
|
|
|
|
+
|
|
|
|
+ // 启动线程,运行一次循环
|
|
|
|
+ poseAnalysisService.running.set(true);
|
|
|
|
+ Thread t = new Thread(() -> {
|
|
|
|
+ poseAnalysisService.poseAnalysisProcess();
|
|
|
|
+ });
|
|
|
|
+ t.start();
|
|
|
|
+ Thread.sleep(300); // 让线程跑一会
|
|
|
|
+ poseAnalysisService.running.set(false);
|
|
|
|
+ t.interrupt();
|
|
|
|
+ t.join(500);
|
|
|
|
+ // 验证AI服务和消息发送被调用
|
|
|
|
+ verify(restTemplate, atLeastOnce()).postForObject(anyString(), any(), eq(Map.class));
|
|
|
|
+ verify(mqttGateway, atLeastOnce()).sendEventMessage(anyString(), any(), anyInt(), any(), anyString());
|
|
|
|
+ }
|
|
|
|
+}
|