线程安全-ThreadLocal
前言
共享资源被多个线程同时访问可能出现不安全的事情。线程安全一直是很重要的事情,没处理好线程安全的问题可能导致错误甚至很难复现排查。
常见的解决办法有定义不可变(immutable)变量:例如Java中的String类型、Guava库里的ImmutableCollections。还有就是对访问共享资源的线程加上锁:例如JUC包下的ReentrantLock和JDK自带的synchronized关键字,或者借助Unsafe类来实现的自旋锁。还有一种就是保证所有执行的任务都一个线程里:比如多进程单线程模型部署的verticle(基于EventLoop),ThreadLocal等都是这个原理。
ThreadLocal简介
将每一个线程中的变量存在ThreadLocal里,而这些变量只属于一个线程,因此每次访问变量都是单线程的,于是解决了线程安全的问题。
基本使用
ThreadLocal通常也就是调用他的get,set方法。
public class ThreadLocalTest2 {
private static final ThreadLocal<Integer> tl = new ThreadLocal<>();
public static void main(String[] args) throws InterruptedException {
// 主线程设置值
tl.set(1);
// 主线程获取值
System.out.println("主线程get:" + tl.get());
Runnable task = () -> {
// 子线程获取值
System.out.println("子线程get:" + tl.get());
};
// 开启子线程
new Thread(task).start();
// 等待父子进程都结束
Thread.sleep(10000);
}
}
主线程get:1
子线程get:null
在主线程中往ThreadLocal中set设置了一个值,主线程能得到,子线程却无法得到,说明了ThreadLocal里存放的值只和当前线程绑定。
基本原理
核心分析
看看为啥能这么神奇,往同一个ThreadLocal变量set。结果却是只有set的线程能获取到结果,而别的线程却得不到结果。
当一个线程调用ThreadLocal.set方法的时候
public void set(T value) {
// 获取当前调用set方法的线程t
Thread t = Thread.currentThread();
// 获取当前线程t的map
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
}
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}
可以看到getMap方法会获取到当前线程t内部的threadLocals这个map变量了。
public class Thread implements Runnable {
ThreadLocal.ThreadLocalMap threadLocals = null;
}
接下来看下这个map是啥
public class ThreadLocal<T> {
static class ThreadLocalMap {
// 这个entry是个弱引用
static class Entry extends WeakReference<ThreadLocal<?>> {
Object value;
Entry(ThreadLocal<?> k, Object v) {
// map的key是一个ThreadLocal,value是具体存放在ThreadLocal内部set进去的值
super(k);
value = v;
}
}
// ThreadLocal内的map
private Entry[] table;
}
}
这个map其实就是一个Entry构成的数组,map的key是ThreadLocal,value是具体值的结构。
由此可以看出当一个线程往ThreadLocal里set值的时候,会先找到当前线程的map放入,
所以多个线程不管往ThreadLocal里set多少次,都会对应到所属线程中,保证了线程安全,
本质思是将每个线程会用到的资源都存到当前线程的上下文中去。
其他原理
set()
继续回到set方法,当获取到map不为空时调用map.set(this, value)
private void set(ThreadLocal<?> key, Object value) {
// 获取Entry的数组
Entry[] tab = table;
// 获取数组长度,注意,null也算入统计
int len = tab.length;
/* 计算当前key在map中的位置,threadLocalHashCode是个魔数,减少hash冲突的可能(不写数学原理了),
每次newThreadLocal时候会自增一倍, 不过一般是只会new一次,然后引用一个Map,new ThreadLocal<Map<Object, Object>>()这样。
*/
int i = key.threadLocalHashCode & (len-1);
// 这个循环主要是在数组中找到第一个为null的位置(线性探测法)
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
// nextIndex(i, len)这个方法是个循环遍历数组过程,可以看出Entry[]数组是个循环的(节约空间)
ThreadLocal<?> k = e.get();
// 将WeakReference中引用的ThreadLocal这个key和当前遍历到的比较,如果相同就直接覆盖值
if (k == key) {
e.value = value;
return;
}
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}
// 将探测到的坐标i,设置为最新的值
tab[i] = new Entry(key, value);
int sz = ++size;
/*
cleanSomeSlots清理过期位置的元素
这里的意思是如果没有清空元素并且当前size大于等于扩容阈值,就要rehash扩容
只要清空过或者小于阈值都不会rehash扩容
*/
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}
这就是set方法的主要流程了:通过线性探测法在循环数组中找到合适的位置插入即可。
接下来看看rehash这个方法
private void rehash() {
// 方法一
expungeStaleEntries();
// Use lower threshold for doubling to avoid hysteresis(用低扩容阈值避免延迟)
// 方法二
if (size >= threshold - threshold / 4)
resize();
}
方法一
/**
* Expunge all stale entries in the table.
*/
private void expungeStaleEntries() {
Entry[] tab = table;
int len = tab.length;
for (int j = 0; j < len; j++) {
Entry e = tab[j];
if (e != null && e.get() == null)
// 对每个WeakReference引用为null的元素进行清理
expungeStaleEntry(j);
}
}
private int expungeStaleEntry(int staleSlot) {
Entry[] tab = table;
int len = tab.length;
// expunge entry at staleSlot
// 1.设置Entry[]数组中当前位置的元素(entry)的value为null
tab[staleSlot].value = null;
// 2.设置Entry[]数组中当前位置为null
tab[staleSlot] = null;
/*
3.这里其实应该还要设置WeakReference引用的ThreadLocal为null(Entry.clear()方法),才能彻底断开引用,避免内存泄露, 但是现在调用expungeStaleEntry()是从e != null && e.get() == null这个判断条件进来的,所以
e.get() == null代表着WeakReference的引用为空。
*/
size--;
// Rehash until we encounter null
Entry e;
int i;
// 遍历连续不为null的元素
for (i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
if (k == null) {
// 通过2次设置null清理垃圾
e.value = null;
tab[i] = null;
size--;
} else {
/*
因为采用线性探测法解决hash冲突,所以这里的位置i有可能不是通过idx=k.threadLocalHashCode & (len - 1)
直接到的,而是将idx加上一定的偏移量得到的。
所以这里要比较h != i成立说明这个元素之前是冲突的,现在将其放到尽可能不冲突的位置,下次调用 ThreadLocal.get()方法的时候可以减少线性搜索的时间复杂度。
*/
int h = k.threadLocalHashCode & (len - 1);
if (h != i) {
tab[i] = null;
// Unlike Knuth 6.4 Algorithm R, we must scan until
// null because multiple entries could have been stale.
while (tab[h] != null)
h = nextIndex(h, len);
tab[h] = e;
}
}
}
return i;
}
方法二
private void resize() {
Entry[] oldTab = table;
int oldLen = oldTab.length;
// 2倍扩容
int newLen = oldLen * 2;
Entry[] newTab = new Entry[newLen];
int count = 0;
// 将老的元素一个一个地拷贝到新数组
for (int j = 0; j < oldLen; ++j) {
Entry e = oldTab[j];
if (e != null) {
ThreadLocal<?> k = e.get();
if (k == null) {
e.value = null; // Help the GC
} else {
int h = k.threadLocalHashCode & (newLen - 1);
// 这里还是用线性探测法找到新数组中的位置
while (newTab[h] != null)
h = nextIndex(h, newLen);
newTab[h] = e;
count++;
}
}
}
/*
注意:这里设置新的扩容阈值threshold=len * 2 / 3,并且在调用该方法前会判断size >= threshold - threshold / 4
所以实际扩容的阈值是 newLen * (2/3) * (1 - 1/4) = newLen * 0.5
就是超过一半就扩容,印证了作者那句话"Use lower threshold for doubling to avoid hysteresis"
*/
setThreshold(newLen);
size = count;
// 更改table的引用为最新
table = newTab;
}