|
|
@@ -0,0 +1,153 @@
|
|
|
+package cn.hfln.framework.mqtt.handler;
|
|
|
+
|
|
|
+import cn.hfln.framework.mqtt.annotation.MqttSubscriber;
|
|
|
+import org.eclipse.paho.client.mqttv3.MqttClient;
|
|
|
+import org.junit.jupiter.api.BeforeEach;
|
|
|
+import org.junit.jupiter.api.Test;
|
|
|
+import org.mockito.Mock;
|
|
|
+import org.mockito.MockitoAnnotations;
|
|
|
+import org.springframework.context.ApplicationContext;
|
|
|
+import org.springframework.integration.mqtt.inbound.MqttPahoMessageDrivenChannelAdapter;
|
|
|
+import org.springframework.messaging.Message;
|
|
|
+import org.springframework.messaging.MessageHeaders;
|
|
|
+
|
|
|
+import java.lang.reflect.Field;
|
|
|
+import java.lang.reflect.Method;
|
|
|
+import java.util.HashMap;
|
|
|
+import java.util.Map;
|
|
|
+
|
|
|
+import static org.junit.jupiter.api.Assertions.*;
|
|
|
+import static org.mockito.Mockito.*;
|
|
|
+
|
|
|
+class MqttSubscriberProcessorTest {
|
|
|
+ @Mock
|
|
|
+ private MqttPahoMessageDrivenChannelAdapter mqttInbound;
|
|
|
+ @Mock
|
|
|
+ private MqttClient mqttClient;
|
|
|
+ @Mock
|
|
|
+ private ApplicationContext applicationContext;
|
|
|
+
|
|
|
+ private MqttSubscriberProcessor processor;
|
|
|
+
|
|
|
+ @BeforeEach
|
|
|
+ void setUp() throws Exception {
|
|
|
+ MockitoAnnotations.openMocks(this);
|
|
|
+ processor = new MqttSubscriberProcessor(mqttInbound, mqttClient);
|
|
|
+ Field ctxField = MqttSubscriberProcessor.class.getDeclaredField("applicationContext");
|
|
|
+ ctxField.setAccessible(true);
|
|
|
+ ctxField.set(processor, applicationContext);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ void testAfterSingletonsInstantiated_registersSubscribers() {
|
|
|
+ when(applicationContext.getBeanDefinitionNames()).thenReturn(new String[]{"testBean"});
|
|
|
+ when(applicationContext.getBean("testBean")).thenReturn(new TestSubscriber());
|
|
|
+ processor.afterSingletonsInstantiated();
|
|
|
+ // 验证topic注册
|
|
|
+ verify(mqttInbound).addTopic("/test/topic", 1);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ void testProcessMessage_exactMatch() throws Exception {
|
|
|
+ // 注册订阅方法
|
|
|
+ TestSubscriber bean = new TestSubscriber();
|
|
|
+ Method method = TestSubscriber.class.getMethod("handle", String.class);
|
|
|
+ MqttSubscriber annotation = method.getAnnotation(MqttSubscriber.class);
|
|
|
+ Object subscriberMethod = createSubscriberMethod(bean, method, annotation);
|
|
|
+ putTopicSubscriber("/test/topic", subscriberMethod);
|
|
|
+ // 构造消息
|
|
|
+ Message<String> message = mock(Message.class);
|
|
|
+ Map<String, Object> headers1 = new HashMap<>();
|
|
|
+ headers1.put("mqtt_receivedTopic", "/test/topic");
|
|
|
+ when(message.getHeaders()).thenReturn(new MessageHeaders(headers1));
|
|
|
+ when(message.getPayload()).thenReturn("payload");
|
|
|
+ // 调用
|
|
|
+ invokeProcessMessage(message);
|
|
|
+ assertEquals("payload", bean.lastMsg);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ void testProcessMessage_wildcardMatch() throws Exception {
|
|
|
+ TestSubscriber bean = new TestSubscriber();
|
|
|
+ Method method = TestSubscriber.class.getMethod("handle", String.class);
|
|
|
+ MqttSubscriber annotation = method.getAnnotation(MqttSubscriber.class);
|
|
|
+ Object subscriberMethod = createSubscriberMethod(bean, method, annotation);
|
|
|
+ putTopicSubscriber("/test/+", subscriberMethod);
|
|
|
+ Message<String> message = mock(Message.class);
|
|
|
+ Map<String, Object> headers2 = new HashMap<>();
|
|
|
+ headers2.put("mqtt_receivedTopic", "/test/abc");
|
|
|
+ when(message.getHeaders()).thenReturn(new MessageHeaders(headers2));
|
|
|
+ when(message.getPayload()).thenReturn("wild");
|
|
|
+ invokeProcessMessage(message);
|
|
|
+ assertEquals("wild", bean.lastMsg);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ void testProcessMessage_noSubscriber() {
|
|
|
+ Message<String> message = mock(Message.class);
|
|
|
+ Map<String, Object> headers3 = new HashMap<>();
|
|
|
+ headers3.put("mqtt_receivedTopic", "/not/exist");
|
|
|
+ when(message.getHeaders()).thenReturn(new MessageHeaders(headers3));
|
|
|
+ when(message.getPayload()).thenReturn("none");
|
|
|
+ // 不抛异常
|
|
|
+ invokeProcessMessage(message);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ void testProcessMessage_nullTopic() {
|
|
|
+ Message<String> message = mock(Message.class);
|
|
|
+ when(message.getHeaders()).thenReturn(new MessageHeaders(new HashMap<>()));
|
|
|
+ when(message.getPayload()).thenReturn("none");
|
|
|
+ // 不抛异常
|
|
|
+ invokeProcessMessage(message);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ void testInvokeSubscriberMethod_invalidParamCount() throws Exception {
|
|
|
+ TestInvalidSubscriber bean = new TestInvalidSubscriber();
|
|
|
+ Method method = TestInvalidSubscriber.class.getMethod("invalid", String.class, String.class, String.class);
|
|
|
+ MqttSubscriber annotation = method.getAnnotation(MqttSubscriber.class);
|
|
|
+ Method regMethod = MqttSubscriberProcessor.class.getDeclaredMethod("registerSubscriberMethod", Object.class, Method.class, MqttSubscriber.class);
|
|
|
+ regMethod.setAccessible(true);
|
|
|
+ regMethod.invoke(processor, bean, method, annotation);
|
|
|
+ // 不会注册
|
|
|
+ Field mapField = MqttSubscriberProcessor.class.getDeclaredField("topicSubscriberMap");
|
|
|
+ mapField.setAccessible(true);
|
|
|
+ Map<String, ?> map = (Map<String, ?>) mapField.get(processor);
|
|
|
+ assertFalse(map.containsKey("/invalid/topic"));
|
|
|
+ }
|
|
|
+
|
|
|
+ // 工具方法:反射注入topicSubscriberMap
|
|
|
+ private void putTopicSubscriber(String topic, Object subscriberMethod) throws Exception {
|
|
|
+ Field mapField = MqttSubscriberProcessor.class.getDeclaredField("topicSubscriberMap");
|
|
|
+ mapField.setAccessible(true);
|
|
|
+ Map<String, Object> map = (Map<String, Object>) mapField.get(processor);
|
|
|
+ map.put(topic, subscriberMethod);
|
|
|
+ }
|
|
|
+ // 工具方法:反射创建SubscriberMethod
|
|
|
+ private Object createSubscriberMethod(Object bean, Method method, MqttSubscriber annotation) throws Exception {
|
|
|
+ Class<?> smClass = Class.forName("cn.hfln.framework.mqtt.handler.MqttSubscriberProcessor$SubscriberMethod");
|
|
|
+ return smClass.getConstructor(Object.class, Method.class, MqttSubscriber.class)
|
|
|
+ .newInstance(bean, method, annotation);
|
|
|
+ }
|
|
|
+ // 工具方法:反射调用processMessage
|
|
|
+ private void invokeProcessMessage(Message<?> message) {
|
|
|
+ try {
|
|
|
+ Method m = MqttSubscriberProcessor.class.getDeclaredMethod("processMessage", Message.class);
|
|
|
+ m.setAccessible(true);
|
|
|
+ m.invoke(processor, message);
|
|
|
+ } catch (Exception e) {
|
|
|
+ fail("Exception in processMessage: " + e.getMessage());
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ static class TestSubscriber {
|
|
|
+ String lastMsg;
|
|
|
+ @MqttSubscriber(topic = "/test/topic", qos = 1)
|
|
|
+ public void handle(String msg) { lastMsg = msg; }
|
|
|
+ }
|
|
|
+ static class TestInvalidSubscriber {
|
|
|
+ @MqttSubscriber(topic = "/invalid/topic", qos = 1)
|
|
|
+ public void invalid(String a, String b, String c) {}
|
|
|
+ }
|
|
|
+}
|