zoukankan      html  css  js  c++  java
  • JDK源码之ThreadLocal 类分析

    一 概述

    ThreadLocal类提供了线程局部 (thread-local) 变量。这些变量与普通变量不同,每个线程都可以通过其 get 或 set方法来访问自己的独立初始化的变量副本
    ThreadLocal 实例通常是类中的 private static 字段,它们希望将状态与某一个线程(例如,用户 ID 或事务 ID)相关联,类里面定义了一个map,key为ThreadLocal,value为值,存储每一个线程的变量
    Thread类里面引用了这个内部类的map实例,从而达到线程隔离

    二 源码分析

    属性

            // 获取下一个hashCode值
            private final int threadLocalHashCode = nextHashCode();
    
            // 获取下一个hashCode,ThreadLocal 中使用了斐波那契散列法,来保证哈希表的离散度
            private static int nextHashCode() {
                // 每一次获取值时,加上HASH_INCREMENT为下一次获取的值
                return nextHashCode.getAndAdd(HASH_INCREMENT);
            }
            // 开始为0,每次创建ThreadLocal实例,值都会累加
            private static AtomicInteger nextHashCode = new AtomicInteger();
    
            // 加数值: 0x61c88647 = 2^32 * 黄金分割比(0.618),
            // 斐波那契数列: 当n趋向于无穷大时,前一项与后一项的比值越来越逼近黄金比,即0.618
            private static final int HASH_INCREMENT = 0x61c88647;
            
            //提供给子类初始化值使用
            protected T initialValue() {
                return null;
            }
    

    核心方法

            // 通过Supplier函数初始化变量值,使用子类构造器进行初始化
            public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
                return new ThreadLocal.SuppliedThreadLocal<>(supplier);
            }
    
            // ThreadLocal扩展子类,接收一个Supplier函数进行值初始化
            static final class SuppliedThreadLocal<T> extends ThreadLocal<T> {
                private final Supplier<? extends T> supplier;
                SuppliedThreadLocal(Supplier<? extends T> supplier) {
                    this.supplier = Objects.requireNonNull(supplier);
                }
                // 覆写父类方法
                @Override
                protected T initialValue() {
                    return supplier.get();
                }
            }
    
            public ThreadLocal() {}
    
            //  返回当前线程的变量副本值
            public T get() {
                Thread t = Thread.currentThread();
                // 获取线程实例引用的ThreadLocalMap
                ThreadLocal.ThreadLocalMap map = getMap(t);
                if (map != null) {
                    ThreadLocal.ThreadLocalMap.Entry e = map.getEntry(this);
                    if (e != null) {
                        @SuppressWarnings("unchecked")
                        T result = (T)e.value;
                        return result;
                    }
                }
                // 值为空,重新初始化
                return setInitialValue();
            }
    
            ThreadLocal.ThreadLocalMap getMap(Thread t) { return t.threadLocals; }
    
            private T setInitialValue() {
                T value = initialValue();
                Thread t = Thread.currentThread();
                ThreadLocal.ThreadLocalMap map = getMap(t);
                // 赋值
                if (map != null)
                    map.set(this, value);
                else
                    createMap(t, value);
                return value;
            }
    
            void createMap(Thread t, T firstValue) {
                t.threadLocals = new ThreadLocal.ThreadLocalMap(this, firstValue);
            }
    
            // 设置ThreadLocal的值,还是设置的thread实例引用的map
            public void set(T value) {
                Thread t = Thread.currentThread();
                ThreadLocal.ThreadLocalMap map = getMap(t);
                if (map != null)
                    map.set(this, value);
                else
                    createMap(t, value);
            }
    
            //清除值,key是ThreadLocal,所以直接使用this
            public void remove() {
                ThreadLocal.ThreadLocalMap m = getMap(Thread.currentThread());
                if (m != null)
                    // 调用map的remove方法
                    m.remove(this);
            }
    
            // Thread 类里面创建线程时候初始化init方法调用的,主要是复制参数中的table构建一个新map返回
            static ThreadLocal.ThreadLocalMap createInheritedMap(ThreadLocal.ThreadLocalMap parentMap) {
                return new ThreadLocal.ThreadLocalMap(parentMap);
            }
    
            // 只允许子类调用此方法,ThreadLocal调用直接抛异常
            T childValue(T parentValue) { throw new UnsupportedOperationException(); }
    
    

    三 静态内部类ThreadLocalMap

    ThreadLocal中的静态内部类ThreadLocalMap,这个类本质上是一个map,和HashMap之类的实现相似,依然是key-value的形式,其中有一个内部类Entry,其中key可以看做是ThreadLocal实例,
    在ThreadLocal中并没有对于ThreadLocalMap的引用,ThreadLocalMap的引用在Thread类中
    每个线程在向ThreadLocal里塞值的时候,其实都是向自己所持有的ThreadLocalMap里塞入数据;
    读的时候同理,首先从自己线程中取出自己持有的ThreadLocalMap,然后再根据ThreadLocal引用作为key取出value,
    基于以上描述,ThreadLocal实现了变量的线程隔离

    Entry

                /**
                 * Entry继承WeakReference,(弱引用:当一个对象仅仅被weak reference指向, 而没有任何其他strong reference指向的时候, 如果GC运行, 那么这个对象就会被回收)
                 * 并且用ThreadLocal作为key.如果key为null
                 * (entry.get() == null)表示key不再被引用,表示ThreadLocal对象被回收
                 * 因此这时候entry也可以从table从清除。
                 */
                static class Entry extends WeakReference<ThreadLocal<?>> {
                    Object value;
                    Entry(ThreadLocal<?> k, Object v) {
                        super(k);
                        value = v;
                    }
                }
    
                // 初始容量
                private static final int INITIAL_CAPACITY = 16;
    
                // 存放数据的数组
                private Entry[] table;
    
                //  数组里面entrys的个数,可以用于判断table当前使用量是否超过负因子
                private int size = 0;
    
                // 进行扩容的阈值,表使用量大于它的时候进行扩容
                private int threshold; // Default to 0
    
                // 设置阈值为参数的三分之二
                private void setThreshold(int len) { threshold = len * 2 / 3;}
    

    核心方法

                /**
                 * ThreadLocalMap使用线性探测法来解决哈希冲突,线性探测法的地址增量di = 1, 2, ... , m-1,其中,i为探测次数。
                 * 该方法一次探测下一个地址,直到有空的地址后插入,若整个空间都找不到空余的地址,则产生溢出。
                 * 假设当前table长度为16,也就是说如果计算出来key的hash值为14,如果table[14]上已经有值,并且其key与当前key不一致,那么就发生了hash冲突,
                 * 这个时候将14加1得到15,取table[15]进行判断,这个时候如果还是冲突会继续下一个,如果是最后的则重新回到0,取table[0],以此类推,直到可以插入。
                 * 可以把table看成一个环形数组
                 */
                // 获取环形数组的下一个索引
                private static int nextIndex(int i, int len) {
                    return ((i + 1 < len) ? i + 1 : 0);
                }
    
                // 获取环形数组的上一个索引
                private static int prevIndex(int i, int len) {
                    return ((i - 1 >= 0) ? i - 1 : len - 1);
                }
    
                // 构造器初始化
                ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
                    // 初始化table
                    table = new Entry[INITIAL_CAPACITY];
                    // 计算索引, & (INITIAL_CAPACITY - 1),这是取模的一种方式,对于2的幂作为模数取模,用此代替%(2^n),目的是均匀分布在数组中
                    int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
                    // 设置值
                    table[i] = new Entry(firstKey, firstValue);
                    size = 1;
                    // 设置阈值
                    setThreshold(INITIAL_CAPACITY);
                }
    
                /**
                 * 构建一个包含所有parentMap中Inheritable ThreadLocals的ThreadLocalMap返回
                 * 该函数只被 createInheritedMap() 调用.即只在Thread类的init方法里面调用(init方法调用了createInheritedMap()),目的是将父线程的变量值复制到当前线程中
                 */
                private ThreadLocalMap(ThreadLocalMap parentMap) {
                    Entry[] parentTable = parentMap.table;
                    int len = parentTable.length;
                    setThreshold(len);
                    table = new Entry[len];
                    // 逐一复制 parentMap 的记录到当前线程中
                    for (int j = 0; j < len; j++) {
                        Entry e = parentTable[j];
                        if (e != null) {
                            //此处获取的都是InheritableThreadLocal类,ThreadLocal的子类,用于父线程变量传递的类
                            ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
                            if (key != null) {
                                // 如果是ThreadLocal,就会抛异常
                                Object value = key.childValue(e.value);
                                Entry c = new Entry(key, value);
                                int h = key.threadLocalHashCode & (len - 1);
                                while (table[h] != null)
                                    h = nextIndex(h, len);
                                table[h] = c;
                                size++;
                            }
                        }
                    }
                }
    
    
                private Entry getEntry(ThreadLocal<?> key) {
                    // 计算数组下标
                    int i = key.threadLocalHashCode & (table.length - 1);
                    Entry e = table[i];
                    if (e != null && e.get() == key)
                        return e;
                    else
                        // e不为空,但是key不相等时候再找
                        return getEntryAfterMiss(key, i, e);
                }
    
                private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
                    Entry[] tab = table;
                    int len = tab.length;
                    // 线性探测法,循环查找
                    while (e != null) {
                        ThreadLocal<?> k = e.get();
                        if (k == key)
                            // 找到就返回
                            return e;
                        if (k == null)
                            // 如果key为null,则清除掉无效的entry
                            expungeStaleEntry(i);
                        else
                            // 环形向后扫描
                            i = nextIndex(i, len);
                        e = tab[i];
                    }
                    return null;
                }
    
    
                private void set(ThreadLocal<?> key, Object value) {
                    Entry[] tab = table;
                    int len = tab.length;
                    int i = key.threadLocalHashCode & (len-1);
                    /**
                     * 根据获取到的索引进行循环,如果当前索引上的table[i]不为空,在没有return的情况下,
                     * 就使用nextIndex()获取下一个地址,即线性探测法。
                     */
                    for (Entry e = tab[i];
                         e != null;
                         e = tab[i = nextIndex(i, len)]) {
                        ThreadLocal<?> k = e.get();
                        if (k == key) {
                            // key相同则更新value,set成功
                            e.value = value;
                            return;
                        }
                        /**
                         * table[i]上的key为空,说明被回收了
                         * 这个时候说明改table[i]可以重新使用,用新的key-value将其替换,并删除其他无效的entry
                         */
                        if (k == null) {
                            replaceStaleEntry(key, value, i);
                            return;
                        }
                    }
    
                    tab[i] = new Entry(key, value);
                    int sz = ++size;
                    if (!cleanSomeSlots(i, sz) && sz >= threshold)
                        rehash();
                }
    
                /**
                 * Remove the entry for key.
                 */
                private void remove(ThreadLocal<?> key) {
                   Entry[] tab = table;
                    int len = tab.length;
                    int i = key.threadLocalHashCode & (len-1);
                    for (Entry e = tab[i];
                         e != null;
                         e = tab[i = nextIndex(i, len)]) {
                        if (e.get() == key) {
                            e.clear();
                            expungeStaleEntry(i);
                            return;
                        }
                    }
                }
    
                //key-value替换 staleSlot位置的值
                private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                                               int staleSlot) {
                    Entry[] tab = table;
                    int len = tab.length;
                    Entry e;
                    /**
                     * 根据传入的无效entry的位置(staleSlot),向前扫描,当前已经是无效的,可能前面也有无效的,找到最开始无效的一个进行替换,维护线性探测法
                     * 一段连续的entry(这里的连续是指一段相邻的entry并且table[i] != null),
                     * 直到找到一个无效entry,或者扫描完也没找到
                     */
                    int slotToExpunge = staleSlot;
                    for (int i = prevIndex(staleSlot, len);
                         (e = tab[i]) != null;
                         i = prevIndex(i, len))
                        if (e.get() == null)
                            slotToExpunge = i;
    
                    /**
                     * 向后扫描一段连续的entry
                     */
                    for (int i = nextIndex(staleSlot, len);
                         (e = tab[i]) != null;
                         i = nextIndex(i, len)) {
                        ThreadLocal<?> k = e.get();
    
                        /**
                         * 如果找到了key,直接替换,也就是与table[staleSlot]进行替换,此时staleSlot位置已经替换了对应key的值
                         */
                        if (k == key) {
                            e.value = value;
                            tab[i] = tab[staleSlot];
                            tab[staleSlot] = e;
    
                            //如果向前查找没有找到无效entry,则更新slotToExpunge为当前值i
                            if (slotToExpunge == staleSlot)
                                slotToExpunge = i;
                            // 此时,staleSlot位置已经设置值了,应该从 slotToExpunge 位置开始往后清除
                            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                            return;
                        }
                        /**
                         * 如果向前查找没有找到无效entry,并且当前向后扫描的entry无效,则更新slotToExpunge为当前值i
                         */
                        if (k == null && slotToExpunge == staleSlot)
                            slotToExpunge = i;
                    }
    
                    /**
                     * 如果没有找到key,也就是说key之前不存在table中
                     * 就直接最开始的无效entry——tab[staleSlot]上直接新增即可
                     */
                    tab[staleSlot].value = null;
                    tab[staleSlot] = new Entry(key, value);
                    /**
                     * slotToExpunge != staleSlot,说明存在其他的无效entry需要进行清理。
                     */
                    if (slotToExpunge != staleSlot)
                        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                }
    
                /**
                 * 连续段清除
                 * 根据传入的staleSlot,清理对应的无效entry——table[staleSlot],
                 * 并且根据当前传入的staleSlot,向后扫描一段连续的entry(这里的连续是指一段相邻的entry并且table[i] != null),
                 * 对可能存在hash冲突的entry进行rehash,并且清理遇到的无效entry.
                 * @param staleSlot key为null,需要无效entry所在的table中的索引
                 * @return 返回下一个为空的solt的索引。
                 */
                private int expungeStaleEntry(int staleSlot) {
                    Entry[] tab = table;
                    int len = tab.length;
    
                    tab[staleSlot].value = null;
                    tab[staleSlot] = null;
                    size--;
    
                   Entry e;
                    int i;
                    for (i = nextIndex(staleSlot, len);
                         (e = tab[i]) != null;
                         i = nextIndex(i, len)) {
                        ThreadLocal<?> k = e.get();
                        // key为null,直接清理
                        if (k == null) {
                            e.value = null;
                            tab[i] = null;
                            size--;
                        } else {
                            /**
                             * 计算出来的索引h,与其现在所在位置的索引——i不一致,置空当前的table[i]
                             * 从h开始向后线性探测到第一个空的slot,把当前的entry挪过去。
                             */
                            int h = k.threadLocalHashCode & (len - 1);
                            if (h != i) {
                                tab[i] = null;
                                while (tab[h] != null)
                                    h = nextIndex(h, len);
                                tab[h] = e;
                            }
                        }
                    }
                    return i;
                }
    
                /**
                 * 启发式的扫描清除,扫描次数由传入的参数n决定
                 * @param i 从i向后开始扫描(不包括i,因为索引为i的Slot肯定为null)
                 * @param n 控制扫描次数,正常情况下为 log2(n) ,
                 * 如果找到了无效entry,会将n重置为table的长度len,进行段清除。
                 * map.set()点用的时候传入的是元素个数,replaceStaleEntry()调用的时候传入的是table的长度len
                 */
                private boolean cleanSomeSlots(int i, int n) {
                    boolean removed = false;
                    Entry[] tab = table;
                    int len = tab.length;
                    do {
                        i = nextIndex(i, len);
                       Entry e = tab[i];
                        if (e != null && e.get() == null) {
                            n = len;
                            removed = true;
                            i = expungeStaleEntry(i);
                        }
                    //无符号的右移动,可以用于控制扫描次数在log2(n)
                    } while ( (n >>>= 1) != 0);
                    return removed;
                }
    
                private void rehash() {
                    expungeStaleEntries();
                    /**
                     * threshold = 2/3 * len
                     * 所以threshold - threshold / 4 = 1en/2
                     * 这里主要是因为上面做了一次全清理所以size减小,需要进行判断。
                     * 判断的时候把阈值调低了。
                     */
                    if (size >= threshold - threshold / 4)
                        resize();
                }
    
                /**
                 * 扩容,扩大为原来的2倍(这样保证了长度为2的冥)
                 */
                private void resize() {
                    Entry[] oldTab = table;
                    int oldLen = oldTab.length;
                    int newLen = oldLen * 2;
                    Entry[] newTab = new Entry[newLen];
                    int count = 0;
    
                    for (int j = 0; j < oldLen; ++j) {
                        Entry e = oldTab[j];
                        if (e != null) {
                            ThreadLocal<?> k = e.get();
                            if (k == null) {
                                e.value = null; // Help the GC
                            } else {
                                int h = k.threadLocalHashCode & (newLen - 1);
                                while (newTab[h] != null)
                                    h = nextIndex(h, newLen);
                                newTab[h] = e;
                                count++;
                            }
                        }
                    }
    
                    setThreshold(newLen);
                    size = count;
                    table = newTab;
                }
    
                // 清空全部Entry
                private void expungeStaleEntries() {
                    Entry[] tab = table;
                    int len = tab.length;
                    for (int j = 0; j < len; j++) {
                        Entry e = tab[j];
                        if (e != null && e.get() == null)
                            expungeStaleEntry(j);
                    }
                }
    
  • 相关阅读:
    反向迭代
    c++知识点
    LeetCode-Count Bits
    LeetCode-Perfect Rectangle
    LeetCode-Perfect Squares
    LeetCode-Lexicographical Numbers
    LeetCode-Find Median from Data Stream
    LeetCode-Maximal Square
    LeetCode-Number of Digit One
    LeetCode-Combination Sum IV
  • 原文地址:https://www.cnblogs.com/houzheng/p/12273441.html
Copyright © 2011-2022 走看看