zoukankan      html  css  js  c++  java
  • C# ConcurrentDictionary实现

    ConcurrentDictionary的源码看了很多遍,今天抽点时间整理一下,它的实现比Dictionary要复杂很多,至于线程安全我觉得比较简单,用的是lock的思想。首先我们来看看它的源码。

     public class ConcurrentDictionary<TKey, TValue> : IDictionary<TKey, TValue>, IDictionary, IReadOnlyDictionary<TKey, TValue>
        {
            /// <summary>
            /// Tables that hold the internal state of the ConcurrentDictionary
            ///
            /// Wrapping the three tables in a single object allows us to atomically
            /// replace all tables at once.
            /// </summary>
            private class Tables
            {
                internal readonly Node[] m_buckets; // A singly-linked list for each bucket.
                internal readonly object[] m_locks; // A set of locks, each guarding a section of the table.
                internal volatile int[] m_countPerLock; // The number of elements guarded by each lock.
                internal readonly IEqualityComparer<TKey> m_comparer; // Key equality comparer
    
                internal Tables(Node[] buckets, object[] locks, int[] countPerLock, IEqualityComparer<TKey> comparer)
                {
                    m_buckets = buckets;
                    m_locks = locks;
                    m_countPerLock = countPerLock;
                    m_comparer = comparer;
                }
            }
            
            private const int DEFAULT_CONCURRENCY_MULTIPLIER = 4;
            private const int DEFAULT_CAPACITY = 31;
            private const int MAX_LOCK_NUMBER = 1024;
              // Whether TValue is a type that can be written atomically (i.e., with no danger of torn reads)
            private static readonly bool s_isValueWriteAtomic = IsValueWriteAtomic();
            public ConcurrentDictionary() : this(DefaultConcurrencyLevel, DEFAULT_CAPACITY, true, EqualityComparer<TKey>.Default)
            public ConcurrentDictionary(int concurrencyLevel, int capacity) : this(concurrencyLevel, capacity, false, EqualityComparer<TKey>.Default) { }
            public ConcurrentDictionary(int concurrencyLevel, int capacity, IEqualityComparer<TKey> comparer) : this(concurrencyLevel, capacity, false, comparer){}
            
            internal ConcurrentDictionary(int concurrencyLevel, int capacity, bool growLockArray, IEqualityComparer<TKey> comparer)
            {
                if (concurrencyLevel < 1)
                {
                    throw new ArgumentOutOfRangeException("concurrencyLevel", GetResource("ConcurrentDictionary_ConcurrencyLevelMustBePositive"));
                }
                if (capacity < 0)
                {
                    throw new ArgumentOutOfRangeException("capacity", GetResource("ConcurrentDictionary_CapacityMustNotBeNegative"));
                }
                if (comparer == null) throw new ArgumentNullException("comparer");
    
                // The capacity should be at least as large as the concurrency level. Otherwise, we would have locks that don't guard
                // any buckets.
                if (capacity < concurrencyLevel)
                {
                    capacity = concurrencyLevel;
                }
    
                object[] locks = new object[concurrencyLevel];
                for (int i = 0; i < locks.Length; i++)
                {
                    locks[i] = new object();
                }
    
                int[] countPerLock = new int[locks.Length];
                Node[] buckets = new Node[capacity];
                m_tables = new Tables(buckets, locks, countPerLock, comparer);
    
                m_growLockArray = growLockArray;
                m_budget = buckets.Length / locks.Length;
            }
            
            public TValue this[TKey key]
            {
                get
                {
                    TValue value;
                    if (!TryGetValue(key, out value))
                    {
                        throw new KeyNotFoundException();
                    }
                    return value;
                }
                set
                {
                    if (key == null) throw new ArgumentNullException("key");
                    TValue dummy;
                    TryAddInternal(key, value, true, true, out dummy);
                }
            }
            
            public bool TryGetValue(TKey key, out TValue value)
            {
                if (key == null) throw new ArgumentNullException("key");
                int bucketNo, lockNoUnused;
    
                // We must capture the m_buckets field in a local variable. It is set to a new table on each table resize.
                Tables tables = m_tables;
                IEqualityComparer<TKey> comparer = tables.m_comparer;
                GetBucketAndLockNo(comparer.GetHashCode(key), out bucketNo, out lockNoUnused, tables.m_buckets.Length, tables.m_locks.Length);
    
                Node n = Volatile.Read<Node>(ref tables.m_buckets[bucketNo]);
    
                while (n != null)
                {
                    if (comparer.Equals(n.m_key, key))
                    {
                        value = n.m_value;
                        return true;
                    }
                    n = n.m_next;
                }
    
                value = default(TValue);
                return false;
            }
            
            private bool TryAddInternal(TKey key, TValue value, bool updateIfExists, bool acquireLock, out TValue resultingValue)
            {
                while (true)
                {
                    int bucketNo, lockNo;
                    int hashcode;
    
                    Tables tables = m_tables;
                    IEqualityComparer<TKey> comparer = tables.m_comparer;
                    hashcode = comparer.GetHashCode(key);
                    GetBucketAndLockNo(hashcode, out bucketNo, out lockNo, tables.m_buckets.Length, tables.m_locks.Length);
    
                    bool resizeDesired = false;
                    bool lockTaken = false;
    
                    try
                    {
                        if (acquireLock)
                            Monitor.Enter(tables.m_locks[lockNo], ref lockTaken);
    
                        // If the table just got resized, we may not be holding the right lock, and must retry.
                        // This should be a rare occurence.
                        if (tables != m_tables)
                        {
                            continue;
                        }
    
                        // Try to find this key in the bucket
                        Node prev = null;
                        for (Node node = tables.m_buckets[bucketNo]; node != null; node = node.m_next)
                        {
                            Assert((prev == null && node == tables.m_buckets[bucketNo]) || prev.m_next == node);
                            if (comparer.Equals(node.m_key, key))
                            {
                                // The key was found in the dictionary. If updates are allowed, update the value for that key.
                                // We need to create a new node for the update, in order to support TValue types that cannot
                                // be written atomically, since lock-free reads may be happening concurrently.
                                if (updateIfExists)
                                {
                                    if (s_isValueWriteAtomic)
                                    {
                                        node.m_value = value;
                                    }
                                    else
                                    {
                                        Node newNode = new Node(node.m_key, value, hashcode, node.m_next);
                                        if (prev == null)
                                        {
                                            tables.m_buckets[bucketNo] = newNode;
                                        }
                                        else
                                        {
                                            prev.m_next = newNode;
                                        }
                                    }
                                    resultingValue = value;
                                }
                                else
                                {
                                    resultingValue = node.m_value;
                                }
                                return false;
                            }
                            prev = node;
    
                        }
    
                        // The key was not found in the bucket. Insert the key-value pair.
                        Volatile.Write<Node>(ref tables.m_buckets[bucketNo], new Node(key, value, hashcode, tables.m_buckets[bucketNo]));
                        checked
                        {
                            tables.m_countPerLock[lockNo]++;
                        }
    
                        if (tables.m_countPerLock[lockNo] > m_budget)
                        {
                            resizeDesired = true;
                        }
                    }
                    finally
                    {
                        if (lockTaken)
                            Monitor.Exit(tables.m_locks[lockNo]);
                    }
    
                    if (resizeDesired)
                    {
                        GrowTable(tables, tables.m_comparer, false, m_keyRehashCount);
                    }
    
                    resultingValue = value;
                    return true;
                }
            }
            public bool TryRemove(TKey key, out TValue value)
            {
                if (key == null) throw new ArgumentNullException("key");
    
                return TryRemoveInternal(key, out value, false, default(TValue));
            }
            
            private bool TryRemoveInternal(TKey key, out TValue value, bool matchValue, TValue oldValue)
            {
                while (true)
                {
                    Tables tables = m_tables;
    
                    IEqualityComparer<TKey> comparer = tables.m_comparer;
    
                    int bucketNo, lockNo;
                    GetBucketAndLockNo(comparer.GetHashCode(key), out bucketNo, out lockNo, tables.m_buckets.Length, tables.m_locks.Length);
    
                    lock (tables.m_locks[lockNo])
                    {
                        // If the table just got resized, we may not be holding the right lock, and must retry.
                        // This should be a rare occurence.
                        if (tables != m_tables)
                        {
                            continue;
                        }
    
                        Node prev = null;
                        for (Node curr = tables.m_buckets[bucketNo]; curr != null; curr = curr.m_next)
                        {
                            Assert((prev == null && curr == tables.m_buckets[bucketNo]) || prev.m_next == curr);
    
                            if (comparer.Equals(curr.m_key, key))
                            {
                                if (matchValue)
                                {
                                    bool valuesMatch = EqualityComparer<TValue>.Default.Equals(oldValue, curr.m_value);
                                    if (!valuesMatch)
                                    {
                                        value = default(TValue);
                                        return false;
                                    }
                                }
    
                                if (prev == null)
                                {
                                    Volatile.Write<Node>(ref tables.m_buckets[bucketNo], curr.m_next);
                                }
                                else
                                {
                                    prev.m_next = curr.m_next;
                                }
    
                                value = curr.m_value;
                                tables.m_countPerLock[lockNo]--;
                                return true;
                            }
                            prev = curr;
                        }
                    }
    
                    value = default(TValue);
                    return false;
                }
            }
            private void GrowTable(Tables tables, IEqualityComparer<TKey> newComparer, bool regenerateHashKeys, int rehashCount)
            {
                int locksAcquired = 0;
                try
                {
                    AcquireLocks(0, 1, ref locksAcquired);
    
                    if (regenerateHashKeys && rehashCount == m_keyRehashCount)
                    {
                        tables = m_tables;
                    }
                    else
                    {
                        if (tables != m_tables)
                        {
                            return;
                        }
                        long approxCount = 0;
                        for (int i = 0; i < tables.m_countPerLock.Length; i++)
                        {
                            approxCount += tables.m_countPerLock[i];
                        }
                        if (approxCount < tables.m_buckets.Length / 4)
                        {
                            m_budget = 2 * m_budget;
                            if (m_budget < 0)
                            {
                                m_budget = int.MaxValue;
                            }
    
                            return;
                        }
                    }
                    int newLength = 0;
                    bool maximizeTableSize = false;
                    try
                    {
                        checked
                        {
                            newLength = tables.m_buckets.Length * 2 + 1;
                            while (newLength % 3 == 0 || newLength % 5 == 0 || newLength % 7 == 0)
                            {
                                newLength += 2;
                            }
    
                            Assert(newLength % 2 != 0);
    
                            if (newLength > Array.MaxArrayLength)
                            {
                                maximizeTableSize = true;
                            }
                        }
                    }
                    catch (OverflowException)
                    {
                        maximizeTableSize = true;
                    }
    
                    if (maximizeTableSize)
                    {
                        newLength = Array.MaxArrayLength;
                        m_budget = int.MaxValue;
                    }
    
                    // Now acquire all other locks for the table
                    AcquireLocks(1, tables.m_locks.Length, ref locksAcquired);
    
                    object[] newLocks = tables.m_locks;
    
                    // Add more locks
                    if (m_growLockArray && tables.m_locks.Length < MAX_LOCK_NUMBER)
                    {
                        newLocks = new object[tables.m_locks.Length * 2];
                        Array.Copy(tables.m_locks, newLocks, tables.m_locks.Length);
    
                        for (int i = tables.m_locks.Length; i < newLocks.Length; i++)
                        {
                            newLocks[i] = new object();
                        }
                    }
    
                    Node[] newBuckets = new Node[newLength];
                    int[] newCountPerLock = new int[newLocks.Length];
    
                    for (int i = 0; i < tables.m_buckets.Length; i++)
                    {
                        Node current = tables.m_buckets[i];
                        while (current != null)
                        {
                            Node next = current.m_next;
                            int newBucketNo, newLockNo;
                            int nodeHashCode = current.m_hashcode;
    
                            if (regenerateHashKeys)
                            {
                                // Recompute the hash from the key
                                nodeHashCode = newComparer.GetHashCode(current.m_key);
                            }
    
                            GetBucketAndLockNo(nodeHashCode, out newBucketNo, out newLockNo, newBuckets.Length, newLocks.Length);
    
                            newBuckets[newBucketNo] = new Node(current.m_key, current.m_value, nodeHashCode, newBuckets[newBucketNo]);
    
                            checked
                            {
                                newCountPerLock[newLockNo]++;
                            }
    
                            current = next;
                        }
                    }
    
                    // If this resize regenerated the hashkeys, increment the count
                    if (regenerateHashKeys)
                    {
                        // We use unchecked here because we don't want to throw an exception if 
                        // an overflow happens
                        unchecked
                        {
                            m_keyRehashCount++;
                        }
                    }
    
                    // Adjust the budget
                    m_budget = Math.Max(1, newBuckets.Length / newLocks.Length);
    
                    // Replace tables with the new versions
                    m_tables = new Tables(newBuckets, newLocks, newCountPerLock, newComparer);
                }
                finally
                {
                    // Release all locks that we took earlier
                    ReleaseLocks(0, locksAcquired);
                }
            }
            private void AcquireLocks(int fromInclusive, int toExclusive, ref int locksAcquired)
            {
                Assert(fromInclusive <= toExclusive);
                object[] locks = m_tables.m_locks;
    
                for (int i = fromInclusive; i < toExclusive; i++)
                {
                    bool lockTaken = false;
                    try
                    {
                       Monitor.Enter(locks[i], ref lockTaken);
                    }
                    finally
                    {
                        if (lockTaken)
                        {
                            locksAcquired++;
                        }
                    }
                }
            }
            private void GetBucketAndLockNo(int hashcode, out int bucketNo, out int lockNo, int bucketCount, int lockCount)
            {
                bucketNo = (hashcode & 0x7fffffff) % bucketCount;
                lockNo = bucketNo % lockCount;
                Assert(bucketNo >= 0 && bucketNo < bucketCount);
                Assert(lockNo >= 0 && lockNo < lockCount);
            }
            private static int DefaultConcurrencyLevel
            {
    
                get { return DEFAULT_CONCURRENCY_MULTIPLIER * PlatformHelper.ProcessorCount; }
            }
            private class Node
            {
                internal TKey m_key;
                internal TValue m_value;
                internal volatile Node m_next;
                internal int m_hashcode;
    
                internal Node(TKey key, TValue value, int hashcode, Node next)
                {
                    m_key = key;
                    m_value = value;
                    m_next = next;
                    m_hashcode = hashcode;
                }
            }
            
        }
        
        public static class Volatile
        {
            [ResourceExposure(ResourceScope.None)]
            [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)]
            [SecuritySafeCritical] //the intrinsic implementation of this method contains unverifiable code
            public static T Read<T>(ref T location) where T : class
            {
                var value = location;
                Thread.MemoryBarrier();
                return value;
            }
            
            [ResourceExposure(ResourceScope.None)]
            [ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)]
            [SecuritySafeCritical] //the intrinsic implementation of this method contains unverifiable code
            public static void Write<T>(ref T location, T value) where T : class
            { 
                Thread.MemoryBarrier();
                location = value;
            }
        }

    ConcurrentDictionary的构造函数依然有int capacity参数,该参数是控制ConcurrentDictionary里面的初始节点数组的大小【Node[] buckets = new Node[capacity] 和m_tables = new Tables(buckets, locks, countPerLock, comparer);】,同时构造函数中多了一个int concurrencyLevel参数,控制并行度【object[] locks = new object[concurrencyLevel]; for (int i = 0; i < locks.Length; i++){  locks[i] = new object(); }】。如果指定了int capacity参数,很多时候参数bool growLockArray为false【m_growLockArray = growLockArray;】表示ConcurrentDictionary在扩容的时候,object[] locks 这个锁的对象数组不扩容,可以理解为锁的粒度变大了,先前4个key公用一个lock对象,现在可能8个key对应一个对象;m_budget = buckets.Length / locks.Length中的m_budget 可以理解为一个lock对象被多少个key共享

    现在我们来看看TryGetValue获取值,这个方法非常简单,应为读取时不需要加锁的,所以首先根据key计算其哈希值,再找到对应的哈希桶,读取哈希桶的数据【Node n = Volatile.Read<Node>(ref tables.m_buckets[bucketNo])】;一个哈希桶的数据可能有多个【 while (n != null){if (comparer.Equals(n.m_key, key)){ value = n.m_value; return true; } n = n.m_next;}】,所以从这里可以看出来每个 哈希桶里面是一个Node链表数据结构

    接下来我们看看比较复杂的TryAddInternal方法,优先需要根据key来确定哈希桶,无论是添加还是修改 都需要锁定对象,所以这里用的是Monitor.Enter(tables.m_locks[lockNo], ref lockTaken); 在最后在释放锁 Monitor.Exit(tables.m_locks[lockNo]);,如果是添加元素那么直接给里面的哈希桶赋值 Volatile.Write<Node>(ref tables.m_buckets[bucketNo], new Node(key, value, hashcode, tables.m_buckets[bucketNo]));注意Node的构造函数,tables.m_buckets[bucketNo])将是新节点的m_next值,也就是添加的新节点永远是哈希桶链表的第一个节点,这里,赋值后对应的lock对象的计数器需要加1【tables.m_countPerLock[lockNo]++;】,如果每个计数器达到预计达阀值就需要扩容了【if (tables.m_countPerLock[lockNo] > m_budget){ resizeDesired = true;}】,那么修改也是首先找到对应的node节点【如果添加的key所在哈希桶里面存在数据】,如果value是可以直接修改的话,那么我们直接修改【 if (s_isValueWriteAtomic) { node.m_value = value;}】,不是的话那我们就克隆一个节点 替换掉原先的节点【Node newNode = new Node(node.m_key, value, hashcode, node.m_next); if (prev == null){ tables.m_buckets[bucketNo] = newNode; } else{ prev.m_next = newNode;}】,如果是桶的第一个节点那么替换比较简单,否者就修改先前节点的m_next 属性

    接下来我们来看看哈希桶的扩容GrowTable,这个方法比较复杂,我就没怎么仔细研读了,首先是多线程我们需要考虑线程安全,说白了就是加锁 AcquireLocks(0, 1, ref locksAcquired),哈希桶扩容基本是按照2倍来扩容的【 newLength = tables.m_buckets.Length * 2 + 1; while (newLength % 3 == 0 || newLength % 5 == 0 || newLength % 7 == 0){  newLength += 2; }】,在正真扩容前我们需要锁定所有对象【AcquireLocks(1, tables.m_locks.Length, ref locksAcquired);】,扩容首先需要扩容锁的对象数组

     if (m_growLockArray && tables.m_locks.Length < MAX_LOCK_NUMBER)
                    {
                        newLocks = new object[tables.m_locks.Length * 2];
                        Array.Copy(tables.m_locks, newLocks, tables.m_locks.Length);
    
                        for (int i = tables.m_locks.Length; i < newLocks.Length; i++)
                        {
                            newLocks[i] = new object();
                        }
                    }

    然后在是哈希桶扩容,这里扩容可以理解为克隆原先的节点到新的数组中 旧的位置上【newBuckets[newBucketNo] = new Node(current.m_key, current.m_value, nodeHashCode, newBuckets[newBucketNo]);】

     Node[] newBuckets = new Node[newLength];
                    int[] newCountPerLock = new int[newLocks.Length];
    
                    for (int i = 0; i < tables.m_buckets.Length; i++)
                    {
                        Node current = tables.m_buckets[i];
                        while (current != null)
                        {
                            Node next = current.m_next;
                            int newBucketNo, newLockNo;
                            int nodeHashCode = current.m_hashcode;
    
                            if (regenerateHashKeys)
                            {
                                // Recompute the hash from the key
                                nodeHashCode = newComparer.GetHashCode(current.m_key);
                            }
    
                            GetBucketAndLockNo(nodeHashCode, out newBucketNo, out newLockNo, newBuckets.Length, newLocks.Length);
    
                            newBuckets[newBucketNo] = new Node(current.m_key, current.m_value, nodeHashCode, newBuckets[newBucketNo]);
    
                            checked
                            {
                                newCountPerLock[newLockNo]++;
                            }
    
                            current = next;
                        }
                    }

    看来扩容,最后来看看移除元素,首先需要根据key来计算哈希桶的位置【GetBucketAndLockNo(comparer.GetHashCode(key), out bucketNo, out lockNo, tables.m_buckets.Length, tables.m_locks.Length)】,然后锁住对应的对象【  lock (tables.m_locks[lockNo])】,在哈希桶里面获取遍历链表查找对应的key,如果是桶的第一个节点则直接写 Volatile.Write<Node>(ref tables.m_buckets[bucketNo], curr.m_next),否者修改链表prev.m_next = curr.m_next,最后该lock对象的计数器需要减1【tables.m_countPerLock[lockNo]--】。

    -----------------------------在一次面试的时候 被问到Count属性, 我们来看看Count的实现吧:

    private void AcquireAllLocks(ref int locksAcquired)
    {
        // First, acquire lock 0
        AcquireLocks(0, 1, ref locksAcquired);
    
        // Now that we have lock 0, the m_locks array will not change (i.e., grow),
        // and so we can safely read m_locks.Length.
        AcquireLocks(1, m_tables.m_locks.Length, ref locksAcquired);
        Assert(locksAcquired == m_tables.m_locks.Length);
    }
    
     private void AcquireLocks(int fromInclusive, int toExclusive, ref int locksAcquired)
    {
        Assert(fromInclusive <= toExclusive);
        object[] locks = m_tables.m_locks;
    
        for (int i = fromInclusive; i < toExclusive; i++)
        {
            bool lockTaken = false;
            try
            {
               Monitor.Enter(locks[i], ref lockTaken);
            }
            finally
            {
                if (lockTaken)
                {
                    locksAcquired++;
                }
            }
        }
    }
    private int GetCountInternal()
    {
        int count = 0;
    
        // Compute the count, we allow overflow
        for (int i = 0; i < m_tables.m_countPerLock.Length; i++)
        {
            count += m_tables.m_countPerLock[i];
        }
    
        return count;
    }
            
    private void ReleaseLocks(int fromInclusive, int toExclusive)
    {
        Assert(fromInclusive <= toExclusive);
    
        for (int i = fromInclusive; i < toExclusive; i++)
        {
            Monitor.Exit(m_tables.m_locks[i]);
        }
    }

    看到这里Count是需要获取m_tables.m_locks每一个对象的锁, ConcurrentDictionary的性能比lock+Dictionary 的性能高出的主要原因就是锁的粒度变小了, 但是这个count需要获取多个对象的锁, 所以相对耗时,同样GetKeys(),GetValues(),ToArray(),IsEmpty也是和Count一样,需要获取所有的锁.

  • 相关阅读:
    优秀数
    加法检测器
    数字转换
    选课
    二叉苹果树
    分离与合体
    括号配对
    凸多边形的划分
    能量项链
    石子合并
  • 原文地址:https://www.cnblogs.com/majiang/p/7883721.html
Copyright © 2011-2022 走看看