分布式锁可以这么简单?

本文只讨论如何基于已实现了分布式锁的第三方框架进行二次封装,减少分布式锁的使用成本,而且当需要替换分布式锁实现时,只需要少量代码的调整,比如只需修改配置文件即可完成改造。
另外,本文是对另一篇文章中的实现的优化版,但主要思想是一致的。见 使用Redisson实现分布式锁,Spring AOP简化之

源码

在开始之前,我们先回想一个比较经典的场景——超卖,而解决超卖的一个方案就是 加锁,真正扣减库存之前,必须拿到对应的锁,下面来看一段示例代码,其中锁的实现是借助了 Redisson

@Transactional(rollbackFor = Throwable.class)
public void seckill(Long itemId, int purchaseCount) {
    RLock lock = redissonClient.getLock("item:" + itemId);
    boolean locked = false;

    try {
        locked = lock.tryLock(5000, TimeUnit.MILLISECONDS);
        if (locked) {
            doSeckill(itemId, purchaseCount);
        }

        
    } catch (InterruptedException e) {
        throw new RuntimeException("无法在指定时间内获得锁");
    } finally {
        // 是否未获得锁
        if (!locked) {
            throw new RuntimeException("尝试获取锁超时, 锁获取失败");
        }

        if (lock.isHeldByCurrentThread()) {
            lock.unlock();
        } else {
            throw new RuntimeException("锁释放失败, 当前线程不是锁的持有者");
        }
    }
}

public void doSeckill(Long itemId, int purchaseCount) {
    // 获取库存
    // 比较并扣减库存
    // 更新库存
    // 异步执行其他逻辑
}

上面的代码,看着也算还好,不太复杂,但其实除了真正的业务逻辑 doSeckill 外,其他的都是一些可模板化的代码结构。想象一下如果每一处使用分布式锁的地方都要写这堆东西,那还得了,更可怕的是,如果要换一种分布式锁的实现方式呢?

既然我们能预想到这种写法有这么严重的弊端,那就得想办法优化。最好能够实现:只需要关注真正的业务逻辑,需要使用分布式锁时,只需增加少量代码即可,比如加个注解;另外,如果需要更换分布式锁的实现方式,也不需要改任何代码。怎么实现呢?这就是这篇文章要解决的问题了。通过优化后,上面的示例代码将可以变得如下这么简单:

@DistributedLock(lockName = "#{#itemId}", lockNamePre = "item")
public void doSeckill(Long itemId, int purchaseCount) {
    // 获取库存
    // 比较并扣减库存
    // 更新库存
    // 异步执行其他逻辑
}

下面,正式开始~~~

Redisson概述

可参考另一篇文章的 Redisson概述

使用Redisson实现分布式锁

1. 定义回调接口

/**
 * 分布锁回调接口
 */
public interface DistributedLockCallback<T> {

    /**
     * 调用者必须在此方法中实现需要加分布式锁的业务逻辑
     *
     * @return
     */
    public T process() throws Throwable;

    /**
     * 得到分布式锁名称
     *
     * @return
     */
    public String getLockName();

}

定义分布式具体实现的锁模板接口

/**
 * 分布式锁具体实现的模板接口
 *
 * @author sprainkle
 * @date 2019.04.20
 */
public interface DistributedLockTemplate {

    /** 尝试获取锁的默认等待时间 */
    long DEFAULT_WAIT_TIME = DistributedLock.DEFAULT_WAIT_TIME;
    /** 锁的默认超时时间. 超时后, 锁会被自动释放 */
    long DEFAULT_TIMEOUT = DistributedLock.DEFAULT_WAIT_TIME;
    /** 时间单位。默认为毫秒。 */
    TimeUnit DEFAULT_TIME_UNIT = DistributedLock.DEFAULT_TIME_UNIT;
    /** 获得锁名时拼接前后缀用到的分隔符 */
    String DEFAULT_SEPARATOR = DistributedLock.DEFAULT_SEPARATOR;
    /** lockName后缀 */
    String LOCK = DistributedLock.LOCK;

    /**
     * 使用分布式锁,使用锁默认超时时间。
     *
     * @param callback
     * @param fairLock 是否使用公平锁
     * @return
     */
    <T> T lock(DistributedLockCallback<T> callback, boolean fairLock);

    /**
     * 使用分布式锁。自定义锁的超时时间
     *
     * @param callback
     * @param leaseTime 锁超时时间。超时后自动释放锁。
     * @param timeUnit
     * @param fairLock  是否使用公平锁
     * @return
     */
    <T> T lock(DistributedLockCallback<T> callback, long leaseTime, TimeUnit timeUnit, boolean fairLock);

    /**
     * 尝试分布式锁,使用锁默认等待时间、超时时间。
     *
     * @param callback
     * @param <T>
     * @param fairLock 是否使用公平锁
     * @return
     */
    <T> T tryLock(DistributedLockCallback<T> callback, boolean fairLock);

    /**
     * 尝试分布式锁,自定义等待时间、超时时间。
     *
     * @param callback
     * @param waitTime  获取锁最长等待时间
     * @param leaseTime 锁超时时间。超时后自动释放锁。
     * @param timeUnit
     * @param <T>
     * @param fairLock  是否使用公平锁
     * @return
     */
    <T> T tryLock(DistributedLockCallback<T> callback, long waitTime, long leaseTime, TimeUnit timeUnit, boolean fairLock);

    /**
     * 锁是否由当前线程持有
     *
     * @param lock
     * @return
     */
    boolean isHeldByCurrentThread(Object lock);

}

基于 Redisson 实现

/**
 * Base DistributedLockTemplate. 用于封装一些公共方法
 */
public abstract class AbstractDistributedLockTemplate implements DistributedLockTemplate {
    /**
     * 处理业务逻辑
     *
     * @param callback
     * @param <T>
     * @return 业务逻辑处理结果
     */
    protected <T> T process(DistributedLockCallback<T> callback) {
        try {
            return callback.process();
        } catch (Throwable e) {
            if (e instanceof BaseException) {
                throw (BaseException) e;
            }

            throw new BaseException(CommonResponseEnum.SERVER_ERROR, null, e.getMessage(), e);
        }
    }
}
@Slf4j
public class RedisDistributedLockTemplate extends AbstractDistributedLockTemplate {

    private final RedissonClient redisson;
    private final String namespace;

    private final long lockTimeoutMs;
    private final long waitTimeoutMs;

    // 锁前缀
    private final String lockPrefix;
    // 锁后缀
    private final String lockPostfix;

    public RedisDistributedLockTemplate(RedissonClient redisson, DistributedLockProperties properties) {
        this.redisson = redisson;
        this.namespace = properties.getNamespace();
        this.lockTimeoutMs = Optional.ofNullable(properties.getLockTimeoutMs()).orElse(DEFAULT_TIMEOUT);
        this.waitTimeoutMs = Optional.ofNullable(properties.getWaitTimeoutMs()).orElse(DEFAULT_WAIT_TIME);

        this.lockPrefix = namespace + DEFAULT_SEPARATOR;
        this.lockPostfix = ".lock";
    }

    @Override
    public <T> T lock(DistributedLockCallback<T> callback, boolean fairLock) {
        return lock(callback, lockTimeoutMs, DEFAULT_TIME_UNIT, fairLock);
    }

    @Override
    public <T> T lock(DistributedLockCallback<T> callback, long leaseTime, TimeUnit timeUnit, boolean fairLock) {
        RLock lock = getLock(callback.getLockName(), fairLock);

        try {
            lock.lock(leaseTime, timeUnit);
            return process(callback);
        } finally {
            if (lock != null && lock.isHeldByCurrentThread()) {
                lock.unlock();
            }
        }
    }

    @Override
    public <T> T tryLock(DistributedLockCallback<T> callback, boolean fairLock) {
        return tryLock(callback, waitTimeoutMs, lockTimeoutMs, DEFAULT_TIME_UNIT, fairLock);
    }

    @Override
    public <T> T tryLock(DistributedLockCallback<T> callback,
                         long waitTime,
                         long leaseTime,
                         TimeUnit timeUnit,
                         boolean fairLock) {
        RLock lock = getLock(callback.getLockName(), fairLock);
        boolean locked = false;

        DistributedLockContext context = DistributedLockContextHolder.getContext();
        context.setLock(lock);

        try {
            locked = lock.tryLock(waitTime, leaseTime, timeUnit);
            if (locked) {
                return process(callback);
            }
        } catch (InterruptedException e) {
            ResponseEnum.LOCK_NOT_YET_HOLD.assertFailWithMsg("无法在指定时间内获得锁", e);
        } finally {
            // 是否未获得锁
            if (!locked) {
                ResponseEnum.LOCK_NOT_YET_HOLD.assertFailWithMsg("尝试获取锁超时, 获取失败.");
            }

            if (lock.isHeldByCurrentThread()) {
                lock.unlock();
            } else {
                log.warn("锁释放失败, 当前线程不是锁的持有者");
            }
        }
        return null;
    }

    public RLock getLock(String lockName, boolean fairLock) {
        RLock lock;
        lockName = lockPrefix + lockName + lockPostfix;
        if (fairLock) {
            lock = redisson.getFairLock(lockName);
        } else {
            lock = redisson.getLock(lockName);
        }
        return lock;
    }

    @Override
    public boolean isHeldByCurrentThread(Object lock) {
        if (!(lock instanceof RLock)) {
            return false;
        }

        return ((RLock) lock).isHeldByCurrentThread();
    }
}

简单使用 RedisDistributedLockTemplate

DistributedLockTemplate lockTemplate = ...;
final String lockName = ...; 
lockTemplate.lock(new DistributedLockCallback<Object>() {
    @Override
    public Object process() {
        //do some business
        return null;
    }

    @Override
    public String getLockName() {
        return lockName;
    }
}, false);

会不会还是很麻烦?

虽说比使用原生 Redisson 时简单一点点,但是每次使用分布式锁都要写类似上面的重复代码,还是不够优雅。有没有什么方法可以只关注核心业务逻辑代码的编写,即上面的"do some business"。下面介绍如何使用Spring AOP来实现这一目标。

使用 Spring AOP 进一步封装

定义注解 DistributedLock

/**
 * <pre>
 *     可以使用该注解实现分布式锁。
 *
 *     获取lockName的优先级为:lockName > argNum > param
 *
 *     使用的是公平锁, 即先来先得.
 * </pre>
 */
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface DistributedLock {
    /**
     * 锁的名称。
     * 如果lockName可以确定,直接设置该属性。
     * <br><br>
     * 支持 SpEL, 格式为: #{expression}, 内置 #root, 属性包括: target, method, args 等, 其中 target 为注解所在类的 Spring Bean
     * 也支持 占位符 ${}
     */
    String lockName() default "";

    /**
     * lockName 前缀
     */
    String lockNamePre() default "";

    /**
     * lockName 后缀
     * @see #LOCK
     */
    String lockNamePost() default "";

    /**
     * 在开始加锁前, 执行某个方法进行校验.
     * <br><br>
     * 支持 SpEL, 格式为: #{expression}, 所以方法必须为 public, 如果方法的所在 Spring Bean 与注解的方法相同,
     * 写法为: #{#root.target.yourMethod(#param1, #param2, ...)}
     *
     * @return
     */
    String checkBefore() default "";

    /**
     * 获得锁名时拼接前后缀用到的分隔符
     * @see #DEFAULT_SEPARATOR
     */
    String separator() default DEFAULT_SEPARATOR;

    /**
     * <pre>
     *     获取注解的方法参数列表的某个参数对象的某个属性值来作为lockName。因为有时候lockName是不固定的。
     *     当param不为空时,可以通过argSeq参数来设置具体是参数列表的第几个参数,不设置则默认取第一个。
     * </pre>
     */
    String param() default "";

    /**
     * 将方法第argSeq个参数作为锁名. 0为无效值.
     */
    int argSeq() default 0;

    /**
     * 是否使用公平锁。
     * 公平锁即先来先得。
     */
    boolean fairLock() default false;

    /**
     * 是否使用尝试锁。
     */
    boolean tryLock() default true;

    /**
     * 最长等待时间。
     * 该字段只有当tryLock()返回true才有效。
     * @see #DEFAULT_WAIT_TIME
     */
    long waitTime() default DEFAULT_WAIT_TIME;

    /**
     * <pre>
     *     锁超时时间。
     *     超时时间过后,锁自动释放。
     *     建议:
     *       尽量缩简需要加锁的逻辑。
     * </pre>
     * @see #DEFAULT_TIMEOUT
     */
    long leaseTime() default DEFAULT_TIMEOUT;

    /**
     * 时间单位。默认为毫秒。
     */
    TimeUnit timeUnit() default TimeUnit.MILLISECONDS;

    /**
     * 尝试获取锁的默认等待时间
     */
    long DEFAULT_WAIT_TIME = 10000L;
    /**
     * 锁的默认超时时间. 超时后, 锁会被自动释放
     */
    long DEFAULT_TIMEOUT = 5000L;
    /**
     * 时间单位。默认为毫秒。
     */
    TimeUnit DEFAULT_TIME_UNIT = TimeUnit.MILLISECONDS;
    /**
     * 获得锁名时拼接前后缀用到的分隔符
     */
    String DEFAULT_SEPARATOR = ":";
    /**
     * lock
     */
    String LOCK = "lock";
}

定义切面织入的代码

/**
 * 分布式锁切面逻辑
 */
@Slf4j
@Aspect
public class DistributedLockAspect implements ApplicationContextAware, BeanFactoryAware, Ordered {

    /**
     * 解析模板
     */
    private static final ParserContext PARSER_CONTEXT = ParserContext.TEMPLATE_EXPRESSION;
    /**
     * SpEL 解析器
     */
    private static final SpelExpressionParser spElParser = new SpelExpressionParser();

    private ApplicationContext applicationContext;

    private BeanFactory beanFactory;
    /**
     * 用于解析 @BeanName 为对应的 Spring Bean
     */
    private BeanResolver beanResolver;

    private final DistributedLockTemplate lockTemplate;

    public DistributedLockAspect(DistributedLockTemplate lockTemplate) {
        this.lockTemplate = lockTemplate;
    }


    @Around(value = "@annotation(distributedLock)")
    public Object doAround(ProceedingJoinPoint pjp, DistributedLock distributedLock) {
        EvaluationContext evaluationCtx = getEvaluationContext(pjp);

        doCheckBefore(distributedLock, evaluationCtx);

        String lockName = getLockName(pjp, evaluationCtx);

        return lock(pjp, lockName);
    }

    /**
     * 执行 {@link DistributedLock#checkBefore()} 指定的方法
     *
     * @param distributedLock
     * @param evaluationCtx
     */
    private void doCheckBefore(DistributedLock distributedLock, EvaluationContext evaluationCtx) {
        String checkBefore = distributedLock.checkBefore();
        resolveExpression(evaluationCtx, checkBefore);
    }

    /**
     * 获取锁名
     *
     * @param jp
     * @return
     */
    private String getLockName(JoinPoint jp, EvaluationContext evaluationCtx) {
        DistributedLock annotation = getAnnotation(jp);
        String lockName = annotation.lockName();

        if (StrUtil.isNotBlank(lockName)) {
            lockName = resolveExpression(evaluationCtx, lockName);
        } else {
            Object[] args = jp.getArgs();
            if (args.length > 0) {
                String param = annotation.param();
                if (StrUtil.isNotBlank(param)) {
                    Object arg;
                    if (annotation.argSeq() > 0) {
                        arg = args[annotation.argSeq() - 1];
                    } else {
                        arg = args[0];
                    }
                    lockName = String.valueOf(getParam(arg, param));
                } else if (annotation.argSeq() > 0) {
                    lockName = String.valueOf(args[annotation.argSeq() - 1]);
                }
            }
        }

        if (StrUtil.isBlank(lockName)) {
            CommonResponseEnum.SERVER_ERROR.assertFailWithMsg("无法生成分布式锁锁名. annotation: {0}", annotation);
        }

        String preLockName = annotation.lockNamePre();
        String postLockName = annotation.lockNamePost();
        String separator = annotation.separator();

        if (StrUtil.isNotBlank(preLockName)) {
            lockName = preLockName + separator + lockName;
        }
        if (StrUtil.isNotBlank(postLockName)) {
            lockName = lockName + separator + postLockName;
        }

        return lockName;
    }

    /**
     * 从方法参数获取数据
     *
     * @param param
     * @param arg 方法的参数数组
     * @return
     */
    private Object getParam(Object arg, String param) {

        if (StrUtil.isNotBlank(param) && arg != null) {
            try {
                return BeanUtil.getFieldValue(arg, param);
            } catch (Exception e) {
                CommonResponseEnum.SERVER_ERROR.assertFailWithMsg("[{0}] 没有属性 [{1}]", arg.getClass(), param);
            }
        }
        return "";
    }

    /**
     * 获取锁并执行
     *
     * @param pjp
     * @param lockName
     * @return
     */
    private Object lock(ProceedingJoinPoint pjp, final String lockName) {
        DistributedLock annotation = PointCutUtils.getAnnotation(pjp, DistributedLock.class);

        boolean fairLock = annotation.fairLock();
        boolean tryLock = annotation.tryLock();

        if (tryLock) {
            return tryLock(pjp, annotation, lockName, fairLock);
        } else {
            return lock(pjp,lockName, fairLock);
        }
    }

    /**
     *
     * @param pjp
     * @param lockName
     * @param fairLock
     * @return
     */
    private Object lock(ProceedingJoinPoint pjp, final String lockName, boolean fairLock) {
        return lockTemplate.lock(new DistributedLockCallback<Object>() {
            @Override
            public Object process() throws Throwable {
                return pjp.proceed();
            }

            @Override
            public String getLockName() {
                return lockName;
            }
        }, fairLock);
    }

    /**
     *
     * @param pjp
     * @param annotation
     * @param lockName
     * @param fairLock
     * @return
     */
    private Object tryLock(ProceedingJoinPoint pjp, DistributedLock annotation, final String lockName, boolean fairLock) {

        long waitTime = annotation.waitTime(), leaseTime = annotation.leaseTime();
        TimeUnit timeUnit = annotation.timeUnit();

        return lockTemplate.tryLock(new DistributedLockCallback<Object>() {
            @Override
            public Object process() throws Throwable {
                return pjp.proceed();
            }

            @Override
            public String getLockName() {
                return lockName;
            }
        }, waitTime, leaseTime, timeUnit, fairLock);
    }

    // 省略若干
}

这里由于篇幅过长,只贴了主要代码,就不贴其他相关类的代码,如果有需要的可以去 Github 自行获取。

如何使用

如果是在本地测试,什么都不用配置,因为使用了 springboot-starter 的规范封装的,只需像其他 springboot-starter 一样,引入对应的依赖即可,类似如下。

<dependency>
    <groupId>com.sprainkle</groupId>
    <artifactId>spring-cloud-advance-common-lock</artifactId>
    <version>${spring-cloud-advance.version}</version>
</dependency>

然后,就可以像如下代码一样,直接在业务代码前加上分布式锁注解即可使用:

@DistributedLock(lockName = "#{#itemId}", lockNamePre = "item")
public void doSeckill(Long itemId, int purchaseCount) {
    // 获取库存
    // 比较并扣减库存
    // 更新库存
    // 异步执行其他逻辑
}

注解 DistributedLock 个参数的使用,可参考各参数的说明。

开始使用

因为这里直接在模块中编写测试用例,所以不用引入依赖。可参考 源码

定义实体类 TestItem

@Data
@TableName("test_item")
public class TestItem {

    @TableId
    private Integer id;
    private String name;
    private Integer stock;

    public TestItem() {
    }

    public TestItem(Integer id, String name) {
        this.id = id;
        this.name = name;
    }
}

服务实现类 TestItemService

@Slf4j
@Service
public class TestItemService extends ServiceImpl<TestItemMapper, TestItem> {
    // 具体逻辑见下文
}

本次测试 ORM 框架使用了 Mybatis Plus,其他类,如 TestItemMapper 等,由于篇幅过长,这里就不展示了。

测试用例

定义 Worker 类
public static class Worker implements Runnable {

        private final CountDownLatch startSignal;
        private final CountDownLatch doneSignal;
        private final Action action;

        public Worker(CountDownLatch startSignal, CountDownLatch doneSignal, Action action) {
            this.startSignal = startSignal;
            this.doneSignal = doneSignal;
            this.action = action;
        }

        @Override
        public void run() {
            try {
                System.out.println(Thread.currentThread().getName() + " start");
                // 阻塞, 直到接收到启动信号. 保证所有线程的起跑线是一样的, 即都是同时启动
                startSignal.await();
                // 具体逻辑
                action.execute();
                // 发送 已完成 信号
                doneSignal.countDown();
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }

该类是下文所有测试用例会用到的类。

testPlainLockName

首先我们定义 TestItemService 类,其中 initStock 为初始化库存,之后每个测试用例的第一步逻辑都是调用该方法,初始化/重置库存

testPlainLockName 才是该测试用例的主要逻辑,逻辑也很简单,就不废话了,重点看分布式锁注解。 可以看到要开启分布式锁,只需在方法签名加上 @DistributedLock(lockName = "item:1", lockNamePre = "item") 即可,非常方便。

那这么写是什么意思呢?其实就是,有很多线程在扣减商品(id=1)的库存前,需要拿到一把锁,锁名为 distributed-lock-test:item:1.lock,其中 distributed-lock-test:.lock 是框架自己补进去的,剩下的 item:1 则是根据 lockNamelockNamePre 拼接的,拼接符默认为 :

很明显,lockName 写死为 "1" 肯定不合适,这里只是演示需要,具体优化请继续往下,下文会给出。

@Slf4j
@Service
public class TestItemService extends ServiceImpl<TestItemMapper, TestItem> {

    private static final AtomicInteger i = new AtomicInteger(10);

    @Transactional(rollbackFor = Throwable.class)
    public TestItem initStock(Long id, Integer stock) {
        TestItem item = this.getById(id);

        if (item == null) {
            item = new TestItem(1, "牛奶");
        }

        item.setStock(stock);
        this.saveOrUpdate(item);

        return this.getById(id);
    }

    /**
     * 锁名为固定的字符串
     */
    @DistributedLock(lockName = "1", lockNamePre = "item")
    public Integer testPlainLockName(TestItem testItem) {
        TestItem item = this.getById(testItem.getId());
        Integer stock = item.getStock();

        if (stock > 0) {
            stock = stock - 1;
            item.setStock(stock);
            this.saveOrUpdate(item);
        } else {
            stock = -1;
        }

        return stock;
    }
}

接下来,定义测试用例类。

@Slf4j
@RunWith(SpringRunner.class)
@SpringBootTest(classes = DistributedLockTestApplication.class)
public class DistributedLockTests {
    // 10 个线程一起跑
    private static int count = 10;

    @Autowired
    private TestItemService testItemService;

    @Test
    public void testPlainLockName() {
        Consumer<TestItem> consumer = testItem -> {
            Integer stock = testItemService.testPlainLockName(testItem);
            if (stock >= 0) {
                System.out.println(Thread.currentThread().getName() + ": rest stock = " + stock);
            } else {
                System.out.println(Thread.currentThread().getName() + ": sold out.");
            }
        };

        commonTest(consumer);
    }

    private void commonTest(Consumer<TestItem> consumer)  {
        try {
            CountDownLatch startSignal = new CountDownLatch(1);
            CountDownLatch doneSignal = new CountDownLatch(count);

            TestItem item = testItemService.initStock(1L, 8);

            for (int i = 0; i < count; ++i) {
                Action action = () -> consumer.accept(item);
                new Thread(new Worker(startSignal, doneSignal, action)).start();
            }
            
            // let all threads proceed
            startSignal.countDown(); 
            doneSignal.await();
            System.out.println("All processors done. Shutdown connection");
        } catch (Exception e) {
            log.error("", e);
        }
    }
}

其中,每次启动的线程数为 10,然后,commonTest(Consumer<TestItem> consumer) 为这里及之后所有测试用例的公共代码,所以抽象出来了,而 Consumer<TestItem> consumer 才是测试用例间不同的逻辑,并且每次都会把库存初始化为 8

  1. 这里省略了各种配置文件和配置类等。
  2. 下文的所有截图,有一处单词拼写错误,rest stock 写成了 reset stock,还请将就着看。

启动测试用例,可以看到类似如下的控制台打印:


testPlainLockName

理论上应该只有8个线程能正常扣减库存,而结果也与预想的一样。这时如果去掉注解 @DistributedLock(lockName = "1", lockNamePre = "item"),会出现什么结果呢?类似如下:

without DistributedLock annotation
testSpel

很明显,上一个测试用例中,lockName 写死为 "1" 是不可取,而是应该取入参 testItemid 的值,接下来,使用 SpEL 来实现该需求。

public class TestItemService {
    @DistributedLock(
            lockName = "#{#testItem.id}",
            lockNamePre = "item"
    )
    public Integer testSpel(TestItem testItem) {
        TestItem item = this.getById(testItem.getId());
        Integer stock = item.getStock();

        if (stock > 0) {
            stock = stock - 1;
            item.setStock(stock);
            this.saveOrUpdate(item);
        } else {
            stock = -1;
        }

        return stock;
    }
}

可以看到表达式 #{#testItem.id} 可能与大家以前使用的不太一样,这是因为在解析 SpEL 表达式时,使用了 解析模板 #{},即表达式必须使用 #{} 包裹起来。其中对 SpEL 的支持,因为不是本文的重点,可以参考: https://cloud.tencent.com/developer/article/1497676

public class DistributedLockTests {
    @Test
    public void testSpel() {
        Consumer<TestItem> consumer = testItem -> {
            Integer stock = testItemService.testSpel(testItem);
            if (stock >= 0) {
                System.out.println(Thread.currentThread().getName() + ": rest stock = " + stock);
            } else {
                System.out.println(Thread.currentThread().getName() + ": sold out.");
            }
        };

        commonTest(consumer);
    }
}

启动测试用例,可以看到控制台类似输出:


testSpel
testCheckBefore
public class TestItemService
    @DistributedLock(
            lockName = "#{#testItem.id}",
            lockNamePre = "item",
            checkBefore = "#{#root.target.check(#testItem)}"
    )
    public Integer testCheckBefore(TestItem testItem) {
        TestItem item = this.getById(testItem.getId());
        Integer stock = item.getStock();

        if (stock > 0) {
            stock = stock - 1;
            item.setStock(stock);
            this.saveOrUpdate(item);
        } else {
            stock = -1;
        }

        return stock;
    }

    public void check(TestItem testItem) {
        int randomInt = RandomUtil.randomInt(100, 10000);
        if (randomInt % 3 == 0) {
            System.out.println(String.format("current thread: %s, randomInt: %d", getCurrentThreadName(), randomInt));
            CommonResponseEnum.SERVER_BUSY.assertFail();
        }
    }
}

其中,参数 checkBefore 用于在开始加锁前, 执行某个方法进行校验,比如这里,使用方法 check 模拟了流量控制,符合一定条件时,直接抛异常返回。

public class DistributedLockTests {
    @Test
    public void testCheckBefore() {
        Consumer<TestItem> consumer = testItem -> {

            Integer stock = -1;

            try {
                stock = testItemService.testCheckBefore(testItem);
            } catch (Exception e) {
                System.out.println(Thread.currentThread().getName() + ": 系统繁忙");
                return;
            }

            if (stock >= 0) {
                System.out.println(Thread.currentThread().getName() + ": rest stock = " + stock);
            } else {
                System.out.println(Thread.currentThread().getName() + ": sold out.");
            }
        };

        commonTest(consumer);
    }
}

启动测试用例,类似控制台输出如下:


testCheckBefore

这里有一点需要注意,方法 check 必须为 public

testTryLock、testFairLock、testWaitTime

这几个测试用例比较简单,这里就不展示了,有兴趣的可以自己跑一下。

testLeaseTime

首先来理解一下参数 leaseTime 的作用,即:锁超时时间,超时时间过后,锁自动释放

挺好理解的,但重点是,锁被自动释放后,之前执行的逻辑需要怎么处理。
在最后释放锁的时候,发现锁已经不是当前线程持有,有可能已经被其他持有,那之前获得锁后执行的逻辑,都变得不可信了,所以理论上需要撤销,如果是数据库操作,那就是回滚。

首先来看 testLeaseTime 的第一个测试用例

public class TestItemService
    @DistributedLock(
            lockName = "#{#testItem.id}",
            lockNamePre = "item",
            leaseTime = 2000
    )
    public Integer testLeaseTime(TestItem testItem) {
        int ci = TestItemService.i.getAndDecrement();
        if (ci == 10) {
            sleep(5000L);
            log.info("模拟阻塞完成");
        } else {
            sleep(300L);
        }

        TestItem item = this.getById(testItem.getId());
        Integer stock = item.getStock();

        if (stock > 0) {
            stock = stock - 1;
            item.setStock(stock);
            this.saveOrUpdate(item);
        } else {
            stock = -1;
        }

        return stock;
    }

    private void sleep(long millis) {
        try {
            Thread.sleep(millis);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }
}
public class DistributedLockTests {
    @Test
    public void testLeaseTime() {
        Consumer<TestItem> consumer = testItem -> {

            Integer stock = testItemService.testLeaseTime(testItem);

            if (stock >= 0) {
                System.out.println(Thread.currentThread().getName() + ": rest stock = " + stock);
            } else {
                System.out.println(Thread.currentThread().getName() + ": sold out.");
            }
        };

        commonTest(consumer);
    }
}

启动测试用例,结果类似如下:


testLeaseTime

这样乍一看,看下没什么毛病,数据库的库存也是正常的,为 0。但再细看一下代码,模拟阻塞是在一开始就进行的,那如果把该部分代码挪到从数据库获取到数据之后,会发生什么呢?

来看第二个测试用例:

public class TestItemService
    /**
     * 超卖
     *
     * @param testItem
     * @return
     */
    @DistributedLock(
            lockName = "#{#testItem.id}",
            lockNamePre = "item",
            leaseTime = 2000
    )
    public Integer testLeaseTimeOversold(TestItem testItem) {
        TestItem item = this.getById(testItem.getId());

        int ci = TestItemService.i.getAndDecrement();
        if (ci == 10) {
            sleep(5000L);
            log.info("模拟阻塞完成");
        } else {
            sleep(300L);
        }

        Integer stock = item.getStock();

        if (stock > 0) {
            stock = stock - 1;
            item.setStock(stock);
            this.saveOrUpdate(item);
        } else {
            stock = -1;
        }

        return stock;
    }
}
public class DistributedLockTests {
    @Test
    public void testLeaseTimeOversold() {
        Consumer<TestItem> consumer = testItem -> {

            Integer stock = testItemService.testLeaseTimeOversold(testItem);

            if (stock >= 0) {
                System.out.println(Thread.currentThread().getName() + ": rest stock = " + stock);
            } else {
                System.out.println(Thread.currentThread().getName() + ": sold out.");
            }
        };

        commonTest(consumer);
    }
}

启动测试用例,结果类似如下:


testLeaseTimeOversold

可以看到,很明显超卖了,sold out 之后,居然还能继续扣减库存,这时候数据库的库存应该也是 7,为什么呢?
因为它扣减库存时,锁已经被释放了,未持有锁的情况下,扣减的库存大概率都是有问题的,只要并发足够大。

testLeaseTimeWithTransactional
public class TestItemService
    @DistributedLock(
            lockName = "#{#testItem.id}",
            lockNamePre = "item",
            leaseTime = 2000
    )
    @Transactional(rollbackFor = Throwable.class)
    public Integer testLeaseTimeWithTransactional(TestItem testItem) {
        TestItem item = this.getById(testItem.getId());
        Integer stock = item.getStock();

        int ci = TestItemService.i.getAndDecrement();
        if (ci == 10) {
            sleep(5000L);
            log.info("模拟阻塞完成");
        } else {
            sleep(300L);
        }

        if (stock > 0) {
            stock = stock - 1;
            item.setStock(stock);
            this.saveOrUpdate(item);
        } else {
            stock = -1;
        }

        return stock;
    }
}
public class DistributedLockTests {
    @Test
    public void testLeaseTimeWithTransactional() {
        Consumer<TestItem> consumer = testItem -> {

            Integer stock = testItemService.testLeaseTimeWithTransactional(testItem);

            if (stock >= 0) {
                System.out.println(Thread.currentThread().getName() + ": rest stock = " + stock);
            } else {
                System.out.println(Thread.currentThread().getName() + ": sold out.");
            }
        };

        commonTest(consumer);
    }
}

启动测试用例,结果类似如下:


testLeaseTimeWithTransactional

可以看到,虽然只有一个线程打印 "sold out", 但 "Thread-11" 也扣减失败了,所以超卖并没有出现。这时的库存为 0
上图中,输出的东西比其他其他测试用例多了异常的堆栈日志,在哪里抛出来的呢?如下:

2021-08-16 16:23:32.332  INFO 96595 --- [      Thread-11] c.s.s.c.a.c.l.service.TestItemService    : 模拟阻塞完成
2021-08-16 16:23:32.349  WARN 96595 --- [      Thread-11] c.a.c.l.a.i.RedisDistributedLockTemplate : 锁释放失败, 当前线程不是锁的持有者
com.sprainkle.spring.cloud.advance.common.core.exception.BusinessException: 系统繁忙,请稍后重试
    at com.sprainkle.spring.cloud.advance.common.core.exception.assertion.BusinessExceptionAssert.newException(BusinessExceptionAssert.java:35)
    at com.sprainkle.spring.cloud.advance.common.core.exception.assertion.Assert.newExceptionWithMsg(Assert.java:52)
    at com.sprainkle.spring.cloud.advance.common.core.exception.assertion.Assert.assertFailWithMsg(Assert.java:789)
    at com.sprainkle.spring.cloud.advance.common.lock.api.UnlockFailureProcessor.beforeCommit(UnlockFailureProcessor.java:28)
    at org.springframework.transaction.support.TransactionSynchronizationUtils.triggerBeforeCommit(TransactionSynchronizationUtils.java:96)
    // ... 省略若干
    at com.sprainkle.spring.cloud.advance.common.lock.DistributedLockTests.lambda$commonTest$8(DistributedLockTests.java:185)
    at com.sprainkle.spring.cloud.advance.common.lock.DistributedLockTests$Worker.run(DistributedLockTests.java:217)
    at java.lang.Thread.run(Thread.java:748)
Caused by: com.sprainkle.spring.cloud.advance.common.core.exception.WrapMessageException: 释放锁时, 当前线程不是锁的持有者
    at com.sprainkle.spring.cloud.advance.common.core.exception.assertion.Assert.newExceptionWithMsg(Assert.java:51)
    ... 33 more

这里可以抛出一个问题,为什么加上 @Transactional(rollbackFor = Throwable.class) 注解后,就不会出现超卖呢?仅仅是因为 @Transactional 注解,还是还有其他原因?

与该项目相关的第一行日志为:at *.lock.api.UnlockFailureProcessor.beforeCommit(UnlockFailureProcessor.java:28),且最后的 Caused by释放锁时, 当前线程不是锁的持有者,大概可以猜出:数据库的交互在最后的 commit 前,判断了当前线程是否为锁的持有者,如果不是,则抛异常让数据回滚。

这里可以简单给出 UnlockFailureProcessor 的源码:

/**
 * 锁释放失败时的处理器. 如果当前线程不是锁的持有者, 直接抛异常让数据回滚
 */
public class UnlockFailureProcessor implements TransactionSynchronization {

    private final Object lock;

    private final DistributedLockTemplate distributedLockTemplate;

    public UnlockFailureProcessor(DistributedLockTemplate distributedLockTemplate, Object lock) {
        this.distributedLockTemplate = distributedLockTemplate;
        this.lock = lock;
    }

    @Override
    public void beforeCommit(boolean readOnly) {
        boolean heldByCurrentThread = distributedLockTemplate.isHeldByCurrentThread(lock);

        if (!heldByCurrentThread) {
            ResponseEnum.LOCK_NO_MORE_HOLD.assertFailWithMsg("释放锁时, 当前线程不是锁的持有者");
        }
    }

}

因此,可以得出结论,其实真正起作用的并不仅仅是因为加了 @Transactional 注解,还需要有相应的其他支持,@Transactional 注解只是让其拥有管理事务的环境,方便数据回滚。

如果有兴趣,可以将 ResponseEnum.LOCK_NO_MORE_HOLD.assertFailWithMsg("释放锁时, 当前线程不是锁的持有者"); 注释,然后再跑一遍,结果为:

testLeaseTimeWithoutRollback

不出意外的话,这时的库存为 7

基于 ZooKeeper 实现

基于 ZooKeeper 的分布式锁实现,在源码中已给出,请参考实现类 ZooDistributedLockTemplate。使用的时候,只需将配置调整为:

sca-common:
  distributed:
    lock:
      impl: zoo
      zoo:
        # zookeeper服务器地址. 多个时用','分开
        connectString: "127.0.0.1:2181"
        # zookeeper的session过期时间. 即锁的过期时间. 可用于全局配置锁的过期时间
        sessionTimeoutMs: 10000
        # zookeeper的连接超时时间
        connectionTimeoutMs: 15000

当然也是需要引入相关依赖:

        <dependency>
            <groupId>org.apache.curator</groupId>
            <artifactId>curator-recipes</artifactId>
        </dependency>

结语

至此,本文的主要内容已介绍完毕, 由于本文的篇幅过长,贴了太多代码,所以只演示了简单使用,当业务比较复杂的时候,上面的使用方法可能没办法很好的支持;另外,在使用过程中也有需要注意的地方,不然有可能出现分布式锁注解不生效的情况;还有部分关键代码的原理以及背后的原因,都不好在这里一一说明,只能放到另外一篇文章做详细分析。

然后,这篇文章仅作为抛砖引玉,如果有其他更好的方案,欢迎留言,一起讨论学习。
谢谢!!!

推荐阅读更多精彩内容