AbstractQueuedSynchronizer深入分析

完整代码:https://github.com/shawntime/test-enjoy-architect/tree/master/test-concurrency

什么是AQS?

AQS(Abstract Queued Synchronizer):抽象队列同步器,是Java并发包中一系列并发锁(可重入锁、读锁、写锁、读写锁等),同步器(CountDownLatch、CyclicBarrier、Semaphore等)的实现基础框架,这些锁和同步器都继承了AQS

image-20200922164011748
AQS实现原理
  • 在AQS中维护了一个FIFO的双端队列,用一个int类型维护状态,如果当前线程竞争失败后,AQS会将当前线程封装成一个Node节点加入到队列队尾,同时会阻塞当前线程。
image-20200922170702846
  • 当头结点释放同步状态后,且后继节点对应的线程被阻塞,此时头结点线程将会去唤醒后继节点线程。后继节点线程恢复运行并获取同步状态后,会将旧的头结点从队列中移除,并将自己设为头结点。
image
AQS使用方式

AQS本身是一个抽象类,基于模板方法模式,将同步状态的方法和释放的方法过程封装交给子类去实现

AQS的两种功能
  • 独占锁:每次只有一个线程能获取到锁,其他的线程在队列中阻塞等待被唤醒
  • 共享锁:允许多个线程同时获取锁,共享的方式并发访问资源
重要模板方法
accquire() // 独占式获取
acquireInterruptibly() // 带中断的独占式获取
tryAcquireNanos() // 带超时的独占式获取
release() // 独占式释放

acquireShared() // 分享式获取
acquireSharedInterruptibly() // 带中断的分享式获取
tryAcquireSharedNanos() // 带超时的分享式获取
releaseShared() // 分享式释放
需要子类去实现的方法
tryAcquire() // 独占式获取,尝试去获取资源
tryRelease() // 独占式释放,尝试去释放资源
tryAcquireShared() // 共享式获取,尝试获取资源。负数表示失败;0表示成功,但没有剩余可用资源;正数表示成功,且有剩余资源。
tryReleaseShared() // 共享式释放,尝试释放资源,如果释放后允许唤醒后续等待结点返回true,否则返回false。
isHeldExclusively() //  当前同步器是否处于独占模式,只有用到condition才需要去实现它
同步状态state
private volatile int state;

AQS中定义一个整型state,它基本是整个工具的核心,通常整个工具都是在设置和修改状态,很多方法的操作都依赖于当前状态是什么。由于状态是全局共享的,一般会被设置成volatile类型,以保证其修改的可见性

getState() // 获取当前的同步状态
setState() // 设置当前同步状态
compareAndSetState() // 使用CAS设置状态,保证状态设置的原子性
独占锁和共享锁执行流程区分

独占锁:以ReentrantLock为例,state的初始状态值为0,第一个线程A调用tryAcquire()方法成功后独占锁并state+1,此后其他线程调用tryAcquire()时都会失败,转入双端队列中阻塞等待,直到线程A调用tryRelease()释放锁后将state--,其他线程才有机会获取锁,当然,释放锁之前,A线程自己是可以重复获取此锁的(state会累加),这就是可重入的概念。但要注意,获取多少次就要释放多么次,这样才能保证state是能回到零态的。

共享锁:以CountDownLatch为例,主任务拆分N个子任务去完成,state的初始值被设置成N,与子任务线程数相等,主线程(可能有多个等待的线程)调用await()方法时判断state是否为0,如果不为0(代表子任务还没有结束)则进入队列阻塞等待唤醒,子任务每执行完成后都调用一次countDown()方法,对state--,当state==0时去唤醒队列中头节点任务,头节点继续唤醒下一节点,依次唤醒

CAS

CAS :比较和交换
CPU指令级别保证这是一个原子操作
三个运算符: 一个内存地址V,一个期望的值A,一个新值B
基本思路:如果地址V上的值和期望的值A相等,就给地址V赋给新值B,如果不是,不做任何操作。
循环(死循环,自旋)里不断的进行CAS操作

为什么使用双向队列?

如果你的队列是单向的如:Head -> N1 -> N2 -> Tail。出队的时候你要获取N1很简单,Head.next就行了,但入队时要遍历整个链表到N2,然后N2.next = N3;N3.next = Tail。入队的复杂度就是O(n)。相反双向链表出队和入队都是O(1)时间复杂度。空间换时间。

Node

AQS维护了一个双向队列,将所有阻塞的线程封装成Node对象添加到队列中等待唤醒

volatile Node next; // 后序节点next
volatile Node prev; // 前驱节点prev
volatile Thread thread; // 当前线程
volatile int waitStatus; // 当前节点状态

waitStatus有四种状态

  • CANCELLED = 1:表示当前线程在等待的时候已经超时了或者被取消了
  • SIGNAL = -1:当前线程释放了同步状态或者被取消的时候会通知后继节点,使后继节点得以运行
  • CONDITION = -2:节点在等待队列中,等待Condition,当其他线程在Condition上调用signal的时候,该线程会从等待队列转移到同步队列中去,加入到同步状态的获取
  • PROPAGATE = -3:表示下一次共享式同步状态会无条件的传播下去
  • 默认值 = 0:表示初始值
自定义一个独占锁
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.AbstractQueuedSynchronizer;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;

/**
 * 自定义一个独占锁
 */
public class MyLock implements Lock {

    private static class Sync extends AbstractQueuedSynchronizer {

        /**
         * 获取锁
         */
        @Override
        protected boolean tryAcquire(int arg) {
            if (compareAndSetState(0, 1)) {
                setExclusiveOwnerThread(Thread.currentThread());
                return true;
            }
            return false;
        }

        /**
         * 释放锁
         */
        @Override
        protected boolean tryRelease(int arg) {
            if (Thread.currentThread() != getExclusiveOwnerThread()) {
                throw new IllegalMonitorStateException();
            }
            if (getState() == 0) {
                throw new UnsupportedOperationException();
            }
            setState(0);
            setExclusiveOwnerThread(null);
            return true;
        }

        /**
         * 锁是否独占
         */
        @Override
        protected boolean isHeldExclusively() {
            return getState() == 1;
        }

        final Condition createCondition() {
            return new ConditionObject();
        }
    }

    private final Sync sync = new Sync();

    @Override
    public void lock() {
        sync.acquire(1);
    }

    @Override
    public void lockInterruptibly() throws InterruptedException {
        sync.acquireInterruptibly(1);
    }

    @Override
    public boolean tryLock() {
        return sync.tryAcquire(1);
    }

    @Override
    public boolean tryLock(long time, TimeUnit unit) throws InterruptedException {
        return sync.tryAcquireNanos(1, unit.toNanos(time));
    }

    @Override
    public void unlock() {
        sync.release(1);
    }

    @Override
    public Condition newCondition() {
        return sync.createCondition();
    }
}
自定义一个限流锁
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.AbstractQueuedSynchronizer;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;

/**
 * 实现限流锁,同时最多只有N个线程处理,其他线程等待
 * 获取到锁的条件:N > 0
 * 阻塞条件 N <= 0
 */
public class CurrentLimitingLock implements Lock {

    private Sync sync;

    public CurrentLimitingLock(int limitNum) {
        sync = new Sync(limitNum);
    }

    @Override
    public void lock() {
        sync.acquireShared(1);
    }

    @Override
    public void lockInterruptibly() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }

    @Override
    public boolean tryLock() {
        return sync.tryAcquireShared(1) >= 0;
    }

    @Override
    public boolean tryLock(long time, TimeUnit unit) throws InterruptedException {
        return sync.tryAcquireSharedNanos(1, unit.toNanos(time));
    }

    @Override
    public void unlock() {
        sync.releaseShared(1);
    }

    @Override
    public Condition newCondition() {
        return sync.newCondition();
    }

    private static final class Sync extends AbstractQueuedSynchronizer {

        public Sync(int maxNum) {
            setState(maxNum);
        }

        /**
         * 尝试去获取锁
         */
        @Override
        protected int tryAcquireShared(int acquireNum) {
            for (;;) {
                int state = getState();
                int newState = state - acquireNum;
                if (newState < 0) {
                    // 不足
                    return -1;
                }
                if (compareAndSetState(state, newState)) {
                    return newState;
                }
            }
        }

        @Override
        protected boolean tryReleaseShared(int releaseNum) {
            for (;;) {
                int state = getState();
                int newState = state + releaseNum;
                if (compareAndSetState(state, newState)) {
                    return true;
                }
            }
        }

        protected Condition newCondition() {
            return new ConditionObject();
        }
    }
}