net.dreamlu.iot.mqtt.spring.server.MqttServerFunctionDetector Maven / Gradle / Ivy
The newest version!
/*
* Copyright (c) 2019-2029, Dreamlu 卢春梦 ([email protected] & dreamlu.net).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package net.dreamlu.iot.mqtt.spring.server;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import net.dreamlu.iot.mqtt.codec.MqttPublishMessage;
import net.dreamlu.iot.mqtt.core.server.func.MqttFunctionManager;
import net.dreamlu.iot.mqtt.core.util.TopicUtil;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.context.ApplicationContext;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.lang.NonNull;
import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils;
import org.tio.core.ChannelContext;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Map;
/**
* Mqtt 服务端消息处理
*
* @author L.cm
*/
@Slf4j
@RequiredArgsConstructor
public class MqttServerFunctionDetector implements BeanPostProcessor {
private final ApplicationContext applicationContext;
private final MqttFunctionManager functionManager;
@Override
public Object postProcessAfterInitialization(@NonNull Object bean, String beanName) throws BeansException {
Class> userClass = ClassUtils.getUserClass(bean);
// 1. 查找类上的 MqttServerFunction 注解
if (bean instanceof MqttServerFunctionMessageListener) {
MqttServerFunction subscribe = AnnotationUtils.findAnnotation(userClass, MqttServerFunction.class);
if (subscribe != null) {
String[] topicFilters = getTopicFilters(applicationContext, subscribe.value());
functionManager.register(topicFilters, (MqttServerFunctionMessageListener) bean);
}
} else {
// 2. 查找方法上的 MqttServerFunction 注解
ReflectionUtils.doWithMethods(userClass, method -> {
MqttServerFunction subscribe = AnnotationUtils.findAnnotation(method, MqttServerFunction.class);
if (subscribe != null) {
// 1. 校验必须为 public 和非 static 的方法
int modifiers = method.getModifiers();
if (Modifier.isStatic(modifiers)) {
throw new IllegalArgumentException("@MqttServerFunction on method " + method + " must not static.");
}
if (!Modifier.isPublic(modifiers)) {
throw new IllegalArgumentException("@MqttServerFunction on method " + method + " must public.");
}
// 2. 校验 method 入参数必须等于2
int paramCount = method.getParameterCount();
if (paramCount < 2 || paramCount > 6) {
throw new IllegalArgumentException("@MqttServerFunction on method " + method + " parameter count must 2 ~ 6.");
}
// 3. 校验 method 入参类型必须为 String、Map、MqttPublishMessage、ByteBuffer
Class>[] parameterTypes = method.getParameterTypes();
checkParameterTypes(method, parameterTypes);
String[] topicTemplates = subscribe.value();
String[] topicFilters = getTopicFilters(applicationContext, topicTemplates);
// 4. 监听器
MqttServerFunctionMessageListener messageListener = new MqttServerFunctionMessageListener(
bean, method, topicTemplates, topicFilters);
// 5. 注册监听器
functionManager.register(topicFilters, messageListener);
}
}, ReflectionUtils.USER_DECLARED_METHODS);
}
return bean;
}
/**
* 解析 Spring boot env 变量
*
* @param applicationContext ApplicationContext
* @param values values
* @return topic array
*/
private static String[] getTopicFilters(ApplicationContext applicationContext, String[] values) {
// 1. 替换 Spring boot env 变量
// 2. 替换订阅中的其他变量
return Arrays.stream(values)
.map(applicationContext.getEnvironment()::resolvePlaceholders)
.map(TopicUtil::getTopicFilter)
.toArray(String[]::new);
}
/**
* 校验方法参数
*
* @param method Method
* @param parameterTypes parameterTypes
*/
private static void checkParameterTypes(Method method, Class>[] parameterTypes) {
for (Class> parameterType : parameterTypes) {
if (String.class != parameterType &&
ChannelContext.class != parameterType &&
Map.class != parameterType &&
MqttPublishMessage.class != parameterType &&
byte[].class != parameterType &&
ByteBuffer.class != parameterType) {
throw new IllegalArgumentException("@MqttServerFunction on method " + method + " parameter type must String topic, Map topicVars, MqttPublishMessage message, byte[] payload or ByteBuffer payload.");
}
}
}
}