自定义注解+Redis实现对接口post请求限流
gitee源码网址(点击直接跳转)
https://gitee.com/young-mou/request-limit-redis
@RestController
@RequestMapping("/user")
public class UserController {
@PostMapping("/getUser")
@RequestLimit
public User getUser(@RequestBody User user) {
return User.builder().age(23).name("张三").build();
}
}
重写RequestBodyAdviceAdapter
@ControllerAdvice
@Order(999)
public class RequestLimitAdvice extends RequestBodyAdviceAdapter {
@Resource
private Redisson redisson;
@Resource
private RequestLimitConfig requestLimitConfig;
private static final String REDIS_KEY_PREFIX = "requestLimit:";
@Override
public boolean supports(MethodParameter methodParameter, Type targetType, Class<? extends HttpMessageConverter<?>> converterType) {
boolean support = Boolean.parseBoolean(requestLimitConfig.getEnable());
if (support) {
support = methodParameter.getMethod().isAnnotationPresent(RequestLimit.class);
}
return support;
}
@Override
public Object afterBodyRead(Object body, HttpInputMessage inputMessage, MethodParameter parameter, Type targetType, Class<? extends HttpMessageConverter<?>> converterType) {
String key = REDIS_KEY_PREFIX + getCurrentRequest().getQueryString();
final RAtomicLong atomicLong = redisson.getAtomicLong(key);
long lock = atomicLong.getAndIncrement();
if (lock > 0) {
throw new RequestLimitException("重复请求");
} else {
Long time = Optional.ofNullable(getMethod(parameter).getAnnotation(RequestLimit.class))
.map(RequestLimit::limitTime)
.orElse(requestLimitConfig.getDefaultLimitTime());
atomicLong.expireAsync(Duration.of(time, ChronoUnit.MILLIS));
}
return super.afterBodyRead(body, inputMessage, parameter, targetType, converterType);
}
private HttpServletRequest getCurrentRequest() {
RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();
if (requestAttributes == null) {
throw new RuntimeException(ResponseEnum.EXCEPTION.getMsg());
}
return ((ServletRequestAttributes) requestAttributes).getRequest();
}
private Method getMethod(MethodParameter parameter) {
Method method = parameter.getMethod();
if (method == null) {
throw new RuntimeException(ResponseEnum.EXCEPTION.getMsg());
}
return method;
}
}
自定义注解
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface RequestLimit {
long limitTime() default 1000;
}
限流异常拦截
@RestControllerAdvice
@Order(998)
public class ExceptionHanderAdvice implements ResponseBodyAdvice<Object> {
@Override
public boolean supports(MethodParameter returnType, Class<? extends HttpMessageConverter<?>> converterType) {
return true;
}
@Override
public Object beforeBodyWrite(Object body, MethodParameter returnType, MediaType selectedContentType, Class<? extends HttpMessageConverter<?>> selectedConverterType, ServerHttpRequest request, ServerHttpResponse response) {
return body;
}
@ExceptionHandler(RequestLimitException.class)
public ServerResponseEntity<String> handleRequestLimitException(RequestLimitException e) {
return ServerResponseEntity.fail(ResponseEnum.EXCEPTION.value(), Optional.ofNullable(e.getMessage()).orElse(ResponseEnum.EXCEPTION.getMsg()));
}
}
redis序列化配置(防止redis存储乱码)
@Configuration
public class RedisConfig {
@Bean
public RedisTemplate<String, Object> redisTemplate(RedisConnectionFactory redisConnectionFactory) {
// 创建RedisTemplate对象
RedisTemplate<String, Object> template = new RedisTemplate<>();
// 设置Redis连接工厂
template.setConnectionFactory(redisConnectionFactory);
// 创建JSON序列化器(两种设置Value和HashValue的JSON的序列化器)
GenericJackson2JsonRedisSerializer genericJackson2JsonRedisSerializer
= new GenericJackson2JsonRedisSerializer();
// 设置Key和HashKey采用String序列化
template.setKeySerializer(RedisSerializer.string());
template.setHashKeySerializer(RedisSerializer.string());
// 设置Value和HashValue采用JSON的序列化
template.setValueSerializer(genericJackson2JsonRedisSerializer);
template.setHashValueSerializer(genericJackson2JsonRedisSerializer);
return template;
}
}
限流开关及相关配置
@Configuration
@ConfigurationProperties(prefix = "request.limit")
@Getter
public class RequestLimitConfig implements Serializable {
/**
* 开关
*/
private String enable = "true";
/**
* 默认过期时间
*/
private Long defaultLimitTime = 1000L;
}