本文旨在介绍多线程编程中Thread、ThreadLocal以及ThreadLocalMap三者之间的关系,并进一步解析InheritableThreadLocal源码实现。在这之后,我们将分析错误的使用示例,并探讨TransmittableThreadLocal及其上下文传播机制。
ThreadLocal Thread 对象内部包含 ThreadLocalMap 类型的threadLocals字段,ThreadLocalMap类型是一个Map结构,键是ThreadLocal引用,值是线程局部变量。
其实,threadLocals字段实际是由 ThreadLocal 负责维护的,我们写入或读取线程局部变量都是通过操作 ThreadLocal 类实例。
1 2 3 4 5 6 7 8 9 10 11 public class Thread implements Runnable { ThreadLocal.ThreadLocalMap threadLocals = null ; private void exit () { threadLocals = null ; } }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 public class ThreadLocal <T> { ThreadLocalMap getMap (Thread t) { return t.threadLocals; } void createMap (Thread t, T firstValue) { t.threadLocals = new ThreadLocalMap (this , firstValue); } public void set (T value) { Thread t = Thread.currentThread(); ThreadLocalMap map = getMap(t); if (map != null ) map.set(this , value); else createMap(t, value); } public T get () { Thread t = Thread.currentThread(); ThreadLocalMap map = getMap(t); if (map != null ) { ThreadLocalMap.Entry e = map.getEntry(this ); if (e != null ) { @SuppressWarnings("unchecked") T result = (T)e.value; return result; } } return setInitialValue(); } public void remove () { ThreadLocalMap m = getMap(Thread.currentThread()); if (m != null ) m.remove(this ); } }
InheritableThreadLocal InheritableThreadLocal 是 ThreadLocal 的一个子类,它拓展增加了线程局部变量的继承机制。当创建一个新的线程时,如果父线程中存在 InheritableThreadLocal 变量,那么在子线程启动时,子线程会获得这些变量的初始副本。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 public class Thread implements Runnable { ThreadLocal.ThreadLocalMap inheritableThreadLocals = null ; public Thread () { init(null , null , "Thread-" + nextThreadNum(), 0 ); } private void init (ThreadGroup g, Runnable target, String name, long stackSize) { init(g, target, name, stackSize, null , true ); } private void init (ThreadGroup g, Runnable target, String name, long stackSize, AccessControlContext acc, boolean inheritThreadLocals) { Thread parent = currentThread(); if (inheritThreadLocals && parent.inheritableThreadLocals != null ) this .inheritableThreadLocals = ThreadLocal.createInheritedMap(parent.inheritableThreadLocals); } private void exit () { inheritableThreadLocals = null ; } }
1 2 3 4 5 6 7 8 9 10 11 public class InheritableThreadLocal <T> extends ThreadLocal <T> { protected T childValue (T parentValue) { return parentValue; } ThreadLocalMap getMap (Thread t) { return t.inheritableThreadLocals; } void createMap (Thread t, T firstValue) { t.inheritableThreadLocals = new ThreadLocalMap (this , firstValue); } }
错误示例 请阅读以下提供的代码示例,推断程序预期的输出结果,并解释原因。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 public class SimpleThreadLocalPollutionExample { private final static ThreadLocal<String> USERNAME_THREADLOCAL = ThreadLocal.withInitial(() -> null ); private static void login (String username, String password) { if ("123456" .equals(password)) { USERNAME_THREADLOCAL.set(username); } } private static String getCurrentUsername () { return USERNAME_THREADLOCAL.get(); } public static void main (String[] args) throws InterruptedException { ExecutorService executorService = Executors.newFixedThreadPool(1 ); Map<String, String> userList = MapUtil.<String, String>builder() .put("zhangsan" , "123456" ) .put("lisi" , "123456" ) .put("wangwu" , "111111" ).build(); for (Map.Entry<String, String> entry : userList.entrySet()) { executorService.submit(() -> { login(entry.getKey(), entry.getValue()); String username; if ((username = getCurrentUsername()) != null ) { System.out.println("hello, " + username); } }); } } }
未及时调用 ThreadLocal 的 remove() 方法会导致内存泄漏和线程局部变量污染。
内存泄漏是由于 ThreadLocal 与 ThreadLocalMap 中的值保持强引用,阻碍垃圾回收。而线程局部变量污染则常常发生在线程池场景下,后续任务会读取到旧的脏数据。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 public class InheritableThreadLocalExample { private static final InheritableThreadLocal<String> CONTEXT_HOLDER = new InheritableThreadLocal <>(); public static void main (String[] args) throws InterruptedException { Thread childThread = new Thread (() -> { System.out.println("Child thread value: " + CONTEXT_HOLDER.get()); }); CONTEXT_HOLDER.set("Parent thread value" ); childThread.start(); childThread.join(); } }
对于 InheritableThreadLocal而言,子线程创建后,副本就已完成同步,父线程的任何更改都不会影响到子线程中的值。因此,线程池一般使用TransmittableThreadLocal传递上下文。
TransmittableThreadLocal TransmittableThreadLocal ,以下简称 ttl, 是阿里巴巴开源的 transmittable-thread-local 库中的一部分,它扩展了 Java 的 InheritableThreadLocal 类。普通的ThreadLocal变量在线程之间是不共享的,每个线程都有自己独立的副本。而 ttl 使 ThreadLocal 变量可以在线程之间共享,。
为了确保上下文信息能够正确地从一个线程传递到另一个线程,TTL在使用时需要进行装饰。常见方法的有三种,示例代码如下:
1 2 3 4 5 6 7 TransmittableThreadLocal<String> ttl = new TransmittableThreadLocal <>(); ttl.set("hello, ttl" ); Executor executor = Executors.newFixedThreadPool(1 );executor.execute(TtlRunnable.get(() -> System.out.println(ttl.get())));
1 2 3 4 5 6 7 TransmittableThreadLocal<String> ttl = new TransmittableThreadLocal <>(); ttl.set("hello, ttl" ); Executor ttlExecutor = TtlExecutors.getTtlExecutor(Executors.newFixedThreadPool(1 ));ttlExecutor.execute(() -> System.out.println(ttl.get()));
除此之外,还可以使用 Java Agent 直接装饰线程池,无侵入业务代码。不过要注意,多个Agent使用时可能会失效。
1 -javaagent :xx/xx/transmittable-thread-local-版本号.jar
TIPS: 这里穿插介绍装饰器模式,装饰器模式是动态地给一个对象添加一些额外的职责,侧重于扩展、增强功能。
读懂TTL源码的关键就在于理解它的两个关键机制:一是 上下文的存储与读取;二是上下文在不同线程间传递,接下来我们逐一进行解析。
上下文的存储与读取方式 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 public final T get () { T value = super .get(); if (disableIgnoreNullValueSemantics || null != value) addThisToHolder(); return value; } public final void set (T value) { if (!disableIgnoreNullValueSemantics && null == value) { remove(); } else { super .set(value); addThisToHolder(); } } public final void remove () { removeThisFromHolder(); super .remove(); }
阅读TTL的get()、set()、remove() 方法,发现相关增删逻辑都在 addThisToHolder() 和 removeThisFromHolder()这两个方法,我们接着往下看。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 private void addThisToHolder () { if (!holder.get().containsKey(this )) { holder.get().put((TransmittableThreadLocal<Object>) this , null ); } } private void removeThisFromHolder () { holder.get().remove(this ); } private static final InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>> holder = new InheritableThreadLocal <WeakHashMap<TransmittableThreadLocal<Object>, ?>>() { @Override protected WeakHashMap<TransmittableThreadLocal<Object>, ?> initialValue() { return new WeakHashMap <>(); } @Override protected WeakHashMap<TransmittableThreadLocal<Object>, ?> childValue(WeakHashMap<TransmittableThreadLocal<Object>, ?> parentValue) { return new WeakHashMap <TransmittableThreadLocal<Object>, Object>(parentValue); } };
显然,上下文数据都是存储在holder实例中,holder是一个泛型为WeakHashMap<TransmittableThreadLocal, ?>的InheritableThreadLocal,并且重写了InheritableThreadLocal的initialValue()和childValue()方法。
WeakHashMap是一个key为弱引用的map,ttl把WeakHashMap当做一个set使用,key为当前的ttl,value则固定为null。使用WeakHashMap是为了避免内存泄漏的问题。
上下文在不同线程间传递
这部分代码比较复杂,我们需要先介绍 TransmittableThreadLocal类中的三个核心类 Snapshot 、 Transmitter 、Transmittee 的概念。它们其实都是用来管理holder的,具体说明如下:
1 2 3 4 5 6 7 8 9 10 11 private static class Snapshot { final HashMap<Transmittee<Object, Object>, Object> transmittee2Value; public Snapshot (HashMap<Transmittee<Object, Object>, Object> transmittee2Value) { this .transmittee2Value = transmittee2Value; } }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 public static class Transmitter { private static final Set<Transmittee<Object, Object>> transmitteeSet = new CopyOnWriteArraySet <>(); static { registerTransmittee(ttlTransmittee); registerTransmittee(threadLocalTransmittee); } public static <C, B> boolean registerTransmittee (@NonNull Transmittee<C, B> transmittee) { return transmitteeSet.add((Transmittee<Object, Object>) transmittee); } public static Object capture () { final HashMap<Transmittee<Object, Object>, Object> transmittee2Value = new HashMap <>(transmitteeSet.size()); for (Transmittee<Object, Object> transmittee : transmitteeSet) { try { transmittee2Value.put(transmittee, transmittee.capture()); } catch (Throwable t) { if (logger.isLoggable(Level.WARNING)) { logger.log(Level.WARNING, "exception when Transmitter.capture for transmittee " + transmittee + "(class " + transmittee.getClass().getName() + "), just ignored; cause: " + t, t); } } } return new Snapshot (transmittee2Value); } public static Object replay (@NonNull Object captured) { final Snapshot capturedSnapshot = (Snapshot) captured; final HashMap<Transmittee<Object, Object>, Object> transmittee2Value = new HashMap <>(capturedSnapshot.transmittee2Value.size()); for (Map.Entry<Transmittee<Object, Object>, Object> entry : capturedSnapshot.transmittee2Value.entrySet()) { Transmittee<Object, Object> transmittee = entry.getKey(); try { Object transmitteeCaptured = entry.getValue(); transmittee2Value.put(transmittee, transmittee.replay(transmitteeCaptured)); } catch (Throwable t) { if (logger.isLoggable(Level.WARNING)) { logger.log(Level.WARNING, "exception when Transmitter.replay for transmittee " + transmittee + "(class " + transmittee.getClass().getName() + "), just ignored; cause: " + t, t); } } } return new Snapshot (transmittee2Value); } public static Object clear () { final HashMap<Transmittee<Object, Object>, Object> transmittee2Value = new HashMap <>(transmitteeSet.size()); for (Transmittee<Object, Object> transmittee : transmitteeSet) { try { transmittee2Value.put(transmittee, transmittee.clear()); } catch (Throwable t) { if (logger.isLoggable(Level.WARNING)) { logger.log(Level.WARNING, "exception when Transmitter.clear for transmittee " + transmittee + "(class " + transmittee.getClass().getName() + "), just ignored; cause: " + t, t); } } } return new Snapshot (transmittee2Value); } public static void restore (@NonNull Object backup) { for (Map.Entry<Transmittee<Object, Object>, Object> entry : ((Snapshot) backup).transmittee2Value.entrySet()) { Transmittee<Object, Object> transmittee = entry.getKey(); try { Object transmitteeBackup = entry.getValue(); transmittee.restore(transmitteeBackup); } catch (Throwable t) { if (logger.isLoggable(Level.WARNINGa { logger.log(Level.WARNING, "exception when Transmitter.restore for transmittee " + transmittee + "(class " + transmittee.getClass().getName() + "), just ignored; cause: " + t, t); } } } } }
可以看到Transmitter类 实际是负责管理 transmitteeSet 容器,它的所有方法都间接调用了容器中的方法,主线程的局部变量就存储在 transmitteeSet 中。
那么接下来,就需要看看容器元素类型 Transmittee类 和 容器中唯一的两个元素 ttlTransmittee类 和 threadLocalTransmittee类 的结构了。这两个元素分别存储 TransmittableThreadLocal局部变量 和 threadLocalTransmittee局部变量 。
1 2 3 4 5 6 public interface Transmittee <C, B> { C capture () ; B replay (@NonNull C captured) ; B clear () ; void restore (@NonNull B backup) ; }
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 private static final Transmittee<HashMap<TransmittableThreadLocal<Object>, Object>, HashMap<TransmittableThreadLocal<Object>, Object>> ttlTransmittee = new Transmittee <HashMap<TransmittableThreadLocal<Object>, Object>, HashMap<TransmittableThreadLocal<Object>, Object>>() { @Override public HashMap<TransmittableThreadLocal<Object>, Object> capture () { final HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = new HashMap <>(holder.get().size()); for (TransmittableThreadLocal<Object> threadLocal : holder.get().keySet()) { ttl2Value.put(threadLocal, threadLocal.copyValue()); } return ttl2Value; } @Override public HashMap<TransmittableThreadLocal<Object>, Object> clear () { return replay(new HashMap <>(0 )); } @Override public void restore (@NonNull HashMap<TransmittableThreadLocal<Object>, Object> backup) { doExecuteCallback(false ); for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) { TransmittableThreadLocal<Object> threadLocal = iterator.next(); if (!backup.containsKey(threadLocal)) { iterator.remove(); threadLocal.superRemove(); } } setTtlValuesTo(backup); } @Override public HashMap<TransmittableThreadLocal<Object>, Object> replay (@NonNull HashMap<TransmittableThreadLocal<Object>, Object> captured) { final HashMap<TransmittableThreadLocal<Object>, Object> backup = new HashMap <>(holder.get().size()); for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) { TransmittableThreadLocal<Object> threadLocal = iterator.next(); backup.put(threadLocal, threadLocal.get()); if (!captured.containsKey(threadLocal)) { iterator.remove(); threadLocal.superRemove(); } } setTtlValuesTo(captured); doExecuteCallback(true ); return backup; } }; private static void setTtlValuesTo (@NonNull HashMap<TransmittableThreadLocal<Object>, Object> ttlValues) { for (Map.Entry<TransmittableThreadLocal<Object>, Object> entry : ttlValues.entrySet()) { TransmittableThreadLocal<Object> threadLocal = entry.getKey(); threadLocal.set(entry.getValue()); } }
容器中的另一个实例 threadLocalTransmittee 结构也是类似的,迭代遍历后进行读写操作。
熟悉核心类的这些职责后,我们开始逐步分析之前TtlExecutors装饰线程池的流程。
1 Executor ttlExecutor = TtlExecutors.getTtlExecutor(Executors.newFixedThreadPool(1 ));
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 public final class TtlExecutors { public static Executor getTtlExecutor (Executor executor) { if (TtlAgent.isTtlAgentLoaded() || null == executor || executor instanceof TtlEnhanced) { return executor; } return new ExecutorTtlWrapper (executor, true ); } } class ExecutorTtlWrapper implements Executor , TtlWrapper<Executor>, TtlEnhanced { @Override public void execute (@NonNull Runnable command) { executor.execute(TtlRunnable.get(command, false , idempotent)); } }
观察发现,装饰线程池,实际上还是使用 TtlRunnable 装饰 Runnable 或 Callable。我们继续往里深入看看 TtlRunnable 的逻辑。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 public final class TtlRunnable implements Runnable , TtlWrapper<Runnable>, TtlEnhanced, TtlAttachments { @Override public void run () { final Object captured = capturedRef.get(); if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null )) { throw new IllegalStateException ("TTL value reference is released after run!" ); } final Object backup = replay(captured); try { runnable.run(); } finally { restore(backup); } } }
显然,TtlRunnable 在 Runnable 上进行了增强,它有三个主要功能:获取主线程的数据快照 、将快照加载到执行线程 和 根据backup备份,恢复数据。
其中,获取主线程数据快照的相关代码如下。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 private final AtomicReference<Object> capturedRef;public static TtlRunnable get (@Nullable Runnable runnable, boolean releaseTtlValueReferenceAfterRun, boolean idempotent) { if (null == runnable) return null ; if (runnable instanceof TtlEnhanced) { if (idempotent) return (TtlRunnable) runnable; else throw new IllegalStateException ("Already TtlRunnable!" ); } return new TtlRunnable (runnable, releaseTtlValueReferenceAfterRun); } private TtlRunnable (@NonNull Runnable runnable, boolean releaseTtlValueReferenceAfterRun) { this .capturedRef = new AtomicReference <>(capture()); }
其余两个功能则是直接调用的 Transmitter 的静态方法。replay() 方法负责在 run() 方法调用前,将局部变量从主线程加载到执行线程。restore() 方法则是恢复数据,常用于线程池复用等情况,避免对后续任务产生影响。