CAS OAuth2 源码分析

在工程中引入以下依赖,方便看代码:

                <!-- 开启oauth支持 -->
                <dependency>
                    <groupId>org.apereo.cas</groupId>
                    <artifactId>cas-server-support-oauth-webflow</artifactId>
                    <version>${cas.version}</version>
                </dependency>
                <dependency>
                    <groupId>org.apereo.cas</groupId>
                    <artifactId>cas-server-support-actions</artifactId>
                    <version>${cas.version}</version>
                </dependency>
                <dependency>
                    <groupId>org.jasig.cas.client</groupId>
                    <artifactId>cas-client-core</artifactId>
                    <version>3.4.1</version>
                </dependency>

CAS 对Oauth2的支持主要分5个包,最主要的是 cas-server-support-oauth

包名 功能
cas-server-support-oauth-api 4种token的接口
cas-server-support-oauth-core 实现了cas-server-support-oauth-api定义的4个接口
cas-server-support-oauth 关键工程,主要功能都在这里实现
cas-server-support-oauth-services 处理配置的service.JSON文件,定义不同的客户端用的
cas-server-support-oauth-webflow 流程处理

以下一个一个分析:

  • cas-server-support-oauth-api

定义了4个接口:

  1. OAuthToken 只声明了 一个获取认证信息的方法Authentication getAuthentication()
  2. AccessToken 继承OAuthToken 增加了个 权限范围的方法 Collection<String> getScopes(),该token有一个比较短的有效时间
  3. OAuthCode 只定义了一个变量,该code只能用一次,并且只有很短的一段时间有效
  4. RefreshToken 刷新token,当AccessToken 快过期了,用RefreshToken来获取一个新的AccessToken

  • cas-server-support-oauth-core

只有5个类:3个实现类实现了上面 cas-server-support-oauth-api 的3个接口 + 1个配置类 + 1个 OAuth20Constants

  1. OAuthProtocolTicketCatalogConfiguration 创建并配置上面3个接口的实现类
  2. OAuth20Constants 包含了一些OAuth变量,如:redirect_uri,response_type,grant_type
  3. OAuthCodeImpl
  4. AccessTokenImpl
  5. RefreshTokenImpl
    上面这几个类只是3个实体类,从类上注释来看是可以持久化到数据库的.没什么内容,类图如下:
    OAuthCodeImpl.png
  • cas-server-support-oauth

CAS实现Oauth2主要是在这里,该包中分几块,

  • 先分析怎么生成ticket,主要有下面这几个类

image.png

套路都是一样的,給三种token定义创建工厂.如:RefreshTokenFactory,有一个默认实现类DefaultRefreshTokenFactory ,主要是根据UniqueTicketIdGeneratorExpirationPolicy来生成token
ExpirationPolicy 的一个实现 OAuthRefreshTokenExpirationPolicyCasOAuthConfiguration 定义的时候会用到.

    private ExpirationPolicy refreshTokenExpirationPolicy() {
        return new OAuthRefreshTokenExpirationPolicy(casProperties.getAuthn().getOauth().getRefreshToken().getTimeToKillInSeconds());
    }

OAuthRefreshTokenExpirationPolicy 继承自 AbstractCasExpirationPolicy ,比其他的Policy 多了一个属性而已timeToKillInSeconds,从上面的配置从可以看到是为了方面配置文件配置的.

同理: OAuthCodeFactory 的默认token工厂实现DefaultOAuthCodeFactory,里面也用到了定制的过期策略OAuthCodeExpirationPolicy
获取配置的代码如下:

    private ExpirationPolicy oAuthCodeExpirationPolicy() {
        final OAuthProperties oauth = casProperties.getAuthn().getOauth();
        return new OAuthCodeExpirationPolicy(oauth.getCode().getNumberOfUses(), oauth.getCode().getTimeToKillInSeconds());
    }

CasOAuthConfiguration
同理: AccessTokenFactory 默认token工厂实现DefaultAccessTokenFactory和定制的过期策略OAuthAccessTokenExpirationPolicy

  • 3个过期策略的实现对比
类名 覆写的方法 构造器传递的参数 说明
OAuthRefreshTokenExpirationPolicy isExpired timeToKillInSeconds 添加了timeToKillInSeconds属性,获取系统的时间都是ZonedDateTime.now(ZoneOffset.UTC),注意时区
OAuthCodeExpirationPolicy numberOfUses timeToKillInMilliSeconds 和其他两个不一样,这个是继承自MultiTimeUseOrTimeoutExpirationPolicy,这个不仅受时间限制,而且也受使用次数限制
OAuthAccessTokenExpirationPolicy isExpired maxTimeToLiveInSeconds timeToKillInSeconds 这个受最长时间maxTimeToLiveInSeconds 限制,和创建时间对比;timeToKillInSeconds和上次使用时间对比

从代码中可以看到,isExpired方法的参数 TicketState是包含了Ticket的重要信息在里面,是过期策略的核心数据载体.而 上面的 OAuthCodeImpl等3个实体类就实现了TicketState接口. TicketState 有个update方法,实现在 AbstractTicket 中,代码如下:

    @Override
    public void update() {
        this.previousLastTimeUsed = this.lastTimeUsed;
        this.lastTimeUsed = ZonedDateTime.now(ZoneOffset.UTC);
        this.countOfUses++;
      //每次更新完token值都会更新 TicketGrantingTicket 值的相应信息
        if (getGrantingTicket() != null && !getGrantingTicket().isExpired()) {
            final TicketState state = TicketState.class.cast(getGrantingTicket());
            state.update();
        }
    }

从代码可以看到,每次更新完token值都会更新 TicketGrantingTicket 值的相应信息


  • 分析下4个Controller,如下图所示
image.png
类名 对应的端口 备注
OAuth20AccessTokenEndpointController /oauth2.0/accessToken 和 /oauth2.0/token 获取token
OAuth20AuthorizeEndpointController /oauth2.0/authorize 认证用户信息
OAuth20CallbackAuthorizeEndpointController /oauth2.0/callbackAuthorize
OAuth20UserProfileControllerController /oauth2.0/profile 返回用户信息,JSON格式

详细信息可见官网资料
4个Controller都是继承自BaseOAuth20Controller,入口很显然都是handleRequest

  • OAuth20AccessTokenEndpointController 分析
    @PostMapping(path = {OAuth20Constants.BASE_OAUTH20_URL + '/' + OAuth20Constants.ACCESS_TOKEN_URL,
            OAuth20Constants.BASE_OAUTH20_URL + '/' + OAuth20Constants.TOKEN_URL})
    public void handleRequest(final HttpServletRequest request, final HttpServletResponse response) throws Exception {
        try {
            response.setContentType(MediaType.TEXT_PLAIN_VALUE);
            //验证请求是否合法,各种认证方式时需要的参数是否存在,是否是配置的registerService
            if (!verifyAccessTokenRequest(request, response)) {
                LOGGER.error("Access token request verification failed");
                OAuth20Utils.writeTextError(response, OAuth20Constants.INVALID_REQUEST);
                return;
            }

            final AccessTokenRequestDataHolder responseHolder;
            try {
                //检查并且转换成  AccessTokenRequestDataHolder,系统中实现了4中Extractor,在初始化的时候初始化进来
                //该方法找到第一个支持的(调用其supports方法),一般也只有一个支持的
                //找到后调用对应的extract方法 转换成需要的 AccessTokenRequestDataHolder
                responseHolder = examineAndExtractAccessTokenGrantRequest(request, response);
                LOGGER.debug("Creating access token for [{}]", responseHolder);
            } catch (final Exception e) {
                LOGGER.error("Could not identify and extract access token request", e);
                OAuth20Utils.writeTextError(response, OAuth20Constants.INVALID_GRANT);
                return;
            }

            final J2EContext context = Pac4jUtils.getPac4jJ2EContext(request, response);
            //调用对应token的生成器 OAuth20DefaultTokenGenerator 生成token ,这个生成器会同时一对token,AccessToken 和RefreshToken
            final Pair<AccessToken, RefreshToken> accessToken = accessTokenGenerator.generate(responseHolder);
            LOGGER.debug("Access token generated is: [{}]. Refresh token generated is [{}]", accessToken.getKey(), accessToken.getValue());
            //用 OAuth20AccessTokenResponseGenerator 生成需要返回的格式数据,如果配置了jsonFormat 就会生成JSON格式,否则就是默认的text
            //这里直接将response传入了,没有返回值,直接做最后的响应
            generateAccessTokenResponse(request, response, responseHolder, context, accessToken.getKey(), accessToken.getValue());
            response.setStatus(HttpServletResponse.SC_OK);
        } catch (final Exception e) {
            LOGGER.error(e.getMessage(), e);
            throw new RuntimeException(e.getMessage(), e);
        }
    }

从上面可以看出,除了常规的校验和最后生成token外,比较有意思的就是那个Extractor了,看CAS是怎么实现将requestresponse 转换成(extract方法) AccessTokenRequestDataHolder
CAS中同样配套了4个Extractor,分别对应了我们4种场景.类图如下:

BaseAccessTokenGrantRequestExtractor.png

  • BaseAccessTokenGrantRequestExtractor: 这个抽象类定义了3个抽象方法,分别是:

    1. 获取是哪一种授权模式getGrantType方法;
    2. 是否支持getGrantType的授权模式的supports方法
    3. 真正干活的转换方法extract
  • AccessTokenAuthorizationCodeGrantRequestExtractor :
    关键代码:getOAuthTokenFromRequest 根据URL中的code参数值到ticketRegistry 中获取OAuthToken
    然后就直接new一个 对象了return new AccessTokenRequestDataHolder(token, registeredService, getGrantType(), isAllowedToGenerateRefreshToken(), scopes);

  • AccessTokenRefreshTokenGrantRequestExtractor
    这个类继承自 AccessTokenAuthorizationCodeGrantRequestExtractor,主要逻辑和上面这个一样,只是一些参数返回值不同. isAllowedToGenerateRefreshToken=false

  • AccessTokenPasswordGrantRequestExtractor
    同样,主要逻辑是在extract方法中,其他都是返回特定参数而已.

    1. AccessTokenAuthorizationCodeGrantRequestExtractor 一样,首先都是根据URL中的clientId得到配置的registeredService
    2. 比较特别的是会根据request, response得到 J2EContext ,这个是在pac4j中定义的,最终会得到 UserProfile ,这个profile在后面获取TGT时候有用
      怎么获取TGT的呢?关键是下面这几行代码
        // 根据 OAuth20CasAuthenticationBuilder 构造出  service,
        final Service service = this.authenticationBuilder.buildService(registeredService, context, requireServiceHeader);

        LOGGER.debug("Authenticating the OAuth request indicated by [{}]", service);
        // 生成 Authentication
        final Authentication authentication = this.authenticationBuilder.build(uProfile, registeredService, context, service);
        //确保 registeredService 合法
        RegisteredServiceAccessStrategyUtils.ensurePrincipalAccessIsAllowedForService(service, registeredService, authentication);      
        final AuthenticationResult result = new DefaultAuthenticationResult(authentication, requireServiceHeader ? service : null);
        //关键代码:生成TGT centralAuthenticationService 在 CasOAuthConfiguration 注入的
        final TicketGrantingTicket ticketGrantingTicket = this.centralAuthenticationService.createTicketGrantingTicket(result);

TicketGrantingTicket 在CAS中是保存时间最长的ticket,后续的token都可以根据这个来生成,有这个ticket就相当于在CAS中创建了会话.

  • AccessTokenClientCredentialsGrantRequestExtractor
    这个类继承自AccessTokenPasswordGrantRequestExtractor,没啥特别代码. 略...

  • OAuth20AuthorizeEndpointController

返回一个code或者accessToken

    @GetMapping(path = OAuth20Constants.BASE_OAUTH20_URL + '/' + OAuth20Constants.AUTHORIZE_URL)
    public ModelAndView handleRequest(final HttpServletRequest request, final HttpServletResponse response) throws Exception {
        final J2EContext context = Pac4jUtils.getPac4jJ2EContext(request, response);
        final ProfileManager manager = Pac4jUtils.getPac4jProfileManager(request, response);
        //验证请求是否合法,定义了接口 OAuth20RequestValidator 有很多实现类在 org.apereo.cas.support.oauth.validator 包下
        //和上面的套路一样接口定义了一个support方法,找到一个支持的就用它了,然后调用 validate 方法验证即可     
        if (!verifyAuthorizeRequest(context) || !isRequestAuthenticated(manager, context)) {
            LOGGER.error("Authorize request verification failed. Either the authorization request is missing required parameters, "
                    + "or the request is not authenticated and contains no authenticated profile/principal.");
            return OAuth20Utils.produceUnauthorizedErrorView();
        }

        final String clientId = context.getRequestParameter(OAuth20Constants.CLIENT_ID);
        final OAuthRegisteredService registeredService = getRegisteredServiceByClientId(clientId);
        try {
            //验证是否有该service的权限
            RegisteredServiceAccessStrategyUtils.ensureServiceAccessIsAllowed(clientId, registeredService);
        } catch (final Exception e) {
            LOGGER.error(e.getMessage(), e);
            return OAuth20Utils.produceUnauthorizedErrorView();
        }
        // 返回授权界面,confirm.html
        final ModelAndView mv = this.consentApprovalViewResolver.resolve(context, registeredService);
        if (!mv.isEmpty() && mv.hasView()) {
            return mv;
        }
        // 调用 OAuth20AuthorizationResponseBuilder 接口返回URL
        return redirectToCallbackRedirectUrl(manager, registeredService, context, clientId);
    }
OAuth20RequestValidator.png

  • OAuth20CallbackAuthorizeEndpointController

这个代码很少.主要是回调了一下,

final DefaultCallbackLogic callback = new DefaultCallbackLogic();
        callback.perform(context, oauthConfig, J2ENopHttpActionAdapter.INSTANCE, null, false, false);

  • OAuth20UserProfileControllerController
    @GetMapping(path = OAuth20Constants.BASE_OAUTH20_URL + '/' + OAuth20Constants.PROFILE_URL, produces = MediaType.APPLICATION_JSON_VALUE)
    public ResponseEntity<String> handleRequest(final HttpServletRequest request, final HttpServletResponse response) throws Exception {
        response.setContentType(MediaType.APPLICATION_JSON_VALUE);
        //从请求参数或者header中获取token值
        final String accessToken = getAccessTokenFromRequest(request);
        if (StringUtils.isBlank(accessToken)) {
            LOGGER.error("Missing [{}]", OAuth20Constants.ACCESS_TOKEN);
            return buildUnauthorizedResponseEntity(OAuth20Constants.MISSING_ACCESS_TOKEN);
        }
        //转换成内部的token,验证其是否过期
        final AccessToken accessTokenTicket = this.ticketRegistry.getTicket(accessToken, AccessToken.class);
        if (accessTokenTicket == null || accessTokenTicket.isExpired()) {
            LOGGER.error("Expired/Missing access token: [{}]", accessToken);
            return buildUnauthorizedResponseEntity(OAuth20Constants.EXPIRED_ACCESS_TOKEN);
        }
        //验证对应的TGT是否有效
        final TicketGrantingTicket ticketGrantingTicket = accessTokenTicket.getGrantingTicket();
        if (ticketGrantingTicket == null || ticketGrantingTicket.isExpired()) {
            LOGGER.error("Ticket granting ticket [{}] parenting access token [{}] has expired or is not found", ticketGrantingTicket, accessTokenTicket);
            this.ticketRegistry.deleteTicket(accessToken);
            return buildUnauthorizedResponseEntity(OAuth20Constants.EXPIRED_ACCESS_TOKEN);
        }
        //更新token的状态,修改最后使用时间等
        updateAccessTokenUsage(accessTokenTicket);
        //根据token得到认证信息,返回map
        final Map<String, Object> map = writeOutProfileResponse(accessTokenTicket);
        //如果是Oauth2的就添加另外两个值,client_id和service
        finalizeProfileResponse(accessTokenTicket, map);
        //处理返回的用户信息,JSON格式,OAuth20DefaultUserProfileViewRenderer 处理了FLAT 和 NESTED 模式
        final String value = this.userProfileViewRenderer.render(map, accessTokenTicket);
        return new ResponseEntity<>(value, HttpStatus.OK);
    }

  • cas-server-support-oauth-services

    只有一个类 OAuthRegisteredService,用于处理配置不同的客户端services.json文件,管理其密钥和ID等相关信息,继承自RegexRegisteredService

  • cas-server-support-oauth-webflow

    3个类:
  1. CasOAuthWebflowConfiguration 配置两个bean oauth20LogoutWebflowConfigureroauth20RegisteredServiceUIAction
  2. OAuth20RegisteredServiceUIAction .
  3. OAuth20WebflowConfigureroauth20RegisteredServiceUIAction加入到登录流程当中

tips :
CasWebflowConstants.STATE_ID_VIEW_LOGIN_FORM ->String STATE_ID_VIEW_LOGIN_FORM = "viewLoginForm"; 流程stateviewLoginForm 对应界面casLoginView.html:<view-state id="viewLoginForm" view="casLoginView"


  • 杂项:

利用org.apache.commons.lang3.builder 包中的方法覆写下面OAuthRegisteredService的这三个方法
@Entity 都要覆写这三个方法.

  • 覆写 toString方法
    @Override
    public String toString() {
        final ToStringBuilder builder = new ToStringBuilder(this);
        builder.appendSuper(super.toString());
        builder.append("clientId", getClientId());
        builder.append("approvalPrompt", isBypassApprovalPrompt());
        builder.append("generateRefreshToken", isGenerateRefreshToken());
        builder.append("jsonFormat", isJsonFormat());
        builder.append("supportedResponseTypes", getSupportedResponseTypes());
        builder.append("supportedGrantTypes", getSupportedGrantTypes());

        return builder.toString();
    }
    /**
     * Build a normalized "toString" text for an object.
     *CommonHelper.toString(this.getClass(), "size", size, "timeout", timeout, "timeUnit", timeUnit)
     * @param clazz class
     * @param args  arguments
     * @return a normalized "toString" text
     */
    public static String toString(final Class<?> clazz, final Object... args) {
        final StringBuilder sb = new StringBuilder();
        sb.append("#");
        sb.append(clazz.getSimpleName());
        sb.append("# |");
        boolean b = true;
        for (final Object arg : args) {
            if (b) {
                sb.append(" ");
                sb.append(arg);
                sb.append(":");
            } else {
                sb.append(" ");
                sb.append(arg);
                sb.append(" |");
            }
            b = !b;
        }
        return sb.toString();
    }

  • 覆写 equals方法

    @Override
    public boolean equals(final Object obj) {
        if (obj == null) {
            return false;
        }
        if (obj == this) {
            return true;
        }
        if (obj.getClass() != getClass()) {
            return false;
        }
        final OAuthRegisteredService rhs = (OAuthRegisteredService) obj;
        final EqualsBuilder builder = new EqualsBuilder()
                .appendSuper(super.equals(obj))
                .append(this.clientSecret, rhs.clientSecret)
                .append(this.clientId, rhs.clientId)
                .append(this.bypassApprovalPrompt, rhs.bypassApprovalPrompt)
                .append(this.generateRefreshToken, rhs.generateRefreshToken)
                .append(this.jsonFormat, rhs.jsonFormat)
                .append(this.supportedResponseTypes, rhs.supportedResponseTypes)
                .append(this.supportedGrantTypes, rhs.supportedGrantTypes);
        
        return builder.isEquals();
    }
  • 覆写 hashCode 方法
    @Override
    public int hashCode() {
        return new HashCodeBuilder()
                .appendSuper(super.hashCode())
                .append(this.clientSecret)
                .append(this.clientId)
                .append(this.bypassApprovalPrompt)
                .append(this.generateRefreshToken)
                .append(this.jsonFormat)
                .append(this.supportedResponseTypes)
                .append(this.supportedGrantTypes)
                .toHashCode();
    }
    @Override
    public int hashCode() {
        return new HashCodeBuilder(13, 133).append(this.getId()).toHashCode();
    }
  • 覆写 compareTo方法
    @Override
    public int compareTo(final RegisteredService other) {
        return new CompareToBuilder()
                .append(getEvaluationOrder(), other.getEvaluationOrder())
                .append(StringUtils.defaultIfBlank(getName(), StringUtils.EMPTY).toLowerCase(),
                        StringUtils.defaultIfBlank(other.getName(), StringUtils.EMPTY).toLowerCase())
                .append(getServiceId(), other.getServiceId())
                .append(getId(), other.getId())
                .toComparison();
    }

推荐阅读更多精彩内容