  温故而知新,再探ConcurrentHashMap




     1 public ConcurrentHashMap(int initialCapacity,
     2                              float loadFactor, int concurrencyLevel) {
     3         if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
     4             throw new IllegalArgumentException();
     5         if (concurrencyLevel > MAX_SEGMENTS)
     6             concurrencyLevel = MAX_SEGMENTS;
     7         // Find power-of-two sizes best matching arguments
     8         int sshift = 0;
     9         int ssize = 1;
    10         while (ssize < concurrencyLevel) {
    11             ++sshift;
    12             ssize <<= 1;
    13         }
    14         this.segmentShift = 32 - sshift;
    15         this.segmentMask = ssize - 1;
    16         if (initialCapacity > MAXIMUM_CAPACITY)
    17             initialCapacity = MAXIMUM_CAPACITY;
    18         int c = initialCapacity / ssize;
    19         if (c * ssize < initialCapacity)
    20             ++c;
    21         int cap = MIN_SEGMENT_TABLE_CAPACITY;
    22         while (cap < c)
    23             cap <<= 1;
    24         // create segments and segments[0]
    25         Segment<K,V> s0 =
    26             new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
    27                              (HashEntry<K,V>[])new HashEntry[cap]);
    28         Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];
    29         UNSAFE.putOrderedObject(ss, SBASE, s0); // ordered write of segments[0]
    30         this.segments = ss;
    31     }


     1 public V put(K key, V value) {
     2         Segment<K,V> s;
     3         if (value == null)
     4             throw new NullPointerException();
     5         int hash = hash(key);
     6         int j = (hash >>> segmentShift) & segmentMask;
     7         if ((s = (Segment<K,V>)UNSAFE.getObject          // nonvolatile; recheck
     8              (segments, (j << SSHIFT) + SBASE)) == null) //  in ensureSegment
     9             s = ensureSegment(j);
    10         return s.put(key, hash, value, false);
    11     }



     1 private Segment<K,V> ensureSegment(int k) {
     2         final Segment<K,V>[] ss = this.segments;
     3         long u = (k << SSHIFT) + SBASE; // raw offset
     4         Segment<K,V> seg;
     5         if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
     6             Segment<K,V> proto = ss[0]; // use segment 0 as prototype
     7             int cap = proto.table.length;
     8             float lf = proto.loadFactor;
     9             int threshold = (int)(cap * lf);
    10             HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
    11             if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
    12                 == null) { // recheck
    13                 Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
    14                 while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
    15                        == null) {
    16                     if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
    17                         break;
    18                 }
    19             }
    20         }
    21         return seg;
    22     }

    我们再看segments,它是final的,所以它是多线程可见的,但是它的元素却不是。数组(或对象)为volatile,只能保证它的引用值是可见的,但是数组元素(或对象属性)却不是多线程可见的,这是java的设计缺陷。为了弥补这个缺陷,在上面的第5行用了 UNSAFE.getObjectVolatile(ss, u) 来补坑。

    final Segment<K,V>[] segments


     1 final V put(K key, int hash, V value, boolean onlyIfAbsent) {
     2             HashEntry<K,V> node = tryLock() ? null :
     3                 scanAndLockForPut(key, hash, value);
     4             V oldValue;
     5             try {
     6                 HashEntry<K,V>[] tab = table;
     7                 int index = (tab.length - 1) & hash;
     8                 HashEntry<K,V> first = entryAt(tab, index);
     9                 for (HashEntry<K,V> e = first;;) {
    10                     if (e != null) {
    11                         K k;
    12                         if ((k = e.key) == key ||
    13                             (e.hash == hash && key.equals(k))) {
    14                             oldValue = e.value;
    15                             if (!onlyIfAbsent) {
    16                                 e.value = value;
    17                                 ++modCount;
    18                             }
    19                             break;
    20                         }
    21                         e = e.next;
    22                     }
    23                     else {
    24                         if (node != null)
    25                             node.setNext(first);
    26                         else
    27                             node = new HashEntry<K,V>(hash, key, value, first);
    28                         int c = count + 1;
    29                         if (c > threshold && tab.length < MAXIMUM_CAPACITY)
    30                             rehash(node);
    31                         else
    32                             setEntryAt(tab, index, node);
    33                         ++modCount;
    34                         count = c;
    35                         oldValue = null;
    36                         break;
    37                     }
    38                 }
    39             } finally {
    40                 unlock();
    41             }
    42             return oldValue;
    43         }

    我们接下来看tryLock失败之后会发生什么。我们知道,CHM在读取操作是不加锁,但是在写入操作的时候是一定要加锁的,但是呢,真的是一开始就加锁吗?仔细观察scanAndLockForPut函数,发现不是这样的。scanAndLockForPut函数有个循环,以自旋锁的方式一直在尝试获取锁,如果不成功就循环下去。但是就这样一直循环下去吗?也不是,可以看下面,有个retries参数,每循环一次,这个reties参数就加1。从19行到22行可知,当达到MAX_SCAN_RETRIES次数之后,就会直接调用lock(),自旋锁变成了重量级锁。而且我们可以看到,在自旋的过程中,scanAndForput并不是什么都没干的,它总是在干一件事儿:给node赋值。如果node为null,就生成新对象。如果找到了key相等的节点,则指向找到的节点。并且在这个过程中,每隔一步((retries & 1) == 0,不明白为什么要在偶数次检查)检查一次table是否已经变了,如果变了,则需要重新找node(把制为retries = -1;,是的重新进入retries < 0逻辑),开始重新计算retries。也就是说如果有线程改变了当前segment,则继续等待MAX_SCAN_RETRIES次。

     1 private HashEntry<K,V> scanAndLockForPut(K key, int hash, V value) {
     2             HashEntry<K,V> first = entryForHash(this, hash);
     3             HashEntry<K,V> e = first;
     4             HashEntry<K,V> node = null;
     5             int retries = -1; // negative while locating node
     6             while (!tryLock()) {
     7                 HashEntry<K,V> f; // to recheck first below
     8                 if (retries < 0) {
     9                     if (e == null) {
    10                         if (node == null) // speculatively create node
    11                             node = new HashEntry<K,V>(hash, key, value, null);
    12                         retries = 0;
    13                     }
    14                     else if (key.equals(e.key))
    15                         retries = 0;
    16                     else
    17                         e = e.next;
    18                 }
    19                 else if (++retries > MAX_SCAN_RETRIES) {//当达到MAX_SCAN_RETRIES次数之后,就会直接调用lock(),自旋锁变成了重量级锁。
    20                     lock();
    21                     break;
    22                 }
    23                 else if ((retries & 1) == 0 &&
    24                          (f = entryForHash(this, hash)) != first) {
    25                     e = first = f; // re-traverse if entry changed
    26                     retries = -1;
    27                 }
    28             }
    29             return node;
    30         }


     1 final V put(K key, int hash, V value, boolean onlyIfAbsent) {
     2             //先尝试获取锁,如果获取失败,则进入scanAndLockForPut函数,并在scanAndLockForPut函数中尝试以自旋的方式获取锁。
     3             HashEntry<K,V> node = tryLock() ? null :
     4                 scanAndLockForPut(key, hash, value);
     5             V oldValue;
     6             try {
     7                 HashEntry<K,V>[] tab = table;
     8                 //计算hash桶的位置
     9                 int index = (tab.length - 1) & hash;
    10                 //获取到第一个节点
    11                 HashEntry<K,V> first = entryAt(tab, index);
    12                 //遍历hash桶中的HashEntry
    13                 for (HashEntry<K,V> e = first;;) {
    14                     if (e != null) {//如果没有遍历完
    15                         K k;
    16                         //如果找到了相同的key,则根据onlyIfAbsent判断是替换值还是不做任何操作,并且结束遍历
    17                         if ((k = e.key) == key ||
    18                             (e.hash == hash && key.equals(k))) {
    19                             oldValue = e.value;
    20                             if (!onlyIfAbsent) {
    21                                 e.value = value;
    22                                 ++modCount;
    23                             }
    24                             break;
    25                         }
    26                         e = e.next;
    27                     }
    28                     else {//如果遍历到了头,则检查是否在scanAndLockForPut已经获取了node,如果是,则设置node的next为当前的first
    29                         if (node != null)
    30                             node.setNext(first);
    31                         else//如果scanAndLockForPut没有获取node,则新建node,注意,node的next还是first。
    32                             node = new HashEntry<K,V>(hash, key, value, first);
    33                         int c = count + 1;
    34                         //如果超过了阈值,则先进行扩容
    35                         if (c > threshold && tab.length < MAXIMUM_CAPACITY)
    36                             rehash(node);
    37                         else//最后,把node放入table中。可见,新插入的node总是位于链表的最前端的
    38                             setEntryAt(tab, index, node);
    39                         ++modCount;
    40                         count = c;
    41                         oldValue = null;
    42                         break;
    43                     }
    44                 }
    45             } finally {
    46                 unlock();
    47             }
    48             return oldValue;
    49         }



     1 private void rehash(HashEntry<K,V> node) {
     2             /*
     3              * Reclassify nodes in each list to new table.  Because we
     4              * are using power-of-two expansion, the elements from
     5              * each bin must either stay at same index, or move with a
     6              * power of two offset. We eliminate unnecessary node
     7              * creation by catching cases where old nodes can be
     8              * reused because their next fields won't change.
     9              * Statistically, at the default threshold, only about
    10              * one-sixth of them need cloning when a table
    11              * doubles. The nodes they replace will be garbage
    12              * collectable as soon as they are no longer referenced by
    13              * any reader thread that may be in the midst of
    14              * concurrently traversing table. Entry accesses use plain
    15              * array indexing because they are followed by volatile
    16              * table write.
    17              */
    18             HashEntry<K,V>[] oldTable = table;
    19             int oldCapacity = oldTable.length;
    20             int newCapacity = oldCapacity << 1;
    21             //重新计算扩容阈值
    22             threshold = (int)(newCapacity * loadFactor);
    23             //新建Table
    24             HashEntry<K,V>[] newTable =
    25                 (HashEntry<K,V>[]) new HashEntry[newCapacity];
    26             //从新计算掩膜,掩膜的作用就是计算索引值,利用table长度是2的整数次方(k)这一特性,直接区hash值的低k位就是索引值
    27             int sizeMask = newCapacity - 1;
    28             //遍历处理table上的每一个hash桶
    29             for (int i = 0; i < oldCapacity ; i++) {
    30                 HashEntry<K,V> e = oldTable[i];
    31                 if (e != null) {
    32                     //如果当前hash桶不为空,则需要处理
    33                     HashEntry<K,V> next = e.next;
    34                     //计算新桶在新table中的位置
    35                     int idx = e.hash & sizeMask;
    36                     //如果当前hash桶只有一个节点,直接把他放到新桶中就可以了
    37                     if (next == null)   //  Single node on list
    38                         newTable[idx] = e;
    39                     //否则的话,就需要遍历当前桶的链表
    40                     else { // Reuse consecutive sequence at same slot
    41                         HashEntry<K,V> lastRun = e;
    42                         int lastIdx = idx;
    43                         //找到第一个后续元素的新位置都不变的节点,然后把这个节点当成头结点,直接把后续的整个链表都放入新table中
    44                         //因为table的容量总是2的k次方,而且每次扩容都是容量乘以2,也就是segmentMask会增加1位,那么,节点的新桶在
    45                         // 新table中的位置要么还是老位置,要么增加了一个oldCapacity,具体要看新增的这一位上key的hash值是否为1.
    47                         for (HashEntry<K,V> last = next;
    48                              last != null;
    49                              last = last.next) {
    50                             int k = last.hash & sizeMask;
    51                             if (k != lastIdx) {
    52                                 lastIdx = k;
    53                                 lastRun = last;
    54                             }
    55                         }
    56                         //直接把整个连续的位置不变的节点组成的链表加入到新桶中
    57                         newTable[lastIdx] = lastRun;
    58                         // Clone remaining nodes
    59                         //把剩下的节点(都在搬迁链表的前端)一个个放入到新桶中。
    60                         //同样,每次都是加入到桶的最前端
    61                         for (HashEntry<K,V> p = e; p != lastRun; p = p.next) {
    62                             V v = p.value;
    63                             int h = p.hash;
    64                             int k = h & sizeMask;
    65                             HashEntry<K,V> n = newTable[k];
    66                             newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
    67                         }
    68                     }
    69                 }
    70             }
    71             //扩容完成之后,在新的table中插入要put的节点
    72             int nodeIndex = node.hash & sizeMask; // add the new node
    73             node.setNext(newTable[nodeIndex]);
    74             newTable[nodeIndex] = node;
    75             table = newTable;
    76         }




    1 public boolean remove(Object key, Object value) {
    2         int hash = hash(key);
    3         Segment<K,V> s;
    4         return value != null && (s = segmentForHash(hash)) != null &&
    5             s.remove(key, hash, value) != null;
    6     }


     1 final V remove(Object key, int hash, Object value) {
     2             if (!tryLock())
     3             //自旋尝试获取锁
     4                 scanAndLock(key, hash);
     5             V oldValue = null;
     6             //获取锁之后
     7             try {
     8                 HashEntry<K,V>[] tab = table;
     9                 int index = (tab.length - 1) & hash;
    10                 //取得头结点
    11                 HashEntry<K,V> e = entryAt(tab, index);
    12                 HashEntry<K,V> pred = null;
    13                 while (e != null) {
    14                     K k;
    15                     HashEntry<K,V> next = e.next;
    16                     //如果已经找到节点
    17                     if ((k = e.key) == key ||
    18                         (e.hash == hash && key.equals(k))) {
    19                         V v = e.value;
    20                         //并且值相等,或者传入的值为null(说明是删除指定的key),则删除
    21                         if (value == null || value == v || value.equals(v)) {
    22                         //如果是找到的节点头结点,则将next节点存入table中
    23                             if (pred == null)
    24                                 setEntryAt(tab, index, next);
    25                             else
    26                             //否则让pred节点指向next节点,使当前节点删除
    27                                 pred.setNext(next);
    28                             ++modCount;
    29                             --count;
    30                             oldValue = v;
    31                         }
    32                         //否则什么都不做,直接跳出循环
    33                         break;
    34                     }
    35                     //继续找下一个节点
    36                     pred = e;
    37                     e = next;
    38                 }
    39             } finally {
    40                 unlock();
    41             }
    42             return oldValue;
    43         }


     1 /**
     2          * Scans for a node containing the given key while trying to
     3          * acquire lock for a remove or replace operation. Upon
     4          * return, guarantees that lock is held.  Note that we must
     5          * lock even if the key is not found, to ensure sequential
     6          * consistency of updates.
     7          */
     8         private void scanAndLock(Object key, int hash) {
     9             // similar to but simpler than scanAndLockForPut
    10             HashEntry<K,V> first = entryForHash(this, hash);
    11             HashEntry<K,V> e = first;
    12             int retries = -1;
    13             while (!tryLock()) {
    14                 HashEntry<K,V> f;
    15                 if (retries < 0) {
    16                     if (e == null || key.equals(e.key))
    17                         retries = 0;
    18                     else
    19                         e = e.next;
    20                 }
    21                 else if (++retries > MAX_SCAN_RETRIES) {
    22                     lock();
    23                     break;
    24                 }
    25                 else if ((retries & 1) == 0 &&
    26                          (f = entryForHash(this, hash)) != first) {
    27                     e = first = f;
    28                     retries = -1;
    29                 }
    30             }
    31         }


    1 public void clear() {
    2         final Segment<K,V>[] segments = this.segments;
    3         for (int j = 0; j < segments.length; ++j) {
    4             Segment<K,V> s = segmentAt(segments, j);
    5             if (s != null)
    6                 s.clear();
    7         }
    8     }


     1 final void clear() {
     2             lock();
     3             try {
     4                 HashEntry<K,V>[] tab = table;
     5                 for (int i = 0; i < tab.length ; i++)
     6                     setEntryAt(tab, i, null);
     7                 ++modCount;
     8                 count = 0;
     9             } finally {
    10                 unlock();
    11             }
    12         }



     1 public boolean isEmpty() {
     2         /*
     3          * Sum per-segment modCounts to avoid mis-reporting when
     4          * elements are concurrently added and removed in one segment
     5          * while checking another, in which case the table was never
     6          * actually empty at any point. (The sum ensures accuracy up
     7          * through at least 1<<31 per-segment modifications before
     8          * recheck.)  Methods size() and containsValue() use similar
     9          * constructions for stability checks.
    10          */
    11         long sum = 0L;
    12         final Segment<K,V>[] segments = this.segments;
    13         for (int j = 0; j < segments.length; ++j) {
    14             Segment<K,V> seg = segmentAt(segments, j);
    15             if (seg != null) {
    16                 if (seg.count != 0)
    17                     return false;
    18                 sum += seg.modCount;
    19             }
    20         }
    21         if (sum != 0L) { // recheck unless no modifications
    22             for (int j = 0; j < segments.length; ++j) {
    23                 Segment<K,V> seg = segmentAt(segments, j);
    24                 if (seg != null) {
    25                     if (seg.count != 0)
    26                         return false;
    27                     sum -= seg.modCount;
    28                 }
    29             }
    30             if (sum != 0L)
    31                 return false;
    32         }
    33         return true;
    34     }


     1 public int size() {
     2         // Try a few times to get accurate count. On failure due to
     3         // continuous async changes in table, resort to locking.
     4         final Segment<K,V>[] segments = this.segments;
     5         int size;
     6         boolean overflow; // true if size overflows 32 bits
     7         long sum;         // sum of modCounts
     8         long last = 0L;   // previous sum
     9         int retries = -1; // first iteration isn't retry
    10         try {
    11             for (;;) {
    12                 if (retries++ == RETRIES_BEFORE_LOCK) {
    13                     for (int j = 0; j < segments.length; ++j)
    14                         ensureSegment(j).lock(); // force creation
    15                 }
    16                 sum = 0L;
    17                 size = 0;
    18                 overflow = false;
    19                 for (int j = 0; j < segments.length; ++j) {
    20                     Segment<K,V> seg = segmentAt(segments, j);
    21                     if (seg != null) {
    22                         sum += seg.modCount;
    23                         int c = seg.count;
    24                         if (c < 0 || (size += c) < 0)
    25                             overflow = true;
    26                     }
    27                 }
    28                 if (sum == last)
    29                     break;
    30                 last = sum;
    31             }
    32         } finally {
    33             if (retries > RETRIES_BEFORE_LOCK) {
    34                 for (int j = 0; j < segments.length; ++j)
    35                     segmentAt(segments, j).unlock();
    36             }
    37         }
    38         return overflow ? Integer.MAX_VALUE : size;
    39     }
