zoukankan      html  css  js  c++  java
  • LINQ之路(3):LINQ扩展

    本篇文章将从三个方面来进行LINQ扩展的阐述:扩展查询操作符、自定义查询操作符和简单模拟LINQ to SQL。

    1.扩展查询操作符

    在实际的使用过程中,Enumerable或Queryable中的扩展方法有时并不能满足我们的需要,我们需要自己扩展一些查询操作符以满足需要。例如,下面的例子:

    var r = Enumerable.Range(1, 10).Zip(Enumerable.Range(11, 5), (s, d) => s + d);
    foreach (var i in r)
    {
        Console.WriteLine(i);
    }
    //output:
    //12
    //14
    //16
    //18
    //20
    

    Enumerable.ZIP扩展是用来将指定函数应用于两个序列的对应元素,以生成结果序列,这里是将序列[1...10]与序列[11...15]相对应位置的元素做加法而生成一个新的序列。内部实现如下:

    static IEnumerable<TResult> ZipIterator<TFirst, TSecond, TResult>(IEnumerable<TFirst> first, IEnumerable<TSecond> second, Func<TFirst, TSecond, TResult> resultSelector) {
            using (IEnumerator<TFirst> e1 = first.GetEnumerator())
                using (IEnumerator<TSecond> e2 = second.GetEnumerator())
                    while (e1.MoveNext() && e2.MoveNext())
                        yield return resultSelector(e1.Current, e2.Current);
        }
    

    很明显,取得是交集,即两个序列对应位置要都有元素才进行处理,所以上面的输出结果也是理所当然的。然而有时候,我们想以第一个序列为主序列,即结果序列的长度总是等于第一个序列的长度。我们来扩展一个查询操作符,取名为LeftZip,实现如下:

    	/// <summary>
        /// Merge right sequence into left sequence by using the specified predicate function.
        /// </summary>
        /// <typeparam name="TLeft"></typeparam>
        /// <typeparam name="TRight"></typeparam>
        /// <typeparam name="TResult"></typeparam>
        /// <param name="lefts"></param>
        /// <param name="rights"></param>
        /// <param name="resultSelector"></param>
        /// <returns></returns>
        public static IEnumerable<TResult> LeftZip<TLeft, TRight, TResult>(this IEnumerable<TLeft> lefts,
            IEnumerable<TRight> rights, Func<TLeft, TRight, TResult> resultSelector)
        {
            if(lefts == null)
                throw new ArgumentNullException("lefts");
            if(rights == null)
                throw new ArgumentNullException("rights");
            if (resultSelector == null)
                throw new ArgumentNullException("resultSelector");
            return LeftZipImpl(lefts, rights, resultSelector);
        }
        /// <summary>
        /// The Implementation of LeftZip
        /// </summary>
        /// <typeparam name="TLeft"></typeparam>
        /// <typeparam name="TRight"></typeparam>
        /// <typeparam name="TResult"></typeparam>
        /// <param name="lefts"></param>
        /// <param name="rights"></param>
        /// <param name="resultSelector"></param>
        /// <returns></returns>
        private static IEnumerable<TResult> LeftZipImpl<TLeft, TRight, TResult>(this IEnumerable<TLeft> lefts,
            IEnumerable<TRight> rights, Func<TLeft, TRight, TResult> resultSelector)
        {
            using (var left = lefts.GetEnumerator())
            {
                using (var right = rights.GetEnumerator())
                {
                    while (left.MoveNext())
                    {
                        if (right.MoveNext())
                        {
                            yield return resultSelector(left.Current, right.Current);
                        }
                        else
                        {
                            do
                            {
                                yield return resultSelector(left.Current, default(TRight));
                            } while (left.MoveNext());
                            yield break;
                        }
                    }
                }
            }
        }
    

    调用LeftZip,代码如下:

    var r = Enumerable.Range(1, 10).LeftZip(Enumerable.Range(11, 5), (s, d) => s + d);
    foreach (var i in r)
    {
        Console.WriteLine(i);
    }
    //output:
    //12
    //14
    //16
    //18
    //20
    //6
    //7
    //8
    //9
    //10
    

    2.自定义查询操作符

    之前,我们在实现枚举器的时候有一种自实现形式,即不继承IEnumerable和IEnumerator接口,自定义一个实现GetEnumerator()的类和一个实现Current和MoveNext的类,即可使用foreach进行迭代。我们还知道LINQ语句转换成了扩展方法的链式调用,标准查询操作符转换了同名扩展方法(首字母大写)。那么,如果我们自己去实现标准查询操作符的同名扩展方法,会不会得到执行呢?
    开始尝试,创建一个静态类LinqExtensions,实现Where扩展方法,如下:

    	/// <summary>
        /// Filters a sequence of values based on a predicate.
        /// </summary>
        /// <typeparam name="TResult"></typeparam>
        /// <param name="source"></param>
        /// <param name="predicate"></param>
        /// <returns></returns>
        public static IEnumerable<TResult> Where<TResult>(this IEnumerable<TResult> source,
            Func<TResult, bool> predicate)
        {
            if (source == null)
                throw new ArgumentNullException("source");
            if (predicate == null)
                throw new ArgumentNullException("predicate");
            return WhereImpl(source, predicate);
        }
        /// <summary>
        /// The implementation of Where
        /// </summary>
        /// <typeparam name="TResult"></typeparam>
        /// <param name="source"></param>
        /// <param name="predicate"></param>
        /// <returns></returns>
        private static IEnumerable<TResult> WhereImpl<TResult>(this IEnumerable<TResult> source,
            Func<TResult, bool> predicate)
        {
            using (var e = source.GetEnumerator())
            {
                while (e.MoveNext())
                {
                    if (predicate(e.Current))
                        yield return e.Current;
                }
            }
        }
    

    调用部分代码如下:

    var r = from e in Enumerable.Range(1, 10)
                where e%2 == 0
                select e;
            foreach (var i in r)
            {
                Console.WriteLine(i);
            }
    //output:
    //2
    //4
    //6
    //8
    //10
    

    如何判断Where扩展方法被调用了呢?调试、在VS中选中Where然后F12转到定义和在Where扩展方法实现中打印输出都可以判断。你可能会有疑问?方法签名是不是要一致呢?答案是否定的。你可以将Where扩展方法改名为Select,然后调用改为如下:

    var r = from e in Enumerable.Range(1, 10)
            //where e % 2 == 0
            select e % 2 == 0;
    //output:
    //2
    //4
    //6
    //8
    //10
    

    最后结合枚举器,举一个例子:

    public class Collection<T>
    {
        private T[] items;
    
        public Collection()
        {
    
        }
    
        public Collection(IEnumerable<T> collection)
        {
            if (collection == null)
                throw new ArgumentNullException("collection");
            items = new T[collection.Count()];
            Array.Copy(collection.ToArray(), items, collection.Count());
        }
    
        public static implicit operator Collection<T>(T[] arr)
        {
            Collection<T> collection = new Collection<T>();
            collection.items = new T[arr.Length];
            Array.Copy(arr, collection.items, arr.Length);
            return collection;
        }
    
        public ItemEnumerator GetEnumerator()
        {
            return new ItemEnumerator(items);
        }
    
        #region Item Enumerator
        public class ItemEnumerator : IDisposable
        {
            private T[] items;
            private int index = -1;
    
            public ItemEnumerator(T[] arr)
            {
                this.items = arr;
            }
            /// <summary>
            /// Current属性
            /// </summary>
            public T Current
            {
                get
                {
                    if (index < 0 || index > items.Length - 1)
                        throw new InvalidOperationException();
                    return items[index];
                }
            }
            /// <summary>
            /// MoveNext方法
            /// </summary>
            /// <returns></returns>
            public bool MoveNext()
            {
                if (index < items.Length - 1)
                {
                    index++;
                    return true;
                }
                else
                {
                    return false;
                }
            }
    
            public void Reset()
            {
                index = -1;
            }
            #region IDisposable 成员
    
            public void Dispose()
            {
                index = -1;
            }
    
            #endregion
        }
        #endregion
    }
    
    public static class EnumerableExtensions
    {
        public static Collection<T> Where<T>(this Collection<T> source, Func<T, bool> predicate)
        {
            if (source == null)
                throw new ArgumentNullException("source");
            if (predicate == null)
                throw new ArgumentNullException("predicate");
            return WhereImpl(source, predicate).ToCollection();
        }
    
        private static IEnumerable<T> WhereImpl<T>(this Collection<T> source, Func<T, bool> predicate)
        {
            using (var e = source.GetEnumerator())
            {
                while (e.MoveNext())
                {
                    if (predicate(e.Current))
                    {
                        yield return e.Current;
                    }
                }
            }
        }
    
        public static Collection<TResult> Select<T, TResult>(this Collection<T> source, Func<T, TResult> selector)
        {
            if (source == null)
                throw new ArgumentNullException("source");
            if (selector == null)
                throw new ArgumentNullException("selector");
            return SelectImpl(source, selector).ToCollection();
        }
    
        private static IEnumerable<TResult> SelectImpl<T, TResult>(this Collection<T> source, Func<T, TResult> selector)
        {
            using (var e = source.GetEnumerator())
            {
                while (e.MoveNext())
                {
                    yield return selector(e.Current);
                }
            }
        }
    
        public static Collection<T> ToCollection<T>(this IEnumerable<T> source)
        {
            if (source == null)
                throw new ArgumentNullException("source");
            return new Collection<T>(source);
        }
    }
    

    包含两个类,一个是作为数据源,一个是用于扩展,调用方法如下:

    Collection<int> collection = new int[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 };
            var r = from c in collection
                    where c % 2 == 0
                    select c;
            foreach (var i in r)
            {
                Console.WriteLine(i);
            }
    
    //output:
    //2
    //4
    //6
    //8
    //10
    

    3.简单模拟LINQ to SQL

    在LINQ之路(2)中,我们简单介绍了LINQ to SQL的原理。在这里,我们通过简单模拟LINQ to SQL来更一步了解LINQ to SQL原理。

    首先创建数据源,创建Query类,实现IQueryable接口:

    public class Query<T> : IQueryable<T>
    {
        #region 字段
    
        private QueryProvider provider;
        private Expression expression;
    
        #endregion
    
        #region 属性
    
        #endregion
    
        #region 构造函数
    
        public Query(QueryProvider provider)
        {
            if (provider == null)
                throw new ArgumentNullException("provider");
            this.provider = provider;
            this.expression = Expression.Constant(this);
        }
    
        public Query(QueryProvider provider, Expression expression)
        {
            if (provider == null)
                throw new ArgumentNullException("provider");
            if (expression == null)
                throw new ArgumentNullException("expression");
            if (!typeof(IQueryable<T>).IsAssignableFrom(expression.Type))
                throw new ArgumentOutOfRangeException("expression");
            this.provider = provider;
            this.expression = expression;
        }
        #endregion
    
        #region 方法
    
        public IEnumerator<T> GetEnumerator()
        {
            return ((IEnumerable<T>) this.provider.Execute(this.expression)).GetEnumerator();
        }
    
        #endregion
    
        #region IEnumerable<T> 成员
    
        IEnumerator<T> IEnumerable<T>.GetEnumerator()
        {
            return this.GetEnumerator();
        }
    
        #endregion
    
        #region IEnumerable 成员
    
        System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
        {
            return this.GetEnumerator();
        }
    
        #endregion
    
        #region IQueryable 成员
    
        Type IQueryable.ElementType
        {
            get { return typeof(T); }
        }
    
        Expression IQueryable.Expression
        {
            get { return this.expression; }
        }
    
        IQueryProvider IQueryable.Provider
        {
            get { return this.provider; }
        }
    
        #endregion
    }
    

    这个类比较简单,实现接口并初始化参数。
    再来看Provider,创建QueryProvider类,实现IQueryProvider接口:

    public class QueryProvider:IQueryProvider
    {
        #region 字段
    
        private IDbConnection dbConnection;
    
        #endregion
    
        #region 属性
    
        #endregion
    
        #region 构造函数
        public QueryProvider(IDbConnection dbConnection)
        {
            this.dbConnection = dbConnection;
        }
        #endregion
    
        #region 方法
    
        #endregion
    
        #region IQueryProvider 成员
    
        public IQueryable<TElement> CreateQuery<TElement>(System.Linq.Expressions.Expression expression)
        {
            return new Query<TElement>(this, expression);
        }
    
        public IQueryable CreateQuery(System.Linq.Expressions.Expression expression)
        {
            var type = expression.Type;
            try
            {
                return (IQueryable) Activator.CreateInstance(typeof (Query<>).MakeGenericType(type), this, expression);
            }
            catch (TargetInvocationException e)
            {
                throw e.InnerException;
            }
        }
    
        TResult IQueryProvider.Execute<TResult>(System.Linq.Expressions.Expression expression)
        {
            return (TResult) this.Execute(expression);
        }
    
        object IQueryProvider.Execute(System.Linq.Expressions.Expression expression)
        {
            return this.Execute(expression);
        }
    
        public virtual object Execute(Expression expression)
        {
            if(expression == null)
                throw new ArgumentNullException("expression");
            return ExecuteImpl(expression);
        }
    
        private IEnumerable ExecuteImpl(Expression expression)
        {
            //var type = expression.Type;
            //var entityType = type.GetGenericArguments()[0];
            List<Product> products = new List<Product>();
            QueryTranslator queryTranslator = new QueryTranslator();
            var cmdText = queryTranslator.Translate(expression);
            IDbCommand cmd = dbConnection.CreateCommand();
            cmd.CommandText = cmdText;
            using (IDataReader dataReader = cmd.ExecuteReader())
            {
                while (dataReader.Read())
                {
                    Product product = new Product();
                    product.ID = dataReader.GetInt32(0);
                    product.Name = dataReader.GetString(1);
                    product.Type = dataReader.GetInt32(2);
                    products.Add(product);
                }
            }
            return products;
        }
        #endregion
    }
    

    再来看看查询翻译类,创建QueryTranslator类,继承自ExpressionVisitor抽象类:

    public class QueryTranslator:ExpressionVisitor
    {
        #region 字段
    
        private StringBuilder sb;
    
        #endregion
    
        #region 属性
    
        #endregion
    
        #region 构造函数
        public QueryTranslator()
        {
    
        }
    
        #endregion
    
        #region 方法
    
        public string Translate(Expression expression)
        {
            this.sb = new StringBuilder();
            this.Visit(expression);
            return this.sb.ToString();
        }
    
        private static Expression StripQuotes(Expression e)
        {
            while (e.NodeType == ExpressionType.Quote)
            {
                e = ((UnaryExpression) e).Operand;
            }
            return e;
        }
    
        protected override Expression VisitMethodCall(MethodCallExpression node)
        {
            if (node.Method.DeclaringType == typeof (Queryable) &&
                node.Method.Name == "Where")
            {
                sb.Append("SELECT * FROM (");
                this.Visit(node.Arguments[0]);
                sb.Append(") AS T WHERE ");
                LambdaExpression lambda = (LambdaExpression) StripQuotes(node.Arguments[1]);
                this.Visit(lambda.Body);
                return node;
            }
            throw new NotSupportedException(string.Format("The Method '{0}' is not supported", node.Method.Name));
        }
    
        protected override Expression VisitBinary(BinaryExpression node)
        {
            sb.Append("(");
            this.Visit(node.Left);
            switch (node.NodeType)
            {
                case ExpressionType.Equal:
                    sb.Append(" = ");
                    break;
                case ExpressionType.NotEqual:
                    sb.Append(" <> ");
                    break;
                case ExpressionType.GreaterThan:
                    sb.Append(" > ");
                    break;
                case ExpressionType.GreaterThanOrEqual:
                    sb.Append(" >= ");
                    break;
                case ExpressionType.LessThan:
                    sb.Append(" < ");
                    break;
                case ExpressionType.LessThanOrEqual:
                    sb.Append(" <= ");
                    break;
                default:
                    throw new NotSupportedException(string.Format("The binary operator '{0}' is not supported", node.NodeType));
            }
            this.Visit(node.Right);
            sb.Append(")");
            return node;
        }
    
        protected override Expression VisitConstant(ConstantExpression node)
        {
            IQueryable q = node.Value as IQueryable;
            if (q != null)
            {
                sb.Append("SELECT * FROM ");
                sb.Append(DataContext.MetaTables.FirstOrDefault(f => f.Type == q.ElementType).TableName);
                return node;
            }
            else if(node.Value == null)
            {
                sb.Append("NULL");
            }
            else
            {
                switch (Type.GetTypeCode(node.Value.GetType()))
                {
                    case TypeCode.Boolean:
                        sb.Append(((bool) node.Value) ? 1 : 0);
                        break;
                    case TypeCode.String:
                        sb.AppendFormat("'{0}'", node.Value);
                        break;
                    case TypeCode.Object:
                        throw new NotSupportedException(string.Format("The constant for '{0}' is not supported", node.Value));
                    default:
                        sb.Append(node.Value);
                        break;
                }
            }
            return node;
        }
    
        protected override Expression VisitMember(MemberExpression node)
        {
            if (node.Expression != null && node.Expression.NodeType == ExpressionType.Parameter)
            {
                sb.Append(node.Member.Name);
                return node;
            }
            throw new NotSupportedException(string.Format("The member '{0}' is not supported", node.Member.Name));
        }
    
        #endregion
    }
    

    重写Visit相关方法,以Visitor模式解析表达式目录树。
    最后来看下DataContext的实现:

    public class DataContext : IDisposable
    {
        #region 字段
    
        private IDbConnection dbConnection;
        private static List<MetaTable> metaTables; 
        #endregion
    
        #region 属性
    
        public TextWriter Log { get; set; }
    
        public IDbConnection DbConnection
        {
            get { return this.dbConnection; }
        }
    
        public static List<MetaTable> MetaTables
        {
            get { return metaTables; }
        }
        #endregion
    
        #region 构造函数
        public DataContext(string connString)
        {
            if (connString == null)
                throw new ArgumentNullException(connString);
            dbConnection = new SqlConnection(connString);
            dbConnection.Open();
            InitTables();
        }
        #endregion
    
        #region 方法
    
        private void InitTables()
        {
            metaTables = new List<MetaTable>();
            var props = this.GetType().GetProperties(BindingFlags.Public | BindingFlags.Instance);
            foreach (var prop in props)
            {
                var propType = prop.PropertyType;
                if (propType.IsGenericType && propType.GetGenericTypeDefinition() == typeof (Query<>))
                {
                    var entityType = propType.GetGenericArguments()[0];
                    var entityAttr = entityType.GetCustomAttribute<MappingAttribute>(true);
                    if (entityAttr != null)
                    {
                        var metaTable = new MetaTable();
                        metaTable.Type = entityType;
                        metaTable.TableName = entityAttr.Name;
                        metaTable.MappingAttribute = entityAttr;
                        var columnProps = entityType.GetProperties(BindingFlags.Public | BindingFlags.Instance);
                        foreach (var columnProp in columnProps)
                        {
                            var columnPropAttr = columnProp.GetCustomAttribute<MappingAttribute>(true);
                            if (columnPropAttr != null)
                            {
                                MetaColumn metaColumn = new MetaColumn();
                                metaColumn.MappingAttribute = columnPropAttr;
                                metaColumn.ColumnName = columnPropAttr.Name;
                                metaColumn.PropertyInfo = columnProp;
                                metaTable.MetaColumns.Add(metaColumn);
                            }
                        }
                        metaTables.Add(metaTable);
                    }
                }
            }
        }
        #endregion
    
        #region IDisposable 成员
    
        protected virtual void Dispose(bool disposing)
        {
            if (!disposing) return;
            if (dbConnection != null)
                dbConnection.Close();
        }
    
        public void Dispose()
        {
            Dispose(true);
            GC.SuppressFinalize(this);
        }
    
        #endregion
    }
    [Database(Name = "IT_Company")]
    public class QueryDataContext : DataContext
    {
        public QueryDataContext(string connString)
            : base(connString)
        {
            QueryProvider provider = new QueryProvider(DbConnection);
            Products = new Query<Product>(provider);
        }
        public Query<Product> Products
        {
            get;
            set;
        }
    }
    

    调用如下:

    class Program
    {
        private static readonly string connString =
            "Data Source=.;Initial Catalog=IT_Company;Persist Security Info=True;User ID=sa;Password=123456";
        static void Main(string[] args)
        {
            using (var context = new QueryDataContext(connString))
            {
                var query = from product in context.Products
                    where product.Type == 1
                    select product;
                foreach (var product in query)
                {
                    Console.WriteLine(product.Name);
                }
                Console.ReadKey();
            }
        }
    }
    //output:
    //MG500
    //MG1000
  • 相关阅读:
    C++头文件保护符和变量的声明定义
    ReactNavtive框架教程(2)
    扩展方法使用
    华为0基础——(练习用)挑7
    HTTP Status 500
    屏蔽DataGridView控件DataError 事件提示的异常信息
    POJ 3630 Phone List Trie题解
    【学习总结】数学-向量叉积
    9.1-9.30推荐文章汇总
    Autolayout环境设置任意个数相等间距排列的按钮的方法
  • 原文地址:https://www.cnblogs.com/jellochen/p/the-extension-of-linq.html
Copyright © 2011-2022 走看看