JAVA多线程-ThreadLocal线程本地存储

一、关于线程本地存储

线程本地存储是一种自动化机制,可以为使用相同变量的每个不同的线程都创建不同的存储,通过根除对变量的共享来防止任务在共享资源时产生冲突。

因此,如果你有5个线程都要使用变量x所表示的对象,那么线程本地存储就会生成5个用于x的不同的存储块,并且使得你可以将状态与线程关联起来。

二、ThreadLocal是什么?

ThreadLocal是线程本地存储的一种实现方案。它并不是一个Thread,我们也可以称之为线程局部变量,在多线程并发访问时,ThreadLocal类为每个使用该变量的线程都创建一个变量值的副本,每一个线程都可以独立地改变自己的副本,而不会和其他线程的副本发生冲突。从线程的角度来看,就感觉像是每个线程都完全拥有该变量一样。

三、什么情况下使用ThreadLocal?

ThreadLocal的使用场合主要用来解决多线程情况下对数据的读取因线程并发而产生数据不一致的问题。ThreadLocal为每个线程中并发访问的数据提供一个本地副本,然后通过对这个本地副本的访问来执行具体的业务逻辑操作,这样就可以大大减少线程并发控制的复杂度;然而这样做也需要付出一定的代价,需要耗费一部分内存资源,但是相比于线程同步所带来的性能消耗还是要好上那么一点点。

四、ThreadLocal的应用场景

最常见的ThreadLocal的使用场景是用来解决数据库连接、Session管理等等。如:

4.1、 数据库连接管理:

同一事务多DAO共享同一Connection,必须在一个共同的外部类中使用threadLocal保存Connection。

public class ConnectionManager {    
    
    private static ThreadLocal<Connection> connectionHolder = new ThreadLocal<Connection>() {    
        @Override    
        protected Connection initialValue() {    
            Connection conn = null;    
            try {    
                conn = DriverManager.getConnection(    
                        "jdbc:mysql://localhost:3306/test", "username",    
                        "password");    
            } catch (SQLException e) {    
                e.printStackTrace();    
            }    
            return conn;    
        }    
    };    
    
    public static Connection getConnection() {    
        return connectionHolder.get();    
    }    
    
    public static void setConnection(Connection conn) {    
        connectionHolder.set(conn);    
    }    
}

通过上面这种方式就保证了一个线程对应一个数据库连接,保证了事务。因为一般事务都是依赖一个个数据库连接来控制的,如commit,rollback等都是需要获取数据库连接来操作的。

4.2、session管理:

private static final ThreadLocal threadSession = new ThreadLocal();

public static Session getSession() throws InfrastructureException {
    Session s = (Session) threadSession.get();
    try {
        if (s == null) {
            s = getSessionFactory().openSession();
            threadSession.set(s);
        }
    } catch (HibernateException ex) {
        throw new InfrastructureException(ex);
    }
    return s;
}

五、如何使用ThreadLocal

直接看代码:

package com.feizi.java.concurrency.tool;

import java.util.Random;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

/**
 * Created by feizi on 2018/5/29.
 */
public class ThreadLocalHolder {
    private static ThreadLocal<Integer> holder = new ThreadLocal<Integer>(){
        private Random rand = new Random(10);
        protected synchronized Integer initialValue(){
            return rand.nextInt(100);
        }
    };

    public static void increment(){
        holder.set(holder.get() + 1);
    }

    public static Integer get(){
        return holder.get();
    }

    public static void main(String[] args) throws InterruptedException {
        ExecutorService threadPool = Executors.newCachedThreadPool();
        for (int i = 0; i < 5; i++){
            threadPool.execute(new Accessor(i));
        }
        threadPool.shutdown();
    }
}

class Accessor implements Runnable{
    private final int id;

    public Accessor(int id) {
        this.id = id;
    }

    @Override
    public void run() {
        while (!Thread.currentThread().isInterrupted()){
            try {
                TimeUnit.SECONDS.sleep(1);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            ThreadLocalHolder.increment();
            System.out.println(this);
            Thread.yield();
        }
    }

    @Override
    public String toString() {
        return "#" + id + " : " + ThreadLocalHolder.get();
    }
}

控制台输出结果:

#线程#3 : 14
#线程#0 : 81
#线程#4 : 94
#线程#2 : 91
#线程#1 : 47
#线程#1 : 48
#线程#4 : 95
#线程#3 : 15
#线程#2 : 92
#线程#0 : 82
#线程#3 : 16
#线程#0 : 83
#线程#4 : 96
#线程#2 : 93
#线程#1 : 49
#线程#2 : 94
#线程#3 : 17

从上面输出结果,我们看到:每个线程的输出的结果都是隔离的,相互并不影响,#线程#3首次输出14,到了下次再输出的时候变成15#线程#0首次输出81,再次输出82,其他类似。

因为每个单独的线程都被分配了自己的存储,因为它们每个都需要跟踪自己的计数值,即便只有一个ThreadLocalHolder对象。

六、ThreadLocal的实现

ThreadLocal中主要提供的方法:

1、public T get(){}

主要用于获取ThreadLocal在当前线程中保存的变量副本

2、public void set(T value) {}

主要用于设置当前线程中变量的副本

3、public void remove() {}

主要用于移除当前线程中变量的副本

4、protected T initialValue() {}

它是一个被protected修饰的方法,主要用于在实例化时进行重载的,是一个延时加载方法

6.1、get()方法的实现:

public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    return setInitialValue();
}

首先获取当前线程t,然后根据当前线程t调用getMap(t)方法获取一个ThreadLocalMap类型的map,之后判断这个map是否为空,如果不为空,则根据this(表示当前ThreadLocal对象)获取一个<key,value>键值对的的Entry,需要注意的是,这里传入的this,而不是当前线程t,如果map为空,则调用setInitialValue()初始化一个value,默认是返回null。

然后,我们跟一下getMap(t)中做了什么操作:

ThreadLocalMap getMap(Thread t) {
    //返回当前线程t中的一个成员变量threadLocals
    return t.threadLocals;
}

从上面可以看到,getMap(t)中返回了当前线程t中的一个成员变量threadLocals,接着再继续往下跟,看一下threadLocals是什么东西:

ThreadLocal.ThreadLocalMap threadLocals = null;

可以看到,threadLocals实际就是一个ThreadLocalMap,这个类是ThreadLocald的一个内部类,然后我们再看一下ThreadLocalMapd的定义:

static class ThreadLocalMap {

    /**
        * The entries in this hash map extend WeakReference, using
        * its main ref field as the key (which is always a
        * ThreadLocal object).  Note that null keys (i.e. entry.get()
        * == null) mean that the key is no longer referenced, so the
        * entry can be expunged from table.  Such entries are referred to
        * as "stale entries" in the code that follows.
        */
    static class Entry extends WeakReference<ThreadLocal<?>> {
        /** The value associated with this ThreadLocal. */
        Object value;

        Entry(ThreadLocal<?> k, Object v) {
            super(k);
            value = v;
        }
    }

    /**
        * The initial capacity -- MUST be a power of two.
        */
    private static final int INITIAL_CAPACITY = 16;

    /**
        * The table, resized as necessary.
        * table.length MUST always be a power of two.
        */
    private Entry[] table;

    /**
        * The number of entries in the table.
        */
    private int size = 0;

    /**
        * The next size value at which to resize.
        */
    private int threshold; // Default to 0
}

从上面的定义我们大致可以看出,ThreadLocalMap的键值对Entry类继承了WeakReference,并且使用ThreadLocal<?>作为key值进行存储。

6.2、setInitialValue()方法的实现

private T setInitialValue() {
    T value = initialValue();
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
    return value;
}

首先,调用initialValue()进行初始化value值,我们跟一下这个initialValue()方法:

protected T initialValue() {
    return null;
}

从上述initialValue()方法中,我们可以看到直接return返回了一个null,获取当前线程t,根据当前线程t获取ThreadLocalMap类型的map,此时再判断map是否为空,不为空则直接设置<key,value>键值对,注意此处的key仍然还是this(表示当前threadLocal对象),为空则调用createMap(Thread t, T firstValue)方法:

void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}

从上面代码我们可以看出,直接new了一个ThreadLocalMap对象,以this(当前threadLocal对象)作为key,传入的value设置为值,并且赋给当前线程t的成员变量threadLocals。

6.3、set()方法的实现:

public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
}

首先获取当前线程t,然后根据当前线程t获取ThreadLocalMap类型的map,判断map不为空就设置键值对,为空就调用createMap初始化一个新的map。

6.4、remove()方法的实现:

public void remove() {
    ThreadLocalMap m = getMap(Thread.currentThread());
    if (m != null)
        m.remove(this);
}

首先获取当前线程t,根据当前线程t获取ThreadLocalMap,判断不为空,就根据this(当前threadLocal对象)移除相对应的value。

通过上面的分析,我们就大致明白了ThreadLocal的基本工作原理:

首先,每个线程Thread内部都拥有一个threadLocals变量(这个是在Thread类中定义的),这个threadLocals是ThreadLocal.ThreadLocalMap类型的,也就是一个Map,这个map是整个threadLocal得以实现的核心,它用于存储变量的副本。key值为this(即当前threadLocal对象),value为变量副本(T类型的变量)。

当我们new一个ThreadLocal对象时,即初始化ThreadLocal),这个Thread类的threadLocals为null,然后进行get()或者set()的时候,都需要对这个Thread类的threadLocals进行初始化操作(步骤都是先获取当前线程t,然后根据t获取ThreadLocalMap,判断如果为空,就初始化new一个),然后以this(当前ThreadLocal变量)为key,以threadLocal需要保存的副本变量为value,存到Thread类的threadLocals中。之后,在线程里面,如果需要使用变量副本,就可以通过get()方法,根据当前线程t去获取threadLocals中对应存储的value副本值。

ok,上面还是有些啰嗦,我们再总结一下:

  1. Thread类中定义了一个ThreadLocal.ThreadLocalMap类型的成员变量threadLocals,用于保存变量的副本
  2. ThreadLocal类中定义了一个ThreadLocalMap的静态内部类
  3. ThreadLocalMap类中定义了一个继承WeakReference类的Entry键值对,并且这个Entry键值对有些特殊,特殊之处就在于它的key必须是ThreadLocal类型的
  4. threadLocals在保存变量副本的时候,以this(当前ThreadLocal变量)为key,以传入需要保存的变量副本为value进行存储
  5. ThreadLocal在get()的时候,会先获取当前线程t,然后根据t去获取ThreadLocalMap,之后对这个ThreadLocalMap进行判空,如果不为空,则根据this(当前ThreadLocal变量)获取ThreadLocalMap类的Entry键值对,再对Entry键值对进行判空,如果不为空就取出变量副本value进行return,如果ThreadLocalMap为空,就调用setInitialValue()方法就行初始化,并且返回一个null的value默认值。
  6. ThreadLocal在set()的时候,也会先获取当前线程t,然后根据t去获取一个ThreadLocalMap,之后对这个ThreadLocalMap进行判空,如果不为空,就以this(当前ThreadLocal变量)为key,需要保存的变量副本为value设置键值对,否则就调用createMap初始化一个ThreadLocalMap。
  7. ThreadLocal在setInitialValue()的时候,同上面的set()过程类似,唯一的区别是setInitialValue()方法会返回一个默认值为null的value(需要对value初始化)
  8. ThreadLocal在remove()的时候,同样也会先获取当前线程t,然后根据t获取一个ThreadLocalMap,之后再对这个ThreadLocalMap进行判空,如果不为空,则根据this(当前ThreadLocal变量)移除对应存储的变量副本value。

其实简单来说,大致就是每个线程都维护了一个map,而这个map的key就是当前threadLocal变量,而值则是我们需要set的那个变量的value,之后每次线程在get取值的时候都是从自己的变量中取值,既然是从自己的变量中取值,那么当然也就不存在线程安全的问题了。ThreadLocal只是充当一个key的角色,然后顺带给每个线程提供一个初始值。

多线程安全性解决方案

  1. 采用synchronized进行同步控制,但是效率略低,使得并发变同步(串行)
  2. 采用ThreadLocal线程本地存储,为每个使用该变量的线程都存储一个本地变量副本(线程互不相干)

两种线程安全方案的区别

  1. synchronized同步机制采用了“以时间换空间”的方式,仅仅只提供一份变量,让参与的多个不同的线程排队进行访问
  2. ThreadLocal采用“以空间换时间”的方式,为参与的每个线程都各自提供一份本地副本,因此可以做到同时访问而互不影响。

综上所述,ThreadLocal通常占用内存较大,但是速度快;而synchronized则占用内存小,速度相对而言比较慢。如果在内存比较充足的情况,对并发部分的执行效率要求很高的话,那么就是ThreadLocal派上用场的时候了,一般情况下还是synchronized用的居多。

原文参考

  1. http://www.cnblogs.com/dolphin0520/p/3920407.html
  2. https://www.cnblogs.com/zhangjk1993/archive/2017/03/29/6641745.html#_label4
  3. https://www.cnblogs.com/xinxin-ting/p/7070826.html
  4. https://blog.csdn.net/sean417/article/details/69948561

推荐阅读更多精彩内容