0
点赞
收藏
分享

微信扫一扫

中文版goole浏览器支持小于12px的文字

穿裙子的程序员 2023-10-31 阅读 48
spring

在给spring webflux做接口签名、防重放的时候,往往需要获取请求参数,请求方法等,而spring webflux无法像spring mvc那样好获取,这里根据之前的实践特地说明一下:

总体思路:
1、利用过滤器,从原request中获取到信息后,缓存在一个上下文对象中,然后构造新的request,传入后面的过滤器。因为原request流式的,用过一次后便无法再取参数了。
2、通过exchange的Attributes传递上下文对象,在不同的过滤器中使用即可。

1、上下文对象

@Getter
@Setter
@ToString
public class GatewayContext {

    public static final String CACHE_GATEWAY_CONTEXT = "cacheGatewayContext";

    /**
     * cache requestMethod
     */
    private String requestMethod;

    /**
     * cache queryParams
     */
    private MultiValueMap<String, String> queryParams;

    /**
     * cache json body
     */
    private String requestBody;
    /**
     * cache Response Body
     */
    private Object responseBody;
    /**
     * request headers
     */
    private HttpHeaders requestHeaders;
    /**
     * cache form data
     */
    private MultiValueMap<String, String> formData;
    /**
     * cache all request data include:form data and query param
     */
    private MultiValueMap<String, String> allRequestData = new LinkedMultiValueMap<>(0);

    private byte[] requestBodyBytes;

}

2、在过滤器中获取请求参数、请求方法。
这里我们只对application/jsonapplication/x-www-form-urlencoded这种做body参数拦截,而对于其他的请求,则可以通过url直接获取到query参数。

@Slf4j
@Component
public class GatewayContextFilter implements WebFilter, Ordered {

    /**
     * default HttpMessageReader
     */
    private static final List<HttpMessageReader<?>> MESSAGE_READERS = HandlerStrategies.withDefaults().messageReaders();


    @Override
    public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
        ServerHttpRequest request = exchange.getRequest();
        GatewayContext gatewayContext = new GatewayContext();
        HttpHeaders headers = request.getHeaders();
        gatewayContext.setRequestHeaders(headers);
        gatewayContext.getAllRequestData().addAll(request.getQueryParams());
        gatewayContext.setRequestMethod(request.getMethodValue().toUpperCase());
        gatewayContext.setQueryParams(request.getQueryParams());
        /*
         * save gateway context into exchange
         */
        exchange.getAttributes().put(GatewayContext.CACHE_GATEWAY_CONTEXT, gatewayContext);
        MediaType contentType = headers.getContentType();
        if (headers.getContentLength() > 0) {
            if (MediaType.APPLICATION_JSON.equals(contentType)) {
                return readBody(exchange, chain, gatewayContext);

            }
            if (MediaType.APPLICATION_FORM_URLENCODED.equalsTypeAndSubtype(contentType)) {
                return readFormData(exchange, chain, gatewayContext);
            }
        }

        String path = request.getPath().value();
        if (!"/".equals(path)) {
            log.info("{} Gateway context is set with {}-{}", path, contentType, gatewayContext);
        }
        return chain.filter(exchange);
    }


    @Override
    public int getOrder() {
        return Integer.MIN_VALUE + 1;
    }


    /**
     * ReadFormData
     */
    private Mono<Void> readFormData(ServerWebExchange exchange, WebFilterChain chain, GatewayContext gatewayContext) {
        HttpHeaders headers = exchange.getRequest().getHeaders();
        return exchange.getFormData()
                .doOnNext(multiValueMap -> {
                    gatewayContext.setFormData(multiValueMap);
                    gatewayContext.getAllRequestData().addAll(multiValueMap);
                    log.debug("[GatewayContext]Read FormData Success");
                })
                .then(Mono.defer(() -> {
                    Charset charset = headers.getContentType().getCharset();
                    charset = charset == null ? StandardCharsets.UTF_8 : charset;
                    String charsetName = charset.name();
                    MultiValueMap<String, String> formData = gatewayContext.getFormData();
                    /*
                     * formData is empty just return
                     */
                    if (null == formData || formData.isEmpty()) {
                        return chain.filter(exchange);
                    }
                    log.info("1. Gateway Context formData: {}", formData);
                    StringBuilder formDataBodyBuilder = new StringBuilder();
                    String entryKey;
                    List<String> entryValue;
                    try {
                        /*
                         * repackage form data
                         */
                        for (Map.Entry<String, List<String>> entry : formData.entrySet()) {
                            entryKey = entry.getKey();
                            entryValue = entry.getValue();
                            if (entryValue.size() > 1) {
                                for (String value : entryValue) {
                                    formDataBodyBuilder
                                            .append(URLEncoder.encode(entryKey, charsetName).replace("+", "%20").replace("*", "%2A").replace("%7E", "~"))
                                            .append("=")
                                            .append(URLEncoder.encode(value, charsetName).replace("+", "%20").replace("*", "%2A").replace("%7E", "~"))
                                            .append("&");
                                }
                            } else {
                                formDataBodyBuilder
                                        .append(URLEncoder.encode(entryKey, charsetName).replace("+", "%20").replace("*", "%2A").replace("%7E", "~"))
                                        .append("=")
                                        .append(URLEncoder.encode(entryValue.get(0), charsetName).replace("+", "%20").replace("*", "%2A").replace("%7E", "~"))
                                        .append("&");
                            }
                        }
                    } catch (UnsupportedEncodingException e) {
                        log.error("GatewayContext readFormData error {}", e.getMessage(), e);
                    }
                    /*
                     * 1. substring with the last char '&'
                     * 2. if the current request is encrypted, substring with the start chat 'secFormData'
                     */
                    String formDataBodyString = "";
                    String originalFormDataBodyString = "";
                    if (formDataBodyBuilder.length() > 0) {
                        formDataBodyString = formDataBodyBuilder.substring(0, formDataBodyBuilder.length() - 1);
                        originalFormDataBodyString = formDataBodyString;
                    }
                    /*
                     * get data bytes
                     */
                    byte[] bodyBytes = formDataBodyString.getBytes(charset);
                    int contentLength = bodyBytes.length;
                    gatewayContext.setRequestBodyBytes(originalFormDataBodyString.getBytes(charset));
                    HttpHeaders httpHeaders = new HttpHeaders();
                    httpHeaders.putAll(exchange.getRequest().getHeaders());
                    httpHeaders.remove(HttpHeaders.CONTENT_LENGTH);
                    /*
                     * in case of content-length not matched
                     */
                    httpHeaders.setContentLength(contentLength);
                    /*
                     * use BodyInserter to InsertFormData Body
                     */
                    BodyInserter<String, ReactiveHttpOutputMessage> bodyInserter = BodyInserters.fromObject(formDataBodyString);
                    CachedBodyOutputMessage cachedBodyOutputMessage = new CachedBodyOutputMessage(exchange, httpHeaders);
                    log.info("2. GatewayContext Rewrite Form Data :{}", formDataBodyString);
                    return bodyInserter.insert(cachedBodyOutputMessage, new BodyInserterContext())
                            .then(Mono.defer(() -> {
                                ServerHttpRequestDecorator decorator = new ServerHttpRequestDecorator(
                                        exchange.getRequest()) {
                                    @Override
                                    public HttpHeaders getHeaders() {
                                        return httpHeaders;
                                    }

                                    @Override
                                    public Flux<DataBuffer> getBody() {
                                        return cachedBodyOutputMessage.getBody();
                                    }
                                };
                                return chain.filter(exchange.mutate().request(decorator).build());
                            }));
                }));
    }


    /**
     * ReadJsonBody
     */
    private Mono<Void> readBody(ServerWebExchange exchange, WebFilterChain chain, GatewayContext gatewayContext) {
        return DataBufferUtils.join(exchange.getRequest().getBody())
                .flatMap(dataBuffer -> {
                    /*
                     * read the body Flux<DataBuffer>, and release the buffer
                     * when SpringCloudGateway Version Release To G.SR2,this can be update with the new version's feature
                     * see PR https://github.com/spring-cloud/spring-cloud-gateway/pull/1095
                     */
                    byte[] bytes = new byte[dataBuffer.readableByteCount()];
                    dataBuffer.read(bytes);
                    DataBufferUtils.release(dataBuffer);
                    gatewayContext.setRequestBodyBytes(bytes);
                    Flux<DataBuffer> cachedFlux = Flux.defer(() -> {
                        DataBuffer buffer = exchange.getResponse().bufferFactory().wrap(bytes);
                        DataBufferUtils.retain(buffer);
                        return Mono.just(buffer);
                    });
                    /*
                     * repackage ServerHttpRequest
                     */
                    ServerHttpRequest mutatedRequest = new ServerHttpRequestDecorator(exchange.getRequest()) {
                        @Override
                        public Flux<DataBuffer> getBody() {
                            return cachedFlux;
                        }
                    };
                    ServerWebExchange mutatedExchange = exchange.mutate().request(mutatedRequest).build();
                    return ServerRequest.create(mutatedExchange, MESSAGE_READERS)
                            .bodyToMono(String.class)
                            .doOnNext(objectValue -> {
                                gatewayContext.setRequestBody(objectValue);
                                if (objectValue != null && !objectValue.trim().startsWith("{")) {
                                    return;
                                }
                                try {
                                    gatewayContext.getAllRequestData().setAll(JsonUtil.fromJson(objectValue, Map.class));
                                } catch (Exception e) {
                                    log.warn("Gateway context Read JsonBody error:{}", e.getMessage(), e);
                                }
                            }).then(chain.filter(mutatedExchange));
                });
    }

}

3、签名、防重放校验
这里我们从上下文对象中取出参数即可
签名算法逻辑:
在这里插入图片描述

@Slf4j
@Component
public class GatewaySignCheckFilter implements WebFilter, Ordered {


    @Value("${api.rest.prefix}")
    private String apiPrefix;

    @Autowired
    private RedisUtil redisUtil;

    //前后端约定签名密钥
    private static final String API_SECRET = "secret-xxx";

    @Override
    public int getOrder() {
        return Integer.MIN_VALUE + 2;
    }

    @NotNull
    @Override
    public Mono<Void> filter(ServerWebExchange exchange, @NotNull WebFilterChain chain) {
        ServerHttpRequest request = exchange.getRequest();
        String uri = request.getURI().getPath();
        GatewayContext gatewayContext = (GatewayContext) exchange.getAttributes().get(GatewayContext.CACHE_GATEWAY_CONTEXT);
        HttpHeaders headers = gatewayContext.getRequestHeaders();
        MediaType contentType = headers.getContentType();
        log.info("check url:{},method:{},contentType:{}", uri, gatewayContext.getRequestMethod(), contentType == null ? "" : contentType.toString());
        //如果contentType为空,只能是get请求
        if (contentType == null || StringUtils.isBlank(contentType.toString())) {
            if (request.getMethod() != HttpMethod.GET) {
                throw new RuntimeException("非法访问");
            }
            checkSign(uri, gatewayContext, exchange);
        } else {
            if (MediaType.APPLICATION_JSON.equals(contentType) || MediaType.APPLICATION_FORM_URLENCODED.equalsTypeAndSubtype(contentType)) {
                checkSign(uri, gatewayContext, exchange);
            }
        }

        return chain.filter(exchange);
    }


    private void checkSign(String uri, GatewayContext gatewayContext, ServerWebExchange exchange) {
        //忽略掉的请求
        List<String> ignores = Lists.newArrayList("/open/**", "/open/login/params", "/open/image");
        for (String ignore : ignores) {
            ignore = apiPrefix + ignore;
            if (uri.equals(ignore) || uri.startsWith(ignore.replace("/**", "/"))) {
                log.info("check sign ignore:{}", uri);
                return;
            }
        }
        String method = gatewayContext.getRequestMethod();
        log.info("start check sign {}-{}", method, uri);
        HttpHeaders headers = gatewayContext.getRequestHeaders();
        log.info("headers:{}", JsonUtils.objectToJson(headers));
        String clientId = getHeaderAttr(headers, SystemSign.CLIENT_ID);
        String timestamp = getHeaderAttr(headers, SystemSign.TIMESTAMP);
        String nonce = getHeaderAttr(headers, SystemSign.NONCE);
        String sign = getHeaderAttr(headers, SystemSign.SIGN);
        checkTime(timestamp);
        checkOnce(nonce);
        String headerStr = String.format("%s=%s&%s=%s&%s=%s", SystemSign.CLIENT_ID, clientId,
                SystemSign.NONCE, nonce, SystemSign.TIMESTAMP, timestamp);
        String signSecret = API_SECRET;
        String queryUri = uri + getQueryParam(gatewayContext.getQueryParams());
        log.info("headerStr:{},signSecret:{},queryUri:{}", headerStr, signSecret, queryUri);
        String realSign = calculatorSign(clientId, queryUri, gatewayContext, headerStr, signSecret);
        log.info("sign:{}, realSign:{}", sign, realSign);
        if (!realSign.equals(sign)) {
            log.warn("wrong sign");
            throw new RuntimeException("Illegal sign");
        }
    }

    private String getQueryParam(MultiValueMap<String, String> queryParams) {
        if (queryParams == null || queryParams.size() == 0) {
            return StringUtils.EMPTY;
        }
        StringBuilder builder = new StringBuilder("?");
        for (Map.Entry<String, List<String>> entry : queryParams.entrySet()) {
            String key = entry.getKey();
            List<String> value = entry.getValue();
            builder.append(key).append("=").append(value.get(0)).append("&");
        }
        builder.deleteCharAt(builder.length() - 1);
        return builder.toString();
    }

    private String getHeaderAttr(HttpHeaders headers, String key) {
        List<String> values = headers.get(key);
        if (CollectionUtils.isEmpty(values)) {
            log.warn("GatewaySignCheckFilter empty header:{}", key);
            throw new RuntimeException("GatewaySignCheckFilter empty header:" + key);
        }
        String value = values.get(0);
        if (StringUtils.isBlank(value)) {
            log.warn("GatewaySignCheckFilter empty header:{}", key);
            throw new RuntimeException("GatewaySignCheckFilter empty header:" + key);
        }
        return value;
    }


    private String calculatorSign(String clientId, String queryUri, GatewayContext gatewayContext, String headerStr, String signSecret) {
        String method = gatewayContext.getRequestMethod();
        byte[] bodyBytes = gatewayContext.getRequestBodyBytes();
        if (bodyBytes == null) {
            //空白的md5固定为:d41d8cd98f00b204e9800998ecf8427e
            bodyBytes = new byte[]{};
        }
        String bodyMd5 = UaaSignUtils.getMd5(bodyBytes);
        String ori = String.format("%s\n%s\n%s\n%s\n%s\n", method, clientId, headerStr, queryUri, bodyMd5);
        log.info("clientId:{},signSecret:{},headerStr:{},bodyMd5:{},queryUri:{},ori:{}", clientId, signSecret, headerStr, bodyMd5, queryUri, ori);
        return UaaSignUtils.sha256HMAC(ori, signSecret);
    }

    private void checkOnce(String nonce) {
        if (StringUtils.isBlank(nonce)) {
            log.warn("GatewaySignCheckFilter checkOnce Illegal");
        }
        String key = "api:auth:" + nonce;
        int fifteenMin = 60 * 15 * 1000;
        Boolean succ = redisUtil.setNxWithExpire(key, "1", fifteenMin);
        if (succ == null || !succ) {
            log.warn("GatewaySignCheckFilter checkOnce Repeat");
            throw new RuntimeException("checkOnce Repeat");
        }
    }


    private void checkTime(String timestamp) {
        long time;
        try {
            time = Long.parseLong(timestamp);
        } catch (Exception ex) {
            log.error("GatewaySignCheckFilter checkTime error:{}", ex.getMessage(), ex);
            throw new RuntimeException("checkTime error");
        }
        long now = DateTimeUtil.now();
        log.info("now: {}, time: {}", DateTimeUtil.millsToStr(now), DateTimeUtil.millsToStr(time));
        int fiveMinutes = 60 * 5 * 1000;
        long duration = now - time;
        if (duration > fiveMinutes || (-duration) > fiveMinutes) {
            log.warn("GatewaySignCheckFilter checkTime Late");
            throw new RuntimeException("checkTime Late");
        }
    }

    public interface SystemSign {
        /**
         * 客户端ID:固定值,由后端给前端颁发约定
         */
        String CLIENT_ID = "client-id";

        /**
         * 客户端计算出的签名
         */
        String SIGN = "sign";

        /**
         * 时间戳
         */
        String TIMESTAMP = "timestamp";

        /**
         * 唯一值
         */
        String NONCE = "nonce";
    }

}
举报

相关推荐

0 条评论