springcloud TokenAuthFilterFactory

package com.gateway.filter.factory;

import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.factory.AbstractGatewayFilterFactory;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.stereotype.Component;
import org.springframework.util.AntPathMatcher;
import java.util.Arrays;
import java.util.List;

@Component
public class TokenAuthFilterFactory extends AbstractGatewayFilterFactory<TokenAuthFilterFactory.Config> {
    
    private final AntPathMatcher pathMatcher = new AntPathMatcher();

    public TokenAuthFilterFactory() {
        super(Config.class);
    }

    @Override
    public GatewayFilter apply(Config config) {
        return (exchange, chain) -> {
            ServerHttpRequest request = exchange.getRequest();
            String path = request.getPath().value();
            String token = request.getHeaders().getFirst("Authorization");

            // 检查白名单
            if (matchesAny(path, config.getWhiteList())) {
                return chain.filter(exchange);
            }

            // 检查必须验证列表
            if (matchesAny(path, config.getMustVerifyList())) {
                if (token == null) {
                    exchange.getResponse().setStatusCode(HttpStatus.UNAUTHORIZED);
                    return exchange.getResponse().setComplete();
                }
                return validateTokenAndContinue(exchange, chain, token);
            }

            // 检查可选验证列表
            if (matchesAny(path, config.getOptionalVerifyList())) {
                if (token != null) {
                    return validateTokenAndContinue(exchange, chain, token);
                }
                return chain.filter(exchange);
            }

            // 默认需要验证
            if (token == null) {
                exchange.getResponse().setStatusCode(HttpStatus.UNAUTHORIZED);
                return exchange.getResponse().setComplete();
            }
            return validateTokenAndContinue(exchange, chain, token);
        };
    }

    private boolean matchesAny(String path, List<String> patterns) {
        return patterns != null && patterns.stream()
                .anyMatch(pattern -> pathMatcher.match(pattern, path));
    }

    private Mono<Void> validateTokenAndContinue(ServerWebExchange exchange, 
                                              GatewayFilterChain chain, 
                                              String token) {
        // 这里添加token验证逻辑
        try {
            // 假设这是你的token验证和解析逻辑
            TokenInfo tokenInfo = TokenUtil.parseToken(token);
            
            // 将用户信息添加到请求头
            ServerHttpRequest newRequest = exchange.getRequest().mutate()
                .header("X-User-Id", tokenInfo.getUserId())
                .header("X-User-Name", tokenInfo.getUserName())
                .header("X-Tenant-Id", tokenInfo.getTenantId())
                .build();
            
            return chain.filter(exchange.mutate().request(newRequest).build());
        } catch (Exception e) {
            exchange.getResponse().setStatusCode(HttpStatus.UNAUTHORIZED);
            return exchange.getResponse().setComplete();
        }
    }

    @Override
    public List<String> shortcutFieldOrder() {
        return Arrays.asList("whiteList", "mustVerifyList", "optionalVerifyList");
    }

    @Data
    public static class Config {
        private List<String> whiteList;
        private List<String> mustVerifyList;
        private List<String> optionalVerifyList;
    }
}
package com.gateway.filter;

import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.Ordered;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;

@Component
public class IpRateLimitFilter implements GlobalFilter, Ordered {
    @Autowired
    private RedisTemplate<String, String> redisTemplate;
    
    private static final int LIMIT_PER_MINUTE = 100;
    private static final int BAN_MINUTES = 20;
    
    @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        String ip = exchange.getRequest().getRemoteAddress().getAddress().getHostAddress();
        String key = "ip_limit:" + ip;
        
        // 检查是否被禁止访问
        if (Boolean.TRUE.equals(redisTemplate.hasKey("ip_banned:" + ip))) {
            return exchange.getResponse().setComplete();
        }
        
        // 增加计数
        Long count = redisTemplate.opsForValue().increment(key);
        if (count == 1) {
            redisTemplate.expire(key, 1, TimeUnit.MINUTES);
        }
        
        // 检查限制
        if (count > LIMIT_PER_MINUTE) {
            redisTemplate.opsForValue().set("ip_banned:" + ip, "1", BAN_MINUTES, TimeUnit.MINUTES);
            return exchange.getResponse().setComplete();
        }
        
        return chain.filter(exchange);
    }
    
    @Override
    public int getOrder() {
        return -100;
    }
}
package com.gateway.filter;

import com.gateway.service.TokenService;
import org.springframework.cloud.gateway.filter.GatewayFilterChain;
import org.springframework.cloud.gateway.filter.GlobalFilter;
import org.springframework.core.Ordered;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;

@Component
public class AuthenticationFilter implements GlobalFilter, Ordered {
    @Autowired
    private TokenService tokenService;
    @Autowired
    private RouteRepository routeRepository;
    
    @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        ServerHttpRequest request = exchange.getRequest();
        String path = request.getPath().value();
        
        RouteDefinition route = routeRepository.findByPath(path);
        if (route == null) {
            return chain.filter(exchange);
        }
        
        String token = request.getHeaders().getFirst("Authorization");
        
        // 处理不同的认证类型
        switch (route.getAuthType()) {
            case 0: // 无需验证
                return chain.filter(exchange);
            case 1: // 可选验证
                if (token != null) {
                    return processToken(exchange, chain, token);
                }
                return chain.filter(exchange);
            case 2: // 必须验证
                if (token == null) {
                    return exchange.getResponse().setComplete();
                }
                return processToken(exchange, chain, token);
            default:
                return chain.filter(exchange);
        }
    }
    
    private Mono<Void> processToken(ServerWebExchange exchange, GatewayFilterChain chain, String token) {
        TokenInfo tokenInfo = tokenService.parseToken(token);
        if (tokenInfo == null) {
            return exchange.getResponse().setComplete();
        }
        
        // 添加用户信息到header
        ServerHttpRequest newRequest = exchange.getRequest().mutate()
            .header("X-User-Id", tokenInfo.getUserId())
            .header("X-User-Name", tokenInfo.getUserName())
            .header("X-Tenant-Id", tokenInfo.getTenantId())
            .build();
            
        return chain.filter(exchange.mutate().request(newRequest).build());
    }
    
    @Override
    public int getOrder() {
        return -90;
    }
}
package com.gateway.entity;

import lombok.Data;
import javax.persistence.*;

@Data
@Entity
@Table(name = "gateway_route")
public class RouteDefinition {
    @Id
    @GeneratedValue(strategy = GenerationType.IDENTITY)
    private Long id;
    
    private String routeId;
    private String uri;
    private String predicates;
    private String filters;
    private Integer authType; // 0-无需验证, 1-可选验证, 2-必须验证
    private Boolean enabled;
}
package com.gateway.service;

import com.gateway.entity.RouteDefinition;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.cloud.gateway.route.RouteLocator;
import org.springframework.cloud.gateway.route.builder.RouteLocatorBuilder;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Mono;

package com.gateway.service;

import com.gateway.entity.RouteDefinition;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.cloud.gateway.event.RefreshRoutesEvent;
import org.springframework.cloud.gateway.route.RouteDefinitionWriter;
import org.springframework.cloud.gateway.route.RouteLocator;
import org.springframework.cloud.gateway.route.builder.RouteLocatorBuilder;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Mono;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

@Service
public class DynamicRouteService {
    @Autowired
    private RouteRepository routeRepository;
    
    @Autowired
    private RouteDefinitionWriter routeDefinitionWriter;
    
    @Autowired
    private ApplicationEventPublisher publisher;

    private Set<String> currentRouteIds = new HashSet<>();
    
    @Scheduled(fixedRate = 30000) // 每30秒更新一次路由
    public void refreshRoutes() {
        List<RouteDefinition> routes = routeRepository.findAllByEnabled(true);
        
        // 获取数据库中的所有路由ID
        Set<String> newRouteIds = routes.stream()
            .map(RouteDefinition::getRouteId)
            .collect(Collectors.toSet());
            
        // 删除不再存在的路由
        currentRouteIds.stream()
            .filter(id -> !newRouteIds.contains(id))
            .forEach(this::deleteRoute);
            
        // 更新或新增路由
        routes.forEach(this::updateRoute);
        
        // 更新当前路由ID集合
        currentRouteIds = newRouteIds;
        
        // 通知网关刷新路由
        publisher.publishEvent(new RefreshRoutesEvent(this));
    }
    
    private void updateRoute(RouteDefinition route) {
        org.springframework.cloud.gateway.route.RouteDefinition definition = 
            convertToGatewayRoute(route);
        routeDefinitionWriter.save(Mono.just(definition)).subscribe();
    }
    
    private void deleteRoute(String routeId) {
        routeDefinitionWriter.delete(Mono.just(routeId)).subscribe();
    }
    
    private org.springframework.cloud.gateway.route.RouteDefinition convertToGatewayRoute(
            RouteDefinition route) {
        org.springframework.cloud.gateway.route.RouteDefinition definition = 
            new org.springframework.cloud.gateway.route.RouteDefinition();
        definition.setId(route.getRouteId());
        definition.setUri(URI.create(route.getUri()));
        
        // 转换断言
        List<PredicateDefinition> predicates = JsonUtils.parsePredicates(route.getPredicates());
        definition.setPredicates(predicates);
        
        // 转换过滤器
        List<FilterDefinition> filters = JsonUtils.parseFilters(route.getFilters());
        definition.setFilters(filters);
        
        return definition;
    }
    
    @PostConstruct
    public void init() {
        refreshRoutes();
    }
}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值