自定义注解 - 实现接口频率限制

1,自定义注解

/**
 * Created by hzq on 2017/5/24.
 */
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
@Documented
@Order(Ordered.HIGHEST_PRECEDENCE)
public @interface RequestLimit {

    /**
     * 调用的次数
     * @return
     */
    int count() default Integer.MAX_VALUE;

    /**
     * 时间段; 在time内调用的次数count
     * @return
     */
    int time() default 60000;

}

2,切面通知实现类。

/**
 * Created by hzq on 2017/5/24.
 */
@Aspect
@Component
public class RequestLimitContract {

    private static Logger logger = LoggerFactory.getLogger(RequestLimitContract.class);

    @Autowired
    private StringRedisTemplate stringRedisTemplate;

    @Around("within(@org.springframework.stereotype.Controller *) && @annotation(limit)")
    public Object requestLimit(final ProceedingJoinPoint joinPoint, RequestLimit limit){
        Map<String, Object> map = Maps.newHashMap();

        try {
            Object[] args = joinPoint.getArgs();
            HttpServletRequest request = null;
            for(int i = 0; i < args.length; i++){
                if(args[i] instanceof HttpServletRequest){
                    request = (HttpServletRequest) args[i];
                    break;
                }
            }

            if(request == null){
                throw new RuntimeException("request is valid");
            }

            String ip = NginxUtils.getRealIp(request);
            String url = request.getRequestURI();
            String key = "req_limit_" + url + "_" + ip;
            BoundValueOperations<String, String> ops = stringRedisTemplate.boundValueOps(key);

            String countString = ops.get();
            if(countString == null || countString.equals("")){
                logger.info("ip:" + ip + ", first request within 1s");
                ops.set("0", limit.time(), TimeUnit.MILLISECONDS);
            }

            long count;
            try {
                count = Long.valueOf(ops.get());
            } catch (Throwable t){
                count = 0;
            }

            if(count > limit.count()){
                logger.info("url:\t" + url + ",  ip:\t" + ip + " limit request");
                map.put("success", false);
                map.put("message", "the ip: " + ip + " is up to the limit " + limit.count() + " within " + limit.time()/1000 + "s");

                return map;
            }

            Long milSeconds = stringRedisTemplate.getExpire(key, TimeUnit.MILLISECONDS);
            if(milSeconds == null || milSeconds <= 0){
                ops.set("0", limit.time(), TimeUnit.MILLISECONDS);
            }

            ops.increment(1);
            if(milSeconds != null && milSeconds > 0){
                ops.expire(milSeconds, TimeUnit.MILLISECONDS);
            }

            return joinPoint.proceed();
        } catch (Throwable t){
            logger.info(t.getMessage(), t);
            throw new RuntimeException(t.getMessage());
        }
    }

}

3,使用注解。

    @RequestMapping(value = "/api/info", method = RequestMethod.POST)
    @ResponseBody
    @RequestLimit(count = 10, time = 1000)
    public Map<String, Object> getApiInfo(HttpServletRequest request){
        Map<String, Object> map = Maps.newHashMap();
        try {
            //service
            map.put("success", true);
            map.put("info", uidAndRoleList);
        } catch(Throwable t){
            logger.info(t.getMessage(), t);
            map.put("success", false);
            map.put("message", t.getMessage());
        }

        return map;
    }

4,另一种使用拦截器的方式

@Component
public class RequestLimitInterceptor implements HandlerInterceptor{

    private static Logger logger = LoggerFactory.getLogger(RequestLimitInterceptor.class);

    @Autowired
    private StringRedisTemplate stringRedisTemplate;

    @Value("${request.limit.count}")
    private String limitCount;

    @Value("${request.limit.time}")
    private String limitTime;

    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {

        try {
            if(request == null){
                throw new RuntimeException("request is valid");
            }

            String ip = request.getRemoteAddr();
            String url = request.getRequestURI();
            String key = "req_limit_" + url + "_" + ip;
            long count = stringRedisTemplate.boundValueOps(key).increment(1);
            if(count == 1){
                stringRedisTemplate.boundValueOps(key).expire(Integer.valueOf(limitTime), TimeUnit.MILLISECONDS);
            }

            if(count > Integer.valueOf(limitCount)){
                logger.info("url:\t" + url + ",  ip:\t" + ip + " limit request");
                throw new RuntimeException("the ip:\t" + ip + " is up to the limit " + limitCount + " within " + Integer.valueOf(limitTime)/1000 + " s");
            }
        } catch (Throwable t){
            logger.info(t.getMessage(), t);
            throw new RuntimeException(t.getMessage());
        }

        return true;
    }

    @Override
    public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler, ModelAndView modelAndView) throws Exception {

    }

    @Override
    public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {

    }
}

推荐阅读更多精彩内容

  • Spring Cloud为开发人员提供了快速构建分布式系统中一些常见模式的工具(例如配置管理,服务发现,断路器,智...
    卡卡罗2017阅读 87,421评论 13 122
  • Spring Boot微服务 http://docs.spring.io/spring-boot/docs/cur...
    鲁云飞_阅读 1,704评论 0 4
  • 什么是AOP? AOP(Aspect Orient Programming),也就是面向切面编程,作为面向对象编程...
    原来蜗牛不是牛阅读 74评论 0 2
  • 概述 Spring4是一套JAVA的MVC框架,经过一系列的自动化改良,如今变得非常简单易用。Spring4框架的...
    胖头鱼战士阅读 802评论 0 7
  • 最近参加了很多面试,其实从我大一开始就尝试过很多面试了,但是最后都接收到同样的短信:很抱歉,同学,您并没有通过面试...
    赤木麟子阅读 82评论 1 0