Oauth2 + JWT登出(一)(白名单方案)

说明

JWT的缺点上一篇已经讲了,不知道的请看上一篇。上篇已经说了,同一个用户id,调两次登录接口,如果这个token没过期,那应该返回同一个token,但是JWT是两次返回的都不是一个token,这样会多消耗资源,其实正常的场景是,只要token没过期,不管调多少次登录接口返回的都是同一个token。
每次登录把token都放Redis里进行白名单处理,资源接口去判断是否在Redis里,如果不在白名单里的都认为过期或不合法的token,就不让登录,这样还能做在线踢人功能,而且后期的续签功能也能在这个基础上做。Redis key 这里选userId,因为userId是唯一的,当然选jti也是一样的。

方案

以下方案都是基于Redis

  1. 登录前先判断当前userId的key是否在redis里,如果在,不走后续登录鉴权流程,直接返回token。否则走后续登录鉴权流程。
  2. 登录后基于切面(@AfterReturning)拿到返回结果,判断userId的key是否在redis里,如果在,则返回redis里的token。否则走原流程,返回Oauth2新生成的token。
  3. 覆盖JwtTokenStore.storeAccessToken()方法。方法内先判断userId的key是否在redis里,如果在,则改变形参OAuth2AccessToken值,返回redis里的token传到外面去(这里注意:必须redis存的是token的byte[])

上面推荐使用方案3,成本最小

方案一

有几种选择,可以使用过滤器、拦截器、AOP。过滤器范围太大,因为只需要过滤登录的接口(包含老登录接口和新登录接口),其他接口都不用管,所以这里选择拦截器或者AOP,指定过滤路径或方法。
因为老登录接口是@RequestBody方式传的参数形式,所以如果用拦截器,只能通过request去拿body参数,如下:

Map<String,Object> params = new HashMap<String, Object>();
        BufferedReader br;
        try {
            br = request.getReader();
            String str, wholeStr = "";
            while((str = br.readLine()) != null){
                wholeStr += str;
            }
            if(StringUtils.isNotEmpty(wholeStr)){
                params = JSON.parseObject(wholeStr,Map.class);
            }
        } catch (IOException e1) {
            logger.error(""+e1);
        }

但是一个流不能读两次异常,这种异常一般出现在框架或者拦截器中读取了request中的流的数据,我们在业务代码中再次读取(如@requestBody),由于流中的数据已经没了,所以第二次读取的时候就会抛出异常。

拦截器伪代码如下(不考虑二次流问题,这里直接从parameter拿出来):

/**
 * lbj
 * 登录之前判断token是否存在Redis中(白名单),如果存在,则直接返回;否则进行登录
 */
@Component
public class BeforeLoginAuthenticationFilter extends OncePerRequestFilter {
    @Autowired
    private RedisUtil redisUtil;

    @Autowired
    private ObjectMapper objectMapper;

    @Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
        String username = request.getParameter("username");
        User user = findbyUsername(username);
        String userId = user .getUserId();
        Object value = redisUtil.get(userIdRedisKey(userId));
        if (value != null) {
            ResponseUtil.responseSucceed(objectMapper, response, value);
            return;
        }
    }
    private String userIdRedisKey(String userId) {
        return SecurityConstants.CACHE_EXPIRE_TOKEN_WHITELIST + ":" + userId;
    }
}

AOP伪代码如下(跟上面效果一样,两种实现方式):

    @Around("execution(* com.dhgate.saas.uaa.controller.Oauth2Controller.userTokenInfo(..))")
    public Object doAround(ProceedingJoinPoint pjp) throws Throwable{
        log.info("OauthTokenAspect.doAround 开始...");
        ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
        HttpServletRequest request = Objects.requireNonNull(attributes).getRequest();
        HttpServletResponse response = Objects.requireNonNull(attributes).getResponse();
        String username = request.getParameter("username");
        User user = findbyUsername(username);
        String userId = user .getUserId();
        LoginTokenDto tokenDto = (LoginTokenDto) redisUtil.get(userIdRedisKey(userId));
        if (tokenDto != null) {
            tokenDto.setExpires_in(redisUtil.getExpire(userIdRedisKey(userId)));
            ResponseUtil.responseSucceed(objectMapper, response, tokenDto);
            return true;
        }
        Object proceed = pjp.proceed();
        return proceed;
    }

实现JwtTokenStore重写storeAccessToken()和removeAccessToken()方法。storeAccessToken()方法是把token存入redis;removeAccessToken()方法是把token从redis删除(登录接口会调这个方法),代码如下(该类在CustomJwtTokenStore):

    @Override
    public void storeAccessToken(OAuth2AccessToken token, OAuth2Authentication authentication) {
        LoginTokenDto loginTokenDto = (LoginTokenDto) redisUtil.get(userIdRedisKey(userId));
        if(Objects.isNull(loginTokenDto)){
            LoginTokenDto transFerToken = LoginTokenDto.builder().access_token(token.getValue())
                    .token_type(token.getTokenType())
                    .refresh_token(token.getRefreshToken().getValue())
//                    .expires_in(Long.valueOf(token.getExpiresIn()))
                    .scope(String.join(",", token.getScope()))
                    .userId(token.getAdditionalInformation().get("userId").toString())
                    .jti(token.getAdditionalInformation().get("jti").toString())
                    .build();
            redisUtil.set(userIdRedisKey(userId), transFerToken, token.getExpiresIn());
        }
        super.storeAccessToken(token, authentication);
    }

    @Override
    public void removeAccessToken(OAuth2AccessToken token) {
        if (token.getAdditionalInformation().containsKey("userId")) {
            String userId = token.getAdditionalInformation().get("userId").toString();
            RedisUtil redisUtil = Optional.ofNullable(SpringContextUtil.getBean(RedisUtil.class))
                    .orElseThrow(() -> new InnerException("internal bean acquisition failed."));
            redisUtil.del(userIdRedisKey(userId));
        }
        super.removeAccessToken(token);
    }

最后每个资源服务器拿token前,先去redis中判断userId的key是否存在,如果白名单不存在 或者 拿到的参数map中的jti和redis中token携带的jti不等,标识为不是一次token, 则直接返回"Invalid token!"异常标识。如果userId的key在白名单内,则继续后续流程。

方法extractAuthentication()在ResJwtAccessTokenConverter.JwtUserAuthenticationConverter.JWTfaultUserAuthenticationConverter中,部分代码如下:
下面用到了黑名单三种方案中的方案二,当然按需选择用哪个。

public Authentication extractAuthentication(Map<String, ?> map) {
                                 //白名单校验逻辑开始---------------------------
                RedisUtil redisUtil = SpringContextUtil.getBean(RedisUtil.class);
                if(map.containsKey("userId") && map.containsKey("jti")){
                    String userId = (String) map.get("userId");
                    String jti = (String) map.get("jti");
                    String tokenByteStr = (String) redisUtil.get(userIdRedisKey(userId));
                    if(!StringUtils.isEmpty(tokenByteStr)){
                        byte[] tokenByte = Base64.getDecoder().decode(tokenByteStr);
                        DefaultOAuth2AccessToken token = (DefaultOAuth2AccessToken) deserializeAccessToken(tokenByte);
                        if (token == null || !Objects.equals(jti, token.getAdditionalInformation().get("jti"))) {
                            throw new InvalidTokenException("Invalid token!");
                        }
                    }else{
                        throw new InvalidTokenException("Invalid token!");
                    }
                }
                            //白名单校验逻辑结束(以下跟白名单无关)---------------------------

                if (map.containsKey("user_info")) {
                    Object principal = map.get("user_info");
                //  Collection<? extends GrantedAuthority> authorities = getAuthorities(map);
                    LoginAppUser loginUser = new LoginAppUser();
                    if (principal instanceof Map) {
                        loginUser = BeanUtil.mapToBean((Map) principal, LoginAppUser.class, true);
                    }
                    return new UsernamePasswordAuthenticationToken(loginUser, "N/A", loginUser.getAuthorities());
                }
                return null;
            }

方案二

基于方案一反向思考,方案一是在登录前,方案二是在调用登录接口后,拿到Oauth2生成的token,然后去判断是否在Redis中,如果在,则返回Redis中的,Oauth2生成的token则达到一种“胎死腹中”的效果。如果不在,返回Oauth2新生成token。

这里可以用Spring AOP去做,因为Spring AOP能方便拿到返回值,而拦截器稍微麻烦些,代码如下:

    @AfterReturning(value = "pointcut()", returning = "rvt")
    public void afterReturning(JoinPoint joinPoint, Object rvt) throws Throwable {
        ServletRequestAttributes attributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
        HttpServletResponse response = Objects.requireNonNull(attributes).getResponse();
        Result result = (Result) rvt;
        if(result.getSuccess()){
            DefaultOAuth2AccessToken oAuth2AccessToken = (DefaultOAuth2AccessToken) result.getData();
            Map<String, Object> map = oAuth2AccessToken.getAdditionalInformation();
            if(map.containsKey("userId")){
                String userId = (String) map.get("userId");
                LoginTokenDto tokenDto = (LoginTokenDto) redisUtil.get(userIdRedisKey(userId));
                if (tokenDto != null) {
                    tokenDto.setExpires_in(redisUtil.getExpire(userIdRedisKey(userId)));
                    ResponseUtil.responseSucceed(objectMapper, response, tokenDto);
                    return;
                }
            }
        }
    }

其他地方改动类似方案一,修改storeAccessToken()方法。

方案三

第三种方案不需要过滤器、拦截器和AOP,直接在storeAccessToken()方法判断redis中是否有token,如果有,直接把redis中token(也是一个OAuth2AccessToken)去覆盖老的OAuth2AccessToken即可。值得注意的是,对象默认传递是引用传递,要改变整个对象。如果直接写 token = tokenSource; 是不行的,因为这不能改变对象的指针,默认拿到的还是之前的对象,想改变对象里面值,有两种办法,1.直接set属性(但是这样一个一个set比较麻烦) 2.直接深复制。

注意:这里存入的redis是一个byte[]流,用的是JdkSerializationStrategy的序列化方式(RedisTokenStore就是用的这种方式)。好处是时间不是存redis死的,然后最后解析反序列化也是OAuth2AccessToken对象。

    @Override
    public void storeAccessToken(OAuth2AccessToken token, OAuth2Authentication authentication) {
        // 这里token 和 accessToken 是一个东西,打断点,看他们的内存地址,都是一个
        DefaultOAuth2AccessToken accessToken = (DefaultOAuth2AccessToken) token;
        if (token.getExpiresIn() > 0 && token.getAdditionalInformation().containsKey("userId")) {
            String userId = token.getAdditionalInformation().get("userId").toString();
            RedisUtil redisUtil = Optional.ofNullable(SpringContextUtil.getBean(RedisUtil.class))
                    .orElseThrow(() -> new InnerException("internal bean acquisition failed."));

            String tokenByteStr = (String) redisUtil.get(userIdRedisKey(userId));
            if(!StringUtils.isEmpty(tokenByteStr)){
                byte[] tokenByte = Base64.getDecoder().decode(tokenByteStr);
                DefaultOAuth2AccessToken tokenSource = (DefaultOAuth2AccessToken) deserializeAccessToken(tokenByte);

                //对象默认传递是引用传递,要改变整个对象。如果直接写 token = tokenSource; 是不行的,因为这不能改变对象的指针,默认拿到的还是之前的对象
                //想改变对象里面值,有两种办法,1.直接set属性(但是这样一个一个set比较麻烦) 2.直接深复制,就是下面这句话
                BeanUtils.copyProperties(tokenSource, accessToken);
            }else{
                byte[] serializedAccessToken = this.serialize((Object)token);
                String encoded = Base64.getEncoder().encodeToString(serializedAccessToken);
                redisUtil.set(userIdRedisKey(userId), encoded, token.getExpiresIn());
            }
        }
        super.storeAccessToken(accessToken, authentication);
    }
禁止转载,如需转载请通过简信或评论联系作者。

推荐阅读更多精彩内容