面试 - handle之看Looper详谈ThreadLocal(3)

面试handler系列:
面试 - handle使用及原理(1)
面试 - handle之详谈Message(2)

回顾Looper知识点

在第一篇中 面试 - handle使用及原理(1) 在讲到Looper部分源码的时候,Looper生成跟返回有以下源码:

    static final ThreadLocal<Looper> sThreadLocal = new ThreadLocal<Looper>();
    private static void prepare(boolean quitAllowed) {
        if (sThreadLocal.get() != null) {
            throw new RuntimeException("Only one Looper may be created per thread");
        }
        sThreadLocal.set(new Looper(quitAllowed));
    }
    public static @Nullable Looper myLooper() {
        return sThreadLocal.get();
    }

可以看到Looper的对象是放到sThreadLocal里面的,sThreadLocal保持只有一个Looper对象。我们知道,Handler需要获取当前线程中的Looper对象,Looper.loop()是作用于当前线程,并且不同线程拥有的Looper对象不同。使用ThreadLocalLooper进行保存,那就实现了在不同的线程中读取到的Looper对象就是相应的那个线程中的。

这里的ThreadLocal是何方神圣,可以保持每个线程中拥有独立的Looper对象?

ThreadLocal的定义

ThreadLocal 是 JDK底层提供的一个解决多线程并发问题的工具类,它为每个线程提供了一个本地的副本变量机制,实现了和其它线程隔离,这种变量只在本线程的生命周期内起作用。当使用ThreadLocal维护变量时,ThreadLocal为每个使用该变量的线程提供独立的变量副本,所以每一个线程都可以独立地改变自己的副本,而不会影响其它线程所对应的副本。

言简意赅: ThreadLocal是线程内部的数据存储类,自己线程存储的数据只有自己才能获取到,其它线程获取不到。

大意图.png

ThreadLocal例子

class AllTestMain{
    companion object{
       var mThread = ThreadLocal<String>()

        @JvmStatic
        fun main(args: Array<String>) {
            mThread.set("main()")

            //线程A
           var threadA = Thread(Runnable {
                mThread.set("线程A")
                println(mThread.get())
            })

            //线程B
            var threadB = Thread(Runnable {
                mThread.set("线程B")
                println(mThread.get())
            })

            println(mThread.get())
            Thread.sleep(500)
            threadA.start()

            Thread.sleep(500)
            threadB.start()

            Thread.sleep(500)
            println(mThread.get())

        }
    }
}

----------------------打印出来的结果---------------------------------------------------------------------------
main()
线程A
线程B
main()

根据上面的结果可以看到,同一个mThread对象,不同线程中读取出来的信息是不同的,读取的信息都是自己线程下设置的信息。

为什么会有这种效果? 这个咱们就得来看看实现的原理了

ThreadLocal原理

分析原理的话,我们从mThread.set()方法开始为切入点吧

set()方法
    public void set(T value) {
        Thread t = Thread.currentThread();
      //关键代码1
        ThreadLocalMap map = getMap(t);
      //关键代码2
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }

从上面的源码可以看到
关键代码1:传入线程的当前线程,并返回一个ThreadLocalMap类型的map对象
关键代码2:判断map是否为null 不为null传入ThreadLocalvalue, 为null则创建map并传入线程对象tvalue

先来看一下createMap(t, value)方法

    void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }

可以看到实际上存储到ThreadLocal中的数据是存到ThreadLocalMap中,这个一看就是HashMap的类似数据结构,key传入的是this对象是ThreadLocal类型的,value是泛型,我们自己定义的一个数据类型。

ThreadLocalMap类

接下来看 ThreadLocalMap类的大致一些信息

static class ThreadLocalMap {
        //内部类
        static class Entry extends WeakReference<ThreadLocal<?>> {
            /** The value associated with this ThreadLocal. */
            Object value;
            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }
        private static final int INITIAL_CAPACITY = 16;
        private Entry[] table;
    
          ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
          //新建Entry[]数组,初始化16
            table = new Entry[INITIAL_CAPACITY];
          //hash算法计算一下,位于数组的下标 哈希映射
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
          //将 key 跟 value 存储进去table[i]对应的哈希映射值
            table[i] = new Entry(firstKey, firstValue);
            size = 1;
            //设置阈值
            setThreshold(INITIAL_CAPACITY);
        } 
        private void setThreshold(int len) {
            threshold = len * 2 / 3;
        }
}

ThreadLocalMap中可以看到Entry是存储单位,而存储结构实际上是一个Entry[] table数组。
上面的构造函数可以看到类似hashmap先用Key值进行hash算法计算映射,求出位于数组的下标的多少,然后将传入的keyvalue值初始化一个Entry对象,并放到Entry[] table中。然后设置阈值 len*2/3 (len = 16) 如果超过阈值便要重新分配table

Entry的数据类中继承了WeakReference<ThreadLocal<?>>使用弱引用挟持着ThreadLocal防止挟持线程的数据,导致线程生命周期受到影响进而导致内存泄漏

这里简要说明一下常用的几个引用的作用
强应用:这个是我们常用的平时都是new Object(),就算是OOM都不会被虚拟机回收
软引用: new SoftReference<T>() ,当要发生OOM的时候会被回收掉
弱引用:new WeakReference<T>(),当GC的时候就会被回收掉

返回最开始的mThread.set()方法里面的map.set(this, value)

        private void set(ThreadLocal<?> key, Object value) {
            //获取table数组
            Entry[] tab = table;
            //tab数组的长度
            int len = tab.length;
           // 计算一下下标
            int i = key.threadLocalHashCode & (len-1);

            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                ThreadLocal<?> k = e.get();
              //如果存在旧数据则替换为新数据
                if (k == key) {
                    e.value = value;
                    return;
                }
              //如果k为null则
                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }
            //正常插入数据
            tab[i] = new Entry(key, value);
            int sz = ++size;
            //超过阈值的时候 重新计算哈希并重组数组
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }

上面的简单讲解一下
如果传入的key计算出下标,根据下标有三种情况处理:

  1. tab[i]为null :直接插入新的数据,如果sz >= threshold则重新扩增数组rehash()
    2.tab[i]不为null分下面两种情况
  • key不为null :则更新新的数据e.value = value并返回
  • key为null:则调用replaceStaleEntry(key, value, i)将数据存储进去,并返回

最主要看下面的几个方法cleanSomeSlots

        private boolean cleanSomeSlots(int i, int n) {
            boolean removed = false;
            Entry[] tab = table;
            int len = tab.length;
            do {
                i = nextIndex(i, len);
                Entry e = tab[i];
                if (e != null && e.get() == null) {
                    n = len;
                    removed = true;
                    i = expungeStaleEntry(i);
                }
            } while ( (n >>>= 1) != 0);
            return removed;
        }

        private void rehash() {
            expungeStaleEntries();

            // Use lower threshold for doubling to avoid hysteresis
            if (size >= threshold - threshold / 4)
                resize();
        }



cleanSomeSlots主要就是将那些key不为null,但是value为空的调用expungeStaleEntry(i)清除掉,如果清除了就是返回ture否则返回false

expungeStaleEntry(i)是真正删除数据的地方,删除那些key存在,但是value不存在的Entry数据

rehash()就是先调用expungeStaleEntry(),然后再根据if (size >= threshold - threshold / 4)去决定要不要resize()

resize()这个就是现在的关键了

        private void resize() {
            Entry[] oldTab = table;
            int oldLen = oldTab.length;
            //关键代码1
            int newLen = oldLen * 2;
            Entry[] newTab = new Entry[newLen];
            int count = 0;
          // 关键代码2
            for (int j = 0; j < oldLen; ++j) {
                Entry e = oldTab[j];
                if (e != null) {
                    ThreadLocal<?> k = e.get();
                    if (k == null) {
                        e.value = null; // Help the GC
                    } else {
                        int h = k.threadLocalHashCode & (newLen - 1);
                        while (newTab[h] != null)
                            h = nextIndex(h, newLen);
                        newTab[h] = e;
                        count++;
                    }
                }
            }
          //关键代码3
            setThreshold(newLen);
            size = count;
            table = newTab;
        }

简单的解释一下上面的源码:

关键代码1:生成新的数组,新的数组是旧的数据的两倍

关键代码2:轮训旧数据,将e!= nullEntry取出
如果key为null,则清除掉对应的value
如果key不为null,则计算出对应的新的数组的下标 并存储进去

关键代码3:设置新的阈值。

get 方法

    public T get() {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();
    }

先看一下没有数据的时候返回什么的吧

    private T setInitialValue() {
        T value = initialValue();
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
        return value;
    }
  protected T initialValue() {
        return null;
    }

看见了吗,返回了一个空值,并向ThreadLocalMap中插入(key=this,value = null)

那有数据的时候ThreadLocalMap.Entry e = map.getEntry(this)
看一下源码

        private Entry getEntry(ThreadLocal<?> key) {
            int i = key.threadLocalHashCode & (table.length - 1);
            Entry e = table[i];
            if (e != null && e.get() == key)
                return e;
            else
                return getEntryAfterMiss(key, i, e);
        }

计算出数组下标,然后比较当前的key是不是要找的那个 ThreadLocal,如果不是则调用 getEntryAfterMiss(key, i, e) 从当前节点开始线性查找。

        private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
            Entry[] tab = table;
            int len = tab.length;

            while (e != null) {
                ThreadLocal<?> k = e.get();
                if (k == key)
                    return e;
                if (k == null)
                    expungeStaleEntry(i);
                else
                    i = nextIndex(i, len);
                e = tab[i];
            }
            return null;
        }

有便返回没有便返回空,发现k==null的时候再调用了expungeStaleEntry(i)去掉该元素。

总结:
1.每个线程都存在ThreadLocalMap,在线程中通过ThreadLoad<T>获取到对应线程下的ThreadLocalMap然后在里面传入ThreadLoad作为键去进行数据的存储跟读取。

例子.png

自己实现类似功能

//定义基本的功能
class LocalThreadExp<T>() {

    var mHashMap = ConcurrentHashMap<Thread, T>()

    fun set(value: T) {
        mHashMap[Thread.currentThread()] = value
    }

    fun get(): T? {
        val thread = Thread.currentThread()
        var  o = mHashMap[thread]
        return o
    }
}
//这个是测试类
class AllTestMain {
    companion object {
        var local = LocalThreadExp<String>()
        @JvmStatic
        fun main(args: Array<String>) {
            local.set("main()")
            //线程A
            var threadA = Thread(Runnable {
                local.set("线程A")
                println(local.get())
            })
            //线程B
            var threadB = Thread(Runnable {
                local.set("线程B")
                println(local.get())
            })
            println(local.get())
            Thread.sleep(500)
            threadA.start()
            Thread.sleep(500)
            threadB.start()
            Thread.sleep(500)
            println(local.get())
        }
    }
}

------------输出打印的信息--------------------------------------------------------
main()
线程A
线程B
main()

关于ThreadLocal的就到这里,如果有什么疑问或者面试碰到的问题欢迎在评论区留言

推荐阅读更多精彩内容