ThreadLocal 原理

总述

ThreadLocal 在面试中经常提到,关于ThreadLocal使用不当造成OOM以及在特殊场景下,通过ThreadLocal可以轻松实现一些看起来复杂的功能,都说明值得花时间研究其原理。

ThreadLocal 不是 Thread,是一个线程内部的数据存储类,通过它可以在指定的线程中存储数据,对数据存储后,只有在线程中才可以获取到存储的数据,对于其他线程来说是无法获取到数据。可能这才是Local的真正含义吧。

使用场景

对于 ThreadLocal 的使用场景,一般来说,当某些数据是以线程为作用域并且不同线程具有不同的数据副本的时候,就可以考虑采用ThreadLocal。比如对于Handler来说,它需要获取当前线程的Looper,很显然Looper的作用域就是线程并且不同线程具有不同的Looper,这个时候通过ThreadLocal就可以轻松实现Looper在线程中的存取,如果不采用ThreadLocal,那么系统就必须提供一个全局的哈希表供Handler查找指定线程的Looper,这样一来就必须提供一个类似于LooperManager的类了,但是系统并没有这么做而是选择了ThreadLocal,这就是ThreadLocal的好处。

ThreadLocal另一个使用场景是复杂逻辑下的对象传递,比如监听器的传递,有些时候一个线程中的任务过于复杂,这可能表现为函数调用栈比较深以及代码入口的多样性,在这种情况下,我们又需要监听器能够贯穿整个线程的执行过程,这个时候可以怎么做呢?其实就可以采用ThreadLocal,采用ThreadLocal可以让监听器作为线程内的全局对象而存在,在线程内部只要通过get方法就可以获取到监听器。而如果不采用ThreadLocal,那么我们能想到的可能是如下两种方法:第一种方法是将监听器通过参数的形式在函数调用栈中进行传递,第二种方法就是将监听器作为静态变量供线程访问。上述这两种方法都是有局限性的。第一种方法的问题时当函数调用栈很深的时候,通过函数参数来传递监听器对象这几乎是不可接受的,这会让程序的设计看起来很糟糕。第二种方法是可以接受的,但是这种状态是不具有可扩充性的,比如如果同时有两个线程在执行,那么就需要提供两个静态的监听器对象,如果有10个线程在并发执行呢?提供10个静态的监听器对象?这显然是不可思议的,而采用ThreadLocal每个监听器对象都在自己的线程内部存储,根据就不会有方法2的这种问题。

使用 Demo

public class ThreadLocalTest {

    public static void main(String[] args) {
        final ThreadLocal<String> threadLocal1 = new ThreadLocal<>();
        final ThreadLocal<Integer> threadLocal2 = new ThreadLocal<>();

        new Thread(new Runnable() {
            @Override
            public void run() {
                threadLocal1.set("A");
                threadLocal2.set(1);
                try {
                    Thread.sleep(1000);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
                System.out.println(Thread.currentThread().getName());
                System.out.println(threadLocal1.get());
                System.out.println(threadLocal2.get());
            }
        }).start();


        new Thread(new Runnable() {
            @Override
            public void run() {
                threadLocal1.set("B");
                threadLocal2.set(2);
                try {
                    Thread.sleep(1000);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
                System.out.println(Thread.currentThread().getName());
                System.out.println(threadLocal1.get());
                System.out.println(threadLocal2.get());
            }
        }).start();


        new Thread(new Runnable() {
            @Override
            public void run() {
                System.out.println(Thread.currentThread().getName());
                System.out.println(threadLocal1.get());
                System.out.println(threadLocal2.get());
            }
        }).start();
    }
}

结果如下:

Thread-2
null
null
Thread-1
B
2
Thread-0
A
1

为了简单理解,这时候只看threadLocal1 就好,从上面日志可以看出,虽然在不同线程中访问的是同一个ThreadLocal对象,但是它们通过ThreadLocal来获取到的值却是不一样的,这就是ThreadLocal的奇妙之处。

ThreadLocal之所以有这么奇妙的效果,是因为不同线程访问同一个ThreadLocal的get方法,ThreadLocal内部会从各自的线程中取出一个数组,然后再从数组中根据当前ThreadLocal的索引去查找出对应的value值,很显然,不同线程中的数组是不同的,这就是为什么通过ThreadLocal可以在不同的线程中维护一套数据的副本并且彼此互不干扰。

可能这样说还是很懵,后面讲原理后会给出他的UML图。

ThreadLocal 的内部实现

JDK和SDK的ThreadLocal其实在构想上是一样的,只不过具体代码实现是有些不同。这里讲解的是 Android API 25 的源码。

说了这么多,都是虚的,看源码啦

public class ThreadLocal<T>

抬头一看,泛型类,仔细的朋友估计在前面的使用的时候估计就已经猜到了。而传进来的泛型T的类型就是ThreadLocal需要保存的数据类型。

ThreadLocal.ThreadLocalMap 内部类

参数

在弄清存取过程之前先解决放在哪里的问题。 ThreadLocalMap 就是用来存储的内部类,现在就先介绍存储的ThreadLocalMap的部分参数和构造方法。

 static class ThreadLocalMap {

       /**
         * The entries in this hash map extend WeakReference, using
         * its main ref field as the key (which is always a
         * ThreadLocal object).  Note that null keys (i.e. entry.get()
         * == null) mean that the key is no longer referenced, so the
         * entry can be expunged from table.  Such entries are referred to
         * as "stale entries" in the code that follows.
         */
        static class Entry extends WeakReference<ThreadLocal> {
            /** The value associated with this ThreadLocal. */
            Object value;

            Entry(ThreadLocal k, Object v) {
                super(k);
                value = v;
            }
        }

        /**
         * The initial capacity -- MUST be a power of two.
         */
        private static final int INITIAL_CAPACITY = 16;

        /**
         * The table, resized as necessary.
         * table.length MUST always be a power of two.
         */
        private Entry[] table;

        /**
         * The number of entries in the table.
         */
        private int size = 0;

        /**
         * The next size value at which to resize.
         */
        private int threshold; // Default to 0
   
        // ...

        ThreadLocalMap(ThreadLocal firstKey, Object firstValue) {
            table = new Entry[INITIAL_CAPACITY];
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
            table[i] = new Entry(firstKey, firstValue);
            size = 1;
            setThreshold(INITIAL_CAPACITY);
       }

       /**
         * Set the resize threshold to maintain at worst a 2/3 load factor.
         */
        private void setThreshold(int len) {
            threshold = len * 2 / 3;
        }
}

由上述代码可知,Entry 是一个包含 key 和 value 的一个对象,Entry的key为ThreadLocal,value为ThreadLocal对应的值,只不过是对这个Entry做了一些特殊处理,即 使用 WeakReference<ThreadLocal>ThreadLocal对象变成一个弱引用的对象,这样做的好处就是在线程销毁的时候,对应的实体就会被回收,不会出现内存泄漏。

其余的都很简单,需要说的是 Entry[] table 就是最后存放数据的地方,而默认的大小呢,就是 16,当大于等于容量的 2/3 的时候重新分配table,具体什么时候分配下面再介绍。

set 方法

既然 ThreadLocal是线程内部的数据存储类,只要弄清楚ThreadLocal的get和set方法就可以明白它的工作原理。

接下来就是重点了,当然是 set(),源代码如下:

 public void set(T value) {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }

在上述源码中,首先会通过 getMap() 方法来获取当前线程中的 ThreadLocal 数据。获取的方法就是:直接去当前Thread t 中访问。因为在 Thread 类中有一个成员变量 ThreadLocal.ThreadLocalMap threadLocals = null;专门用于存储线程的 ThreadLocal 数据,他们的关系的UML图请看下面。这时候如果 threadLocals 为 null 的时候,就调用 createMap(t, value); 进行初始化,并把数据放进去,这个构造方法就在上面。

    void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }

接下来就分析一下上面中的 map.set(this, value) 的 set 方法,需要说明一下的是这个方法是在内部类ThreadLocalMap里面。

private void set(ThreadLocal key, Object value) {

    // We don't use a fast path as with get() because it is at
    // least as common to use set() to create new entries as
    // it is to replace existing ones, in which case, a fast
    // path would fail more often than not.

    Entry[] tab = table;
    int len = tab.length;
    // 通过传入的key的hashCode计算出索引的位置
    // 且运算,得到下标,这样子不容易重复
    int i = key.threadLocalHashCode & (len-1);

    for (Entry e = tab[i];
        e != null;
        e = tab[i = nextIndex(i, len)]) {
        ThreadLocal k = e.get();

        if (k == key) {
            e.value = value;
            return;
        }

        if (k == null) {
            replaceStaleEntry(key, value, i);
            return;
        }
    }

    tab[i] = new Entry(key, value);
    int sz = ++size;
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}

我们来讲接一下 set里面的算法,首先是拿到当前 table 的长度,计算出下标,然后从计算出的下标开始循环:

  1. 如果当前指向的 Entry 是存储过的 ThreadLocal,就直接将以前的数据覆盖掉,并结束。
  2. 如果当前这个的 Entry 是一个陈旧Entry(有对象但是k == null),那就调用 replaceStaleEntry(key, value, i);将数据存储进去,并结束。

如果循环的找到一个空的数组,将退出循环,然后将值存在这里,size+1。
值得注意的是最后一段代码,这里执行了一次cleanSomeSlots(int i, int n),这个方法呢,也很简单,就是清除部分的陈旧Entry,如果清除不成功,并且大于等于负载阈值 threshold (当前size的2/3)的时候就会 rehash。至此数据就成功存储进去了。

set() 方法讲完了,简单理一下ThreadThreadLocalThreadLocalMap之间的关系:

Thread,ThreadLocal,ThreadLocalMap 关系

这时候是不是返回去看那个例子就能看懂了呢


使用Memo

每一个 Thread 中都保存着自己的一个 ThreadLocalMap,这就是为什么每个 ThreadLocal 保存进去的东西独立而多样,ThreadLocal 就像是定义了一次操作,当前 ThreadLocal 能够对指定的线程进行存取一份数据。

面试题:ThreadLocal 如何保证Local属性?

当需要使用多线程时,有个变量恰巧不需要共享,此时就不必使用synchronized这么麻烦的关键字来锁住,每个线程都相当于在堆内存中开辟一个空间,线程中带有对共享变量的缓冲区,通过缓冲区将堆内存中的共享变量进行读取和操作,ThreadLocal相当于线程内的内存,一个局部变量。每次可以对线程自身的数据读取和操作,并不需要通过缓冲区与 主内存中的变量进行交互。并不会像synchronized那样修改主内存的数据,再将主内存的数据复制到线程内的工作内存。ThreadLocal可以让线程独占资源,存储于线程内部,避免线程堵塞造成CPU吞吐下降。

在每个Thread中包含一个ThreadLocalMap,ThreadLocalMap的key是ThreadLocal的对象,value是独享数据。

rehash() 扩容

这一部分呢,比较简单,就简单讲解一下。

   private void rehash() {
        expungeStaleEntries();

       // Use lower threshold for doubling to avoid hysteresis
        if (size >= threshold - threshold / 4)
             resize();
  }

首先一来就调用了 expungeStaleEntries() 来去除陈旧无用的Entry(key == null),那怎么去除就请接着看。

  private void expungeStaleEntries() {
        Entry[] tab = table;
        int len = tab.length;
        for (int j = 0; j < len; j++) {
            Entry e = tab[j];
            if (e != null && e.get() == null)
                expungeStaleEntry(j);
        }
  }

这也是特别简单的,就去遍历一遍 table 数组,挨个判断每一个是不是陈旧(key==null)的Entry,但是具体怎么去除单个无用的Entry呢?

 private int expungeStaleEntry(int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;

            // expunge entry at staleSlot 
            // 将tab上staleSlot位置的对象清空
            tab[staleSlot].value = null;
            tab[staleSlot] = null;
            size--;

            // Rehash until we encounter null
            Entry e;
            int i;
            for (i = nextIndex(staleSlot, len); // 遍历下一个元素, 即(i+1)%len位置的元素
                 (e = tab[i]) != null;   // 遍历到Entry为空时, 跳出循环并返回索引位置
                 i = nextIndex(i, len)) {
                ThreadLocal k = e.get();
                if (k == null) {  // 当前遍历Entry的key为空, 则将该位置的对象清空
                    e.value = null;
                    tab[i] = null;
                    size--;
                } else {    // 当前遍历Entry的key不为空
                    int h = k.threadLocalHashCode & (len - 1);  // 重新计算该Entry的索引位置
                    if (h != i) { // 如果索引位置不为当前索引位置i
                        tab[i] = null;  // 则将i位置对象清空, 替当前Entry寻找正确的位置(当前对象已经保存在e中了)

                        // Unlike Knuth 6.4 Algorithm R, we must scan until
                        // null because multiple entries could have been stale.
                        // 如果h位置不为null,则向h后寻找当前Entry的位置
                        while (tab[h] != null)
                            h = nextIndex(h, len);
                        tab[h] = e;
                    }
                }
            }
            return i;
        }

回收的算法也很简单,staleSlot开始,清除key为null的Entry,并将不为空的元素放到合适的位置,最后遍历到Entry为空的元素时,跳出循环返回当前索引位置。

这里说一下,这里的 tab[h] != null,这种情况就是哈希碰撞,这种处理hash碰撞的方法就是开放地址法中的线性探测再散列,这里不细讲,列两个公式,相信你会懂:
Hi = (H(key)+ di)MOD m i = 1,2,3,4...,k (k<=m-1)
H(key)为哈希函数;m为哈希表表长;di为增量序列
线性探测再散列 di = 1,2,3,4,...,m-1
二次探测再散列 di = 1^2, -1^2, 2^2, -2^2, 3^2, -3^2, ... +-k^2 (k<=m/2)
伪随机探测再散列 di = 伪随机序列数
这里的哈希函数是:key.threadLocalHashCode & (table.length - 1)

至此,去除陈旧无用的 expungeStaleEntries() 就执行完了,接下来就是一个判断,因为当前又清除了一遍,table里面使用了的size已经变化,当 size >= threshold - threshold / 4 即 数组table长度 len * 2 / 3 - len * 2 / 3 / 4 = 1/2 * len,意味着当清除后如果还是超过一半的话,就进行扩容。那如何扩容呢?resize()啊。

  private void resize() {
        Entry[] oldTab = table;
        int oldLen = oldTab.length;
        int newLen = oldLen * 2;
        Entry[] newTab = new Entry[newLen];
        int count = 0;

        for (int j = 0; j < oldLen; ++j) {
            Entry e = oldTab[j];
            if (e != null) {
                ThreadLocal k = e.get();
                if (k == null) {
                    e.value = null; // Help the GC
                } else {
                    int h = k.threadLocalHashCode & (newLen - 1);
                    // 检测碰撞,
                    while (newTab[h] != null)
                        h = nextIndex(h, newLen);
                    newTab[h] = e;
                    count++;
               }
           }
        }

        setThreshold(newLen);
        size = count;
        table = newTab;
  }

这一部分也是很简单,最重要的就是 int newLen = oldLen * 2; 说明扩容是以两倍进行扩容。resize() 其实就是先申请两倍长度的table数组,然后将数据拷贝到合适位置,然后将新的table数组的引用赋值给原来的table。

get 方法

    public T get() {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null)
                return (T)e.value;
        }
        return setInitialValue();
    }

在外面调用 get 就会当前线程存储的数据,首先拿到当前Thread中的保存的 ThreadLocal.ThreadLocalMap threadLocals,判空,本着先易后难的原则,先看 setInitialValue():

private T setInitialValue() {
        T value = initialValue();  // return null
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
        return value;
}

这里的 value 是恒为null的,在get调用的时候呢,map一定null,就会初始化一个 ThreadLocalMap 给当前Thread,并将为null的value存进去(啥都没存好么)。那这里返回值就为null,意味着当当前Thread没有ThreadLocal时,返回null,符合直觉。

那退一步说,如果当前Thread村过值了呢,那 ThreadLocalMap map 就不会为空,接着调用 ThreadLocalMap 中的 getEntry() 得到想要的Entry。

private Entry getEntry(ThreadLocal key) {
      int i = key.threadLocalHashCode & (table.length - 1);
      Entry e = table[i];
      if (e != null && e.get() == key)
          return e;
      else
          return getEntryAfterMiss(key, i, e);
}

同过哈希函数算出下标,然后比较当前的key(ThreadLocal)是不是要找的那个Thread 的 ThreadLocal,如果不是则调用 getEntryAfterMiss(key, i, e) 从当前节点开始线性查找。

private Entry getEntryAfterMiss(ThreadLocal key, int i, Entry e) {
           Entry[] tab = table;
           int len = tab.length;

           while (e != null) {
               ThreadLocal k = e.get();
               if (k == key)
                   return e;
               if (k == null)
                   expungeStaleEntry(i);
               else
                   i = nextIndex(i, len);
               e = tab[i];
           }
           return null;
       }

没什么好说的,从给定位置进行线性探测(循环),如果是就返回,只不过他在比较好的是判断了当前Entry 是否是陈旧无用的,如果是,就调用expungeStaleEntry(i) 去掉(前面有讲到这个方法)。

需要解释一下两个问题

  1. 为什么循环的终止条件为什么是一旦找到一个空对象就停止返回null(表示没找到)呢?

答: 在进行放的时候,如果哈希碰撞了,就会进行线性探测再散列,现在挨着挨着找,如果当时是存放了数据的话,那么就会放到第一个是空的地方,然后第一个为空的地方不为空了,而现在取的时候都出现null的现象了,说明根本没有存过。

  1. expungeStaleEntry(i) 中的重新放置不会放到当前i之前么?从而导致存了,却取不到数据现象。

答:不会,首先能保证的是从哈希函数算出的下标 H(key) 开始到当前的Entry 都是有效的,因为i开始就判断了 k == key 的,其次 expungeStaleEntry(staleSlot) 是从staleSlot开始,清除key为null的Entry,试想如果当前处理位置的下一位就是 目标Thread 的 ThreadLocalMap ,那么它将会被放在当前位置,因为,当前位置一定为空,从H(key)到当前位置一定都有其他Entry占着位置,这时候在 getEntryAfterMiss(ThreadLocal key, int i, Entry e) 中会再一次取当前位置的值,然后判断。

总结:

  1. 每一个线程都有变量 ThreadLocal.ThreadLocalMap threadLocals保存着自己的 ThreadLocalMap。
  2. ThreadLocal 所操作的是当前线程的 ThreadLocalMap 对象中的 table 数组,并把操作的 ThreadLocal 作为键存储。

自定义ThreadLocal

问题:多线程下,如何实现一个ThreadLoacl
笔者提供的简单例子

public class SimpleThreadLocal<T>{
    /**
     * Key为线程对象,Value为传入的值对象
     */
    private Map<Thread, T> valueMap = Collections.synchronizedMap(new HashMap<Thread, T>());

    /**
     * 设值
     * @param value Map键值对的value
     */
    public void set(T value) {
        valueMap.put(Thread.currentThread(), value);
    }

    /**
     * 取值
     * @return
     */
    public T get() {
        Thread currentThread = Thread.currentThread();
        //返回当前线程对应的变量
        T t = valueMap.get(currentThread);
        //如果当前线程在Map中不存在,则将当前线程存储到Map中
        if (t == null && !valueMap.containsKey(currentThread)) {
            t = initialValue();
            valueMap.put(currentThread, t);
        }
        return t;
    }

    public void remove() {
        valueMap.remove(Thread.currentThread());
    }

    public T initialValue() {
        return null;
    }

    public static void main(String[] args) {

        SimpleThreadLocal<List<String>> threadLocal = new SimpleThreadLocal<>();

        new Thread(() -> {
            List<String> params = new ArrayList<>(3);
            params.add("张三");
            params.add("李四");
            params.add("王五");
            threadLocal.set(params);
            System.out.println(Thread.currentThread().getName());
            threadLocal.get().forEach(param -> System.out.println(param));
        }).start();

        new Thread(() -> {
            try {
                Thread.sleep(1000);
                List<String> params = new ArrayList<>(2);
                params.add("Chinese");
                params.add("English");
                threadLocal.set(params);
                System.out.println(Thread.currentThread().getName());
                threadLocal.get().forEach(param -> System.out.println(param));
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        }).start();
    }
}

运行结果:

Thread-0
张三
李四
王五
Thread-1
Chinese
English

自此,ThreadLocal的分析就结束了。

参考资料:
Android 开发艺术探索
thinking in java
Android与Java中的ThreadLocal
对ThreadLocal实现原理的一点思考
Java并发:ThreadLocal详解
Java多线程编程-(8)-多图深入分析ThreadLocal原理
轻松使用线程 不共享有时是最好的 利用 ThreadLocal 提高可伸缩性

推荐阅读更多精彩内容