package com.gateway.filter.factory;
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.factory.AbstractGatewayFilterFactory;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.stereotype.Component;
import org.springframework.util.AntPathMatcher;
import java.util.Arrays;
import java.util.List;
@Component
public class TokenAuthFilterFactory extends AbstractGatewayFilterFactory<TokenAuthFilterFactory.Config> {
private final AntPathMatcher pathMatcher = new AntPathMatcher();
public TokenAuthFilterFactory() {
super(Config.class);
}
@Override
public GatewayFilter apply(Config config) {
return (exchange, chain) -> {
ServerHttpRequest request = exchange.getRequest();
String path = request.getPath().value();
String token = request.getHeaders().getFirst("Authorization");
// 检查白名单
if (matchesAny(path, config.getWhiteList())) {
return chain.filter(exchange);
}
// 检查必须验证列表
if (matchesAny(path, config.getMustVerifyList())) {
if (token == null) {
exchange.getResponse().setStatusCode(HttpStatus.UNAUTHORIZED);
return exchange.getResponse().setComplete();
}
return validateTokenAndContinue(exchange, chain, token);
}
// 检查可选验证列表
if (matchesAny(path, config.getOptionalVerifyList())) {
if (token != null) {
return validateTokenAndContinue(exchange, chain, token);
}
return chain.filter(exchange);
}
// 默认需要验证
if (token == null) {
exchange.getResponse().setStatusCode(HttpStatus.UNAUTHORIZED);
return exchange.getResponse().setComplete();
}
return validateTokenAndContinue(exchange, chain, token);
};
}
private boolean matchesAny(String path, List<String> patterns) {
return patterns != null && patterns.stream()
.anyMatch(pattern -> pathMatcher.match(pattern, path));
}
private Mono<Void> validateTokenAndContinue(ServerWebExchange exchange,
GatewayFilterChain chain,
String token) {
// 这里添加token验证逻辑
try {
// 假设这是你的token验证和解析逻辑
TokenInfo tokenInfo = TokenUtil.parseToken(token);
// 将用户信息添加到请求头
ServerHttpRequest newRequest = exchange.getRequest().mutate()
.header("X-User-Id", tokenInfo.getUserId())
.header("X-User-Name", tokenInfo.getUserName())
.header("X-Tenant-Id", tokenInfo.getTenantId())
.build();
return chain.filter(exchange.mutate().request(newRequest).build());
} catch (Exception e) {
exchange.getResponse().setStatusCode(HttpStatus.UNAUTHORIZED);
return exchange.getResponse().setComplete();
}
}
@Override
public List<String> shortcutFieldOrder() {
return Arrays.asList("whiteList", "mustVerifyList", "optionalVerifyList");
}
@Data
public static class Config {
private List<String> whiteList;
private List<String> mustVerifyList;
private List<String> optionalVerifyList;
}
}
package com.gateway.filter;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.Ordered;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
@Component
public class IpRateLimitFilter implements GlobalFilter, Ordered {
@Autowired
private RedisTemplate<String, String> redisTemplate;
private static final int LIMIT_PER_MINUTE = 100;
private static final int BAN_MINUTES = 20;
@Override
public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
String ip = exchange.getRequest().getRemoteAddress().getAddress().getHostAddress();
String key = "ip_limit:" + ip;
// 检查是否被禁止访问
if (Boolean.TRUE.equals(redisTemplate.hasKey("ip_banned:" + ip))) {
return exchange.getResponse().setComplete();
}
// 增加计数
Long count = redisTemplate.opsForValue().increment(key);
if (count == 1) {
redisTemplate.expire(key, 1, TimeUnit.MINUTES);
}
// 检查限制
if (count > LIMIT_PER_MINUTE) {
redisTemplate.opsForValue().set("ip_banned:" + ip, "1", BAN_MINUTES, TimeUnit.MINUTES);
return exchange.getResponse().setComplete();
}
return chain.filter(exchange);
}
@Override
public int getOrder() {
return -100;
}
}
package com.gateway.filter;
import com.gateway.service.TokenService;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.Ordered;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
@Component
public class AuthenticationFilter implements GlobalFilter, Ordered {
@Autowired
private TokenService tokenService;
@Autowired
private RouteRepository routeRepository;
@Override
public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
ServerHttpRequest request = exchange.getRequest();
String path = request.getPath().value();
RouteDefinition route = routeRepository.findByPath(path);
if (route == null) {
return chain.filter(exchange);
}
String token = request.getHeaders().getFirst("Authorization");
// 处理不同的认证类型
switch (route.getAuthType()) {
case 0: // 无需验证
return chain.filter(exchange);
case 1: // 可选验证
if (token != null) {
return processToken(exchange, chain, token);
}
return chain.filter(exchange);
case 2: // 必须验证
if (token == null) {
return exchange.getResponse().setComplete();
}
return processToken(exchange, chain, token);
default:
return chain.filter(exchange);
}
}
private Mono<Void> processToken(ServerWebExchange exchange, GatewayFilterChain chain, String token) {
TokenInfo tokenInfo = tokenService.parseToken(token);
if (tokenInfo == null) {
return exchange.getResponse().setComplete();
}
// 添加用户信息到header
ServerHttpRequest newRequest = exchange.getRequest().mutate()
.header("X-User-Id", tokenInfo.getUserId())
.header("X-User-Name", tokenInfo.getUserName())
.header("X-Tenant-Id", tokenInfo.getTenantId())
.build();
return chain.filter(exchange.mutate().request(newRequest).build());
}
@Override
public int getOrder() {
return -90;
}
}
package com.gateway.entity;
import lombok.Data;
import javax.persistence.*;
@Data
@Entity
@Table(name = "gateway_route")
public class RouteDefinition {
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
private Long id;
private String routeId;
private String uri;
private String predicates;
private String filters;
private Integer authType; // 0-无需验证, 1-可选验证, 2-必须验证
private Boolean enabled;
}
package com.gateway.service;
import com.gateway.entity.RouteDefinition;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.cloud.gateway.route.RouteLocator;
import org.springframework.cloud.gateway.route.builder.RouteLocatorBuilder;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Mono;
package com.gateway.service;
import com.gateway.entity.RouteDefinition;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.cloud.gateway.event.RefreshRoutesEvent;
import org.springframework.cloud.gateway.route.RouteDefinitionWriter;
import org.springframework.cloud.gateway.route.RouteLocator;
import org.springframework.cloud.gateway.route.builder.RouteLocatorBuilder;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Mono;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
@Service
public class DynamicRouteService {
@Autowired
private RouteRepository routeRepository;
@Autowired
private RouteDefinitionWriter routeDefinitionWriter;
@Autowired
private ApplicationEventPublisher publisher;
private Set<String> currentRouteIds = new HashSet<>();
@Scheduled(fixedRate = 30000) // 每30秒更新一次路由
public void refreshRoutes() {
List<RouteDefinition> routes = routeRepository.findAllByEnabled(true);
// 获取数据库中的所有路由ID
Set<String> newRouteIds = routes.stream()
.map(RouteDefinition::getRouteId)
.collect(Collectors.toSet());
// 删除不再存在的路由
currentRouteIds.stream()
.filter(id -> !newRouteIds.contains(id))
.forEach(this::deleteRoute);
// 更新或新增路由
routes.forEach(this::updateRoute);
// 更新当前路由ID集合
currentRouteIds = newRouteIds;
// 通知网关刷新路由
publisher.publishEvent(new RefreshRoutesEvent(this));
}
private void updateRoute(RouteDefinition route) {
org.springframework.cloud.gateway.route.RouteDefinition definition =
convertToGatewayRoute(route);
routeDefinitionWriter.save(Mono.just(definition)).subscribe();
}
private void deleteRoute(String routeId) {
routeDefinitionWriter.delete(Mono.just(routeId)).subscribe();
}
private org.springframework.cloud.gateway.route.RouteDefinition convertToGatewayRoute(
RouteDefinition route) {
org.springframework.cloud.gateway.route.RouteDefinition definition =
new org.springframework.cloud.gateway.route.RouteDefinition();
definition.setId(route.getRouteId());
definition.setUri(URI.create(route.getUri()));
// 转换断言
List<PredicateDefinition> predicates = JsonUtils.parsePredicates(route.getPredicates());
definition.setPredicates(predicates);
// 转换过滤器
List<FilterDefinition> filters = JsonUtils.parseFilters(route.getFilters());
definition.setFilters(filters);
return definition;
}
@PostConstruct
public void init() {
refreshRoutes();
}
}