限流注解
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface RateLimiter {
String key() default Constants.RATE_LIMIT_KEY;
int time() default 10;
int count() default 1;
String type() default "";
String msg() default "访问太过频繁,请求稍后重试!";
}
@Aspect
@Component
@RequiredArgsConstructor
public class RateLimiterAspect {
public static final Logger logger = LoggerFactory.getLogger(RateLimiterAspect.class);
@Resource
RedisTemplate redisTemplate;
@Pointcut(value = "@annotation(rateLimiter)")
public void pointCut(RateLimiter rateLimiter) {
}
@Before("pointCut(rateLimiter)")
public void doBefore(JoinPoint joinPoint, RateLimiter rateLimiter) throws Exception {
this.handle(joinPoint, rateLimiter);
}
private void handle(JoinPoint joinPoint, RateLimiter rateLimiter) throws Exception {
RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();
HttpServletRequest request = (HttpServletRequest) requestAttributes.resolveReference(requestAttributes.REFERENCE_REQUEST);
try {
MethodSignature signature = (MethodSignature) joinPoint.getSignature();
Method method = signature.getMethod();
Class<?> targetClass = method.getDeclaringClass();
String ip = IpUtils.getIp(request) + ":";
StringBuffer stringBuffer = new StringBuffer(RATELIMITER);
stringBuffer.append(ip).append(":").append(method.getName());
String key = stringBuffer.toString();
if (redisTemplate.hasKey(key)) {
redisTemplate.opsForValue().increment(key);
int num = Integer.parseInt((String) redisTemplate.opsForValue().get(key.toString()));
if (num > rateLimiter.count()) {
throw new RateLimiterException("网络繁忙,请稍候再试!");
}
} else {
redisTemplate.opsForValue().set(key, "1", Long.valueOf(rateLimiter.time()), TimeUnit.SECONDS);
}
} catch (RateLimiterException e) {
throw new RateLimiterException("网络繁忙,请稍候再试");
} catch (Exception e) {
logger.error(e.getMessage());
throw new Exception();
}
}
}