声明:大部分代码我都是重别人博客里面复制过来的,只是稍稍的修改了一下,能够支持对文件的过滤,话不多说了,直接贴代码了!
- 配置文件(application.properties)
# 如果不想进行Xss过滤,可以注释掉或者设置为false
common.xss-filter-open=true
- 过滤配置类(XssFilterConfig.java)
@Configuration
@ConditionalOnProperty(prefix = "common", name = "xss-filter-open", havingValue = "true", matchIfMissing = false)
public class XSSFilterConfig {
@Bean
public FilterRegistrationBean filterRegistrationBean() {
FilterRegistrationBean registration = new FilterRegistrationBean();
registration.setFilter(xssFilter());
registration.addInitParameter("EXCLUDED_URLS" , "" /** 不需要进行过滤的URL */);
registration.addUrlPatterns("/*");
registration.addInitParameter("paramName", "paramValue");
registration.setName("xssFilter");
return registration;
}
@Bean(name = "xssFilter")
public Filter xssFilter() {
return new XssFilter();
}
}
- Xss过滤器(XssFilter.java)
public class XssFilter implements Filter {
private final static Logger LOG = LoggerFactory.getLogger(XssFilter.class);
private String excludedUrls;
private String[] excludedUrlArray;
@Override
public void init(FilterConfig filterConfig) throws ServletException {
excludedUrls = filterConfig.getInitParameter("EXCLUDED_URLS");
if (StringUtils.isNotBlank(excludedUrls)) {
excludedUrlArray = excludedUrls.split(";");
}
}
@Override
public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain)
throws IOException, ServletException {
String requestUrl = ((HttpServletRequest) req).getRequestURI();
if (LOG.isDebugEnabled()) {
LOG.debug("Enter into xssFilter request url {} , inputStream {}", requestUrl , req.getInputStream());
}
for (String url : excludedUrlArray) {
if(requestUrl.contains(url)){
chain.doFilter(req, res);
return;
}
}
XssHttpServletRequestWrapper xssHttpServletRequestWrapper =
new XssHttpServletRequestWrapper((HttpServletRequest)req);
chain.doFilter(xssHttpServletRequestWrapper, res);
}
@Override
public void destroy() {
}
}
- 自己包装的Request(XssHttpServletRequestWrapper.java)
public class XssHttpServletRequestWrapper extends HttpServletRequestWrapper {
private HttpServletRequest orgRequest;
private byte[] body;
private Collection<Part> parts;
public XssHttpServletRequestWrapper(HttpServletRequest request) throws IOException {
super(request);
this.orgRequest = request;
// 对文件的支持
try {
this.parts = request.getParts();
} catch(ServletException e) {
this.parts = null;
}
this.body = StreamUtils.copyToByteArray(request.getInputStream());
}
/**
* 覆盖getParameter方法,将参数名和参数值都做xss过滤。<br/>
* 如果需要获得原始的值,则通过super.getParameterValues(name)来获取<br/>
* getParameterNames,getParameterValues和getParameterMap也可能需要覆盖
*/
@Override
public String getParameter(String name) {
String value = super.getParameter(xssEncode(name, 0));
if (null != value) {
value = xssEncode(value, 0);
}
return value;
}
@Override
public String[] getParameterValues(String name) {
String[] values = super.getParameterValues(xssEncode(name, 0));
if (values == null) {
return null;
}
int count = values.length;
String[] encodedValues = new String[count];
for (int i = 0; i < count; i++) {
encodedValues[i] = xssEncode(values[i], 0);
}
return encodedValues;
}
@Override
public Map getParameterMap() {
HashMap paramMap = (HashMap) super.getParameterMap();
paramMap = (HashMap) paramMap.clone();
for (Iterator iterator = paramMap.entrySet().iterator(); iterator.hasNext(); ) {
Map.Entry entry = (Map.Entry) iterator.next();
String[] values = (String[]) entry.getValue();
for (int i = 0; i < values.length; i++) {
if (values[i] instanceof String) {
values[i] = xssEncode(values[i], 0);
}
}
entry.setValue(values);
}
return paramMap;
}
@Override
public ServletInputStream getInputStream() throws IOException {
ServletInputStream inputStream = null;
String bodyStr = new String(body);
if (StringUtils.isNotEmpty(bodyStr)) {
bodyStr = xssEncode(bodyStr, 1);
final ByteArrayInputStream bais = new ByteArrayInputStream(bodyStr.getBytes());
// import org.springframework.mock.web.DelegatingServletInputStream;
return new DelegatingServletInputStream(bais);
}
return inputStream;
}
/**
* 覆盖getHeader方法,将参数名和参数值都做xss过滤。<br/>
* 如果需要获得原始的值,则通过super.getHeaders(name)来获取<br/>
* getHeaderNames 也可能需要覆盖
*/
@Override
public String getHeader(String name) {
String value = super.getHeader(xssEncode(name, 0));
if (value != null) {
value = xssEncode(value, 0);
}
return value;
}
/**
* 将容易引起xss漏洞的半角字符直接替换成全角字符
*
* @param s
* @return
*/
private static String xssEncode(String s, int type) {
if (s == null || s.isEmpty()) {
return s;
}
StringBuilder sb = new StringBuilder(s.length() + 16);
for (int i = 0; i < s.length(); i++) {
char c = s.charAt(i);
if (type == 0) {
switch (c) {
case '\'':
// 全角单引号
sb.append('‘');
break;
case '\"':
// 全角双引号
sb.append('“');
break;
case '>':
// 全角大于号
sb.append('>');
break;
case '<':
// 全角小于号
sb.append('<');
break;
case '&':
// 全角&符号
sb.append('&');
break;
case '\\':
// 全角斜线
sb.append('\');
break;
case '#':
// 全角井号
sb.append('#');
break;
// < 字符的 URL 编码形式表示的 ASCII 字符(十六进制格式) 是: %3c
case '%':
processUrlEncoder(sb, s, i);
break;
default:
sb.append(c);
break;
}
} else {
switch (c) {
case '>':
// 全角大于号
sb.append('>');
break;
case '<':
// 全角小于号
sb.append('<');
break;
case '&':
// 全角&符号
sb.append('&');
break;
case '#':
// 全角井号
sb.append('#');
break;
// < 字符的 URL 编码形式表示的 ASCII 字符(十六进制格式) 是: %3c
case '%':
processUrlEncoder(sb, s, i);
break;
default:
sb.append(c);
break;
}
}
}
return sb.toString();
}
public static void processUrlEncoder(StringBuilder sb, String s, int index) {
if (s.length() >= index + 2) {
// %3c, %3C
if (s.charAt(index + 1) == '3' && (s.charAt(index + 2) == 'c' || s.charAt(index + 2) == 'C')) {
sb.append('<');
return;
}
// %3c (0x3c=60)
if (s.charAt(index + 1) == '6' && s.charAt(index + 2) == '0') {
sb.append('<');
return;
}
// %3e, %3E
if (s.charAt(index + 1) == '3' && (s.charAt(index + 2) == 'e' || s.charAt(index + 2) == 'E')) {
sb.append('>');
return;
}
// %3e (0x3e=62)
if (s.charAt(index + 1) == '6' && s.charAt(index + 2) == '2') {
sb.append('>');
return;
}
}
sb.append(s.charAt(index));
}
/**
* 获取最原始的request
*
* @return
*/
public HttpServletRequest getOrgRequest() {
return orgRequest;
}
/**
* 获取最原始的request的静态方法
*
* @return
*/
public static HttpServletRequest getOrgRequest(HttpServletRequest req) {
if (req instanceof XssHttpServletRequestWrapper) {
return ((XssHttpServletRequestWrapper) req).getOrgRequest();
}
return req;
}
}
习惯性的谢谢!!!