JAVA concurrency 之ThreadLocal源码详解,80%人不会

在这里插入图片描述

概述

1、在并发编程中,为了控制数据的正确性,我们往往需要使用锁来来保证代码块的执行隔离性。但是在很多时候锁的开销太大了,而在某些情况下,我们的局部变量是线程私有的,每个线程都会有自己的独自的变/量,这个时候我们可以不对这部分数据进行加锁操作。于是ThredLocal应运而生。

2、ThredLocal顾名思义,是线程持有的本地变量,存放在ThredLocal中的变量不会同步到其他线程以及主线程,所有线程对于其他的线程变量都是不可见的。那么我们来看下它是如何实现的吧。
3、注意:光理论是不够的。在此免费赠送5大JAVA架构项目实战教程及大厂面试题库,有兴趣的可以进裙 783802103获取,没基础勿进哦!

实现原理
ThredLocal在内部实现了一个静态类ThreadLocalMap来对于变量进行存储,并且在Thread类的内部使用到了这两个成员变量

 ThreadLocal.ThreadLocalMap threadLocals = null;
    ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;

来调用ThreadLocalMap存储当前线程的内部变量。

ThreadLocalMap的实现
ThreadLocalMap是键值对结构的map,但是他没有直接使用HashMap,而是自己实现了一个。

Entry

Entry是ThreadLocalMap中定义的map节点,他以ThreadLocal弱引用为key,以Object为value的K-V形式的节点。使用弱引用是为了可以及时释放内存避免内存泄漏。

 static class Entry extends WeakReference<ThreadLocal<?>> {
        Object value;

        Entry(ThreadLocal<?> k, Object v) {
            super(k);
            value = v;
        }
    }

这里和HashMap不一样的地方在于两者解决hash冲突的方式的不同,HashMap采用的是链地址法,遇到冲突的时候将冲突的数据放入同一链表之中,等到链表到了一定程度再将链表转化为红黑树。而ThreadLocalMap实现采用的是开放寻址法,它内部没有使用链表结构,因此Entry内部没有next或者是prev指针。ThreadLocalMap的开放寻址法是怎么实现的,请看接下来的源码。

成员变量

// map默认的初始化大小
    private static final int INITIAL_CAPACITY = 16;

    // 存储map节点数据的数组
    private Entry[] table;

    // map大小
    private int size = 0;

    // 临界值,达到这个值的时候需要扩容
    private int threshold;

    // 当临界值达到2/3的时候扩容
    private void setThreshold(int len) {
        threshold = len * 2 / 3;
    }

这里的数组大小始终是2的幂次,原因和HashMap一样,是为了在计算hash偏移的时候减少碰撞。

构造函数

 ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
        // 初始化table
        table = new Entry[INITIAL_CAPACITY];
        // 计算第一个值的hash值
        int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
        // 创建新的节点
        table[i] = new Entry(firstKey, firstValue);
        size = 1;
        setThreshold(INITIAL_CAPACITY);
    }

set方法

private void set(ThreadLocal<?> key, Object value) {

        // 获取ThreadLocal的hash值偏移量
        Entry[] tab = table;
        int len = tab.length;
        int i = key.threadLocalHashCode & (len-1);

        // 遍历数组直到节点为空
        for (Entry e = tab[i];
                e != null;
                e = tab[i = nextIndex(i, len)]) {
            ThreadLocal<?> k = e.get();

            // 如果节点key相等,即找到了我们想要的节点,
            // 将值赋予节点
            if (k == key) {
                e.value = value;
                return;
            }

            // 如果节点的key为空,说明弱引用已经把key回收了,那么需要做一波清理
            if (k == null) {
                replaceStaleEntry(key, value, i);
                return;
            }
        }

        // 如果没有找到对应的节点说明该key不存在,创建新节点
        tab[i] = new Entry(key, value);
        int sz = ++size;
        // 进行清理,如果清理结果没能清理掉任何的旧节点,
        // 并且数组大小超出了临界值,就进行rehash操作扩容
        if (!cleanSomeSlots(i, sz) && sz >= threshold)
            rehash();
    }

看到这段代码,开放寻址法的实现原理可以说是非常清楚了。首先计算节点的hash值,找到对应的位置,查看该位置是否为空,如果是空则插入,如果不为空,则顺延至下个节点,直到找到空的位置插入。那么我们的查询逻辑也呼之欲出:计算节点的hash值,找到对应的位置,查看该节点是否是我们想要找的节点,如果不是,则继续往下顺序寻找。

get方法

 private Entry getEntry(ThreadLocal<?> key) {
        // 计算hash值
        int i = key.threadLocalHashCode & (table.length - 1);
        // 获取该hash值对应的数组节点
        Entry e = table[i];
        if (e != null && e.get() == key)
            // 如果节点不为空并且key一致,说明是我们找的节点,直接返回
            return e;
        else
            // 否则继续往后寻找
            return getEntryAfterMiss(key, i, e);
    }
    
    private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
        Entry[] tab = table;
        int len = tab.length;

        // 如果节点不为空就一直找下去
        while (e != null) {
            ThreadLocal<?> k = e.get();
            // key相同则说明找到,返回该节点
            if (k == key)
                return e;
            // key为空进行一次清理
            if (k == null)
                expungeStaleEntry(i);
            else
                i = nextIndex(i, len);
            e = tab[i];
        }
        return null;
    }

replaceStaleEntry

 // 这个方法的作用是在set操作的时候进行清理
    private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                                       int staleSlot) {
        Entry[] tab = table;
        int len = tab.length;
        Entry e;
        
        // slotToExpunge是之后开始清理的节点位置
        int slotToExpunge = staleSlot;
        // 往前寻找找到第一个为空的节点记录下位置
        for (int i = prevIndex(staleSlot, len);
                (e = tab[i]) != null;
                i = prevIndex(i, len))
            if (e.get() == null)
                slotToExpunge = i;

        // 从staleSlot开始向后遍历直到节点为空
        for (int i = nextIndex(staleSlot, len);
                (e = tab[i]) != null;
                i = nextIndex(i, len)) {
            ThreadLocal<?> k = e.get();

            if (k == key) {
                // 如果节点的key一致,替换value的值
                e.value = value;

                // 将当前节点和staleSlot上的节点互换位置(将后方的值放到前方来,之前的值等待回收)
                tab[i] = tab[staleSlot];
                tab[staleSlot] = e;

                // 如果slotToExpunge和staleSlot相等,说明前面没有需要清理的节点
                // 则从当前节点开始进行清理
                if (slotToExpunge == staleSlot)
                    slotToExpunge = i;
                // 进行节点清理
                cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                return;
            }

            // 如果key为空并且slotToExpunge和staleSlot相等
            // 把slotToExpunge赋值为当前节点
            if (k == null && slotToExpunge == staleSlot)
                slotToExpunge = i;
        }

        // 如果没法找到key相等的节点,
        // 则清空当前节点的value并生成新的节点
        tab[staleSlot].value = null;
        tab[staleSlot] = new Entry(key, value);

        // 如果slotToExpunge和staleSlot不相等则需要进行清理(因为前方发现空的节点)
        if (slotToExpunge != staleSlot)
            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
    }

expungeStaleEntry

// 对节点进行清理
    private int expungeStaleEntry(int staleSlot) {
        Entry[] tab = table;
        int len = tab.length;

        // 释放当节点
        tab[staleSlot].value = null;
        tab[staleSlot] = null;
        size--;

        Entry e;
        int i;
        // 循环寻找到第一个空节点
        for (i = nextIndex(staleSlot, len);
                (e = tab[i]) != null;
                i = nextIndex(i, len)) {
            ThreadLocal<?> k = e.get();
            // key为空进行节点释放
            if (k == null) {
                e.value = null;
                tab[i] = null;
                size--;
            } else {
                // 如果key不为空,找到对应的节点应该在的位置
                int h = k.threadLocalHashCode & (len - 1);
                if (h != i) {
                    // 如果和当前节点位置不同,
                    // 则清理节点并且循环找到后面的非空节点移到前面来
                    tab[i] = null;

                    while (tab[h] != null)
                        h = nextIndex(h, len);
                    tab[h] = e;
                }
            }
        }