zoukankan      html  css  js  c++  java
  • 构建属于自己的ORM框架之二--IQueryable的奥秘

    上篇文章标题乱起,被吐槽了,这次学乖了。

    上篇文章中介绍了如何解析Expression生成对应的SQL语句,以及IQueryable的一些概念,以及我们所搭建的框架的思想等。但还没把它们结合并应用起来。这一篇文章将更黄更暴力,揭露IQueryable在实际使用中延迟加载的实现原理,结合上篇对Expression的解析,我们来实现一个自己的“延迟加载”

    如果还不太了解如何解析Expression和IQueryable的一些基本概念,可以先看看我的上篇文章

    我们先来做些基本工作,定义一个IDataBase接口,里面可以定义些查询,删除,修改,新增等方法,为了节约时间,我们就定义一个查询和删除的方法,再定义一个获取IQueryable<T>实例的方法

       public interface IDataBase
        {
            List<T> FindAs<T>(Expression<Func<T, bool>> lambdawhere);
            int Remove<T>(Expression<Func<T, bool>> lambdawhere);
            IQueryable<T> Source<T>();
        }

    再添加一个类DBSql,实现我们上面的IDataBase接口,这个类是负责提供对sql数据库的操作

    复制代码
     public class DBSql : IDataBase
        {
            public List<T> FindAs<T>(Expression<Func<T, bool>> lambdawhere)
            {
                throw new NotImplementedException();
            }
    
            public int Remove<T>(Expression<Func<T, bool>> lambdawhere)
            {
                throw new NotImplementedException();
            }
    
            public IQueryable<T> Source<T>()
            {
                throw new NotImplementedException();
            }
        }
    复制代码

    IQueryable<T>

    上篇文章中有个朋友的回复对IQueryable的解释十分到位,“IQueryable只存贮条件,不立即运行,从而可以实现延迟加载。”那它是如何存贮条件,如何延迟加载的?

    这时我们为了提供 public IQueryable<T> Source<T>() 所需的对象。我们再来建一个SqlQuery类,实现IQueryable<T>。

    复制代码
       public class SqlQuery<T> : IQueryable<T>
        {
            public IEnumerator<T> GetEnumerator()
            {
                throw new NotImplementedException();
            }
    
            System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
            {
                throw new NotImplementedException();
            }
    
            public Type ElementType
            {
                get { throw new NotImplementedException(); }
            }
    
            public Expression Expression
            {
                get { throw new NotImplementedException(); }
            }
    
            public IQueryProvider Provider
            {
                get { throw new NotImplementedException(); }
            }
        }
    复制代码

    看到这里大家都不陌生吧?

    GetEnumerator()是IEnumerable<T>里的。有了它我们就能foreach了。有泛型和非泛型版本,所以有2个

    Type提供访问当前对象的类型(反正由你定义。。。)

    Expression是贮存查询条件的

    IQueryProvider简单的翻译过来就是查询提供者,它是负责创建查询条件和执行查询的。我们写一个SqlProvider类来实现它

    复制代码
      public class SqlProvider<T> : IQueryProvider
        {
    
            public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
            {
                throw new NotImplementedException();
            }
    
            public IQueryable CreateQuery(Expression expression)
            {
                throw new NotImplementedException();
            }
    
            public TResult Execute<TResult>(Expression expression)
            {
                throw new NotImplementedException();
            }
    
            public object Execute(Expression expression)
            {
                throw new NotImplementedException();
            }
        }
    复制代码

    CreateQuery是创建查询条件。。

    平时我们

    IQueryable query=xxx源;

    query=query.Where(x=>x.Name=="123");

    这时Where方法里做的其实就是将前面query的Expression属性和Where里的(x=>x.Name=="123")相并,并且调用Provider属性里的CreateQuery方法。我们可以把我们的代码改成这样,来看看到底是不是这么回事。

    复制代码
       public class DBSql : IDataBase
        {
            public IQueryable<T> Source<T>()
            {
                return new SqlQuery<T>();
            }
    
            public List<T> FindAs<T>(Expression<Func<T, bool>> lambdawhere)
            {
                throw new NotImplementedException();
            }
    
            public int Remove<T>(Expression<Func<T, bool>> lambdawhere)
            {
                throw new NotImplementedException();
            }
        }
    
        public class SqlQuery<T> : IQueryable<T>
        {
    
            private Expression _expression;
            private IQueryProvider _provider;
    
            public SqlQuery()
            {
                _provider = new SqlProvider<T>();
                _expression = Expression.Constant(this);
            }
    
            public SqlQuery(Expression expression, IQueryProvider provider)
            {
                _expression = expression;
                _provider = provider;
            }
    
            public IEnumerator<T> GetEnumerator()
            {
                throw new NotImplementedException();
            }
    
            System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
            {
                throw new NotImplementedException();
            }
    
            public Type ElementType
            {
                get { return typeof(SqlQuery<T>); }
            }
    
            public Expression Expression
            {
                get { return _expression; }
            }
    
            public IQueryProvider Provider
            {
                get { return _provider; }
            }
        }
    
        public class SqlProvider<T> : IQueryProvider
        {
            public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
            {
                IQueryable<TElement> query = new SqlQuery<TElement>(expression, this);
                return query;
            }
    
            public IQueryable CreateQuery(Expression expression)
            {
                throw new NotImplementedException();
            }
    
            public TResult Execute<TResult>(Expression expression)
            {
                throw new NotImplementedException();
            }
    
            public object Execute(Expression expression)
            {
                throw new NotImplementedException();
            }
        }
    复制代码
    复制代码
         public class Staff
            {
                public int ID { get; set; }
                public string Code { get; set; }
                public string Name { get; set; }
                public DateTime? Birthday { get; set; }
                public bool Deletion { get; set; }
            }
    
            static void Main(string[] args)
            {
                IDataBase db = new DBSql();
                IQueryable<Staff> query = db.Source<Staff>();
                string name = "张三";
                Expression express = null;
                query = query.Where(x => x.Name == "赵建华");
                express = query.Expression;
                query = query.Where(x => x.Name == name);
                express = query.Expression;
            }
    复制代码

    段点打在 

    public IQueryable<TElement> CreateQuery<TElement>(Expression expression)

    每次query.Where都会跑这里来。并且Expression都是前后相并的结果。

    到了这一步,相信大家都明白了IQueryable只存贮条件这个概念了吧。

    那延迟加载呢?什么时候加载啊!当我们foreach或者ToList/ToArray时啊。这时你想到了什么?GetEnumerator()。在调用GetEnumerator()时。我们再调用Provider里的Execute(Expression)。里面解析Expression,生成SQL语句,通过反射的方式生成实例,再一个个返回回去。完成!下面我直接给代码了。解析Expression的类我也改了,这个更黄更暴力。

    复制代码
      public class ResolveExpression
        {
            public Dictionary<string, object> Argument;
            public string SqlWhere;
            public SqlParameter[] Paras;
            private int index = 0;
            /// <summary>
            /// 解析lamdba,生成Sql查询条件
            /// </summary>
            /// <param name="expression"></param>
            /// <returns></returns>
            public void ResolveToSql(Expression expression)
            {
                this.index = 0;
                this.Argument = new Dictionary<string, object>();
                this.SqlWhere = Resolve(expression);
                this.Paras = Argument.Select(x => new SqlParameter(x.Key, x.Value)).ToArray();
            }
    
            private object GetValue(Expression expression)
            {
                if (expression is ConstantExpression)
                    return (expression as ConstantExpression).Value;
                if (expression is UnaryExpression)
                {
                    UnaryExpression unary = expression as UnaryExpression;
                    LambdaExpression lambda = Expression.Lambda(unary.Operand);
                    Delegate fn = lambda.Compile();
                    return fn.DynamicInvoke(null);
                }
                if (expression is MemberExpression)
                {
                    MemberExpression member = expression as MemberExpression;
                    string name = member.Member.Name;
                    var constant = member.Expression as ConstantExpression;
                    if (constant == null)
                        throw new Exception("取值时发生异常" + member);
                    return constant.Value.GetType().GetFields().First(x => x.Name == name).GetValue(constant.Value);
                }
                throw new Exception("无法获取值" + expression);
            }
    
            private string Resolve(Expression expression)
            {
                if (expression is LambdaExpression)
                {
                    LambdaExpression lambda = expression as LambdaExpression;
                    expression = lambda.Body;
                    return Resolve(expression);
                }
                if (expression is BinaryExpression)//解析二元运算符
                {
                    BinaryExpression binary = expression as BinaryExpression;
                    if (binary.Left is MemberExpression)
                    {
                        object value = GetValue(binary.Right);
                        return ResolveFunc(binary.Left, value, binary.NodeType);
                    }
                    if (binary.Left is MethodCallExpression && (binary.Right is UnaryExpression || binary.Right is MemberExpression))
                    {
                        object value = GetValue(binary.Right);
                        return ResolveLinqToObject(binary.Left, value, binary.NodeType);
                    }
                }
                if (expression is UnaryExpression)//解析一元运算符
                {
                    UnaryExpression unary = expression as UnaryExpression;
                    if (unary.Operand is MethodCallExpression)
                    {
                        return ResolveLinqToObject(unary.Operand, false);
                    }
                    if (unary.Operand is MemberExpression)
                    {
                        return ResolveFunc(unary.Operand, false, ExpressionType.Equal);
                    }
                }
                if (expression is MethodCallExpression)//解析扩展方法
                {
                    return ResolveLinqToObject(expression, true);
                }
                if (expression is MemberExpression)//解析属性。。如x.Deletion
                {
                    return ResolveFunc(expression, true, ExpressionType.Equal);
                }
                var body = expression as BinaryExpression;
                if (body == null)
                    throw new Exception("无法解析" + expression);
                var Operator = GetOperator(body.NodeType);
                var Left = Resolve(body.Left);
                var Right = Resolve(body.Right);
                string Result = string.Format("({0} {1} {2})", Left, Operator, Right);
                return Result;
            }
    
            /// <summary>
            /// 根据条件生成对应的sql查询操作符
            /// </summary>
            /// <param name="expressiontype"></param>
            /// <returns></returns>
            private string GetOperator(ExpressionType expressiontype)
            {
                switch (expressiontype)
                {
                    case ExpressionType.And:
                        return "and";
                    case ExpressionType.AndAlso:
                        return "and";
                    case ExpressionType.Or:
                        return "or";
                    case ExpressionType.OrElse:
                        return "or";
                    case ExpressionType.Equal:
                        return "=";
                    case ExpressionType.NotEqual:
                        return "<>";
                    case ExpressionType.LessThan:
                        return "<";
                    case ExpressionType.LessThanOrEqual:
                        return "<=";
                    case ExpressionType.GreaterThan:
                        return ">";
                    case ExpressionType.GreaterThanOrEqual:
                        return ">=";
                    default:
                        throw new Exception(string.Format("不支持{0}此种运算符查找!" + expressiontype));
                }
            }
    
    
            private string ResolveFunc(Expression left, object value, ExpressionType expressiontype)
            {
                string Name = (left as MemberExpression).Member.Name;
                string Operator = GetOperator(expressiontype);
                string Value = value.ToString();
                string CompName = SetArgument(Name, Value);
                string Result = string.Format("({0} {1} {2})", Name, Operator, CompName);
                return Result;
            }
    
            private string ResolveLinqToObject(Expression expression, object value, ExpressionType? expressiontype = null)
            {
                var MethodCall = expression as MethodCallExpression;
                var MethodName = MethodCall.Method.Name;
                switch (MethodName)//这里其实还可以改成反射调用,不用写switch
                {
                    case "Contains":
                        if (MethodCall.Object != null)
                            return Like(MethodCall);
                        return In(MethodCall, value);
                    case "Count":
                        return Len(MethodCall, value, expressiontype.Value);
                    case "LongCount":
                        return Len(MethodCall, value, expressiontype.Value);
                    default:
                        throw new Exception(string.Format("不支持{0}方法的查找!", MethodName));
                }
            }
    
            private string SetArgument(string name, string value)
            {
                name = "@" + name;
                string temp = name;
                while (Argument.ContainsKey(temp))
                {
                    temp = name + index;
                    index = index + 1;
                }
                Argument[temp] = value;
                return temp;
            }
    
            private string In(MethodCallExpression expression, object isTrue)
            {
                var Argument1 = expression.Arguments[0];
                var Argument2 = expression.Arguments[1] as MemberExpression;
                var fieldValue = GetValue(Argument1);
                object[] array = fieldValue as object[];
                List<string> SetInPara = new List<string>();
                for (int i = 0; i < array.Length; i++)
                {
                    string Name_para = "InParameter" + i;
                    string Value = array[i].ToString();
                    string Key = SetArgument(Name_para, Value);
                    SetInPara.Add(Key);
                }
                string Name = Argument2.Member.Name;
                string Operator = Convert.ToBoolean(isTrue) ? "in" : " not in";
                string CompName = string.Join(",", SetInPara);
                string Result = string.Format("{0} {1} ({2})", Name, Operator, CompName);
                return Result;
            }
    
            private string Like(MethodCallExpression expression)
            {
                Expression argument = expression.Arguments[0];
                object Temp_Vale = GetValue(argument);
                string Value = string.Format("%{0}%", Temp_Vale);
                string Name = (expression.Object as MemberExpression).Member.Name;
                string CompName = SetArgument(Name, Value);
                string Result = string.Format("{0} like {1}", Name, CompName);
                return Result;
            }
    
            private string Len(MethodCallExpression expression, object value, ExpressionType expressiontype)
            {
                object Name = (expression.Arguments[0] as MemberExpression).Member.Name;
                string Operator = GetOperator(expressiontype);
                string CompName = SetArgument(Name.ToString(), value.ToString());
                string Result = string.Format("len({0}){1}{2}", Name, Operator, CompName);
                return Result;
            }
    
        }
    复制代码
      public interface IDataBase
        {
            List<T> FindAs<T>(Expression<Func<T, bool>> lambdawhere);
            int Remove<T>(Expression<Func<T, bool>> lambdawhere);
            IQueryable<T> Source<T>();
        }
    复制代码
    namespace Data.DataBase
    {
        public class DBSql : IDataBase
        {
            private readonly static string ConnectionString = @"Data Source=.;Initial Catalog=btmmcms-Standard;Persist Security Info=True;User ID=sa;Password=sa;";
    
            public IQueryable<T> Source<T>()
            {
                return new SqlQuery<T>();
            }
    
            public List<T> FindAs<T>(Expression<Func<T, bool>> lambdawhere)
            {
                using (SqlConnection Conn = new SqlConnection(ConnectionString))
                {
                    using (SqlCommand Command = new SqlCommand())
                    {
                        try
                        {
                            Command.Connection = Conn;
                            Conn.Open();
                            string sql = string.Format("select * from {0}", typeof(T).Name);
                            if (lambdawhere != null)
                            {
                                ResolveExpression resolve = new ResolveExpression();
                                resolve.ResolveToSql(lambdawhere);
                                sql = string.Format("{0} where {1}", sql, resolve.SqlWhere);
                                Command.Parameters.AddRange(resolve.Paras);
                            }
                            //为了测试,就在这里打印出sql语句了
                            Console.WriteLine(sql);
                            Command.CommandText = sql;
                            SqlDataReader dataReader = Command.ExecuteReader();
                            List<T> ListEntity = new List<T>();
                            while (dataReader.Read())
                            {
                                var constructor = typeof(T).GetConstructor(new Type[] { });
                                T Entity = (T)constructor.Invoke(null);
                                foreach (var item in Entity.GetType().GetProperties())
                                {
                                    var value = dataReader[item.Name];
                                    if (value == null)
                                        continue;
                                    if (value is DBNull)
                                        value = null;
                                    item.SetValue(Entity, value, null);
                                }
                                ListEntity.Add(Entity);
                            }
                            if (ListEntity.Count == 0)
                                return null;
                            return ListEntity;
                        }
                        catch (Exception ex)
                        {
                            throw ex;
                        }
                        finally
                        {
                            Conn.Close();
                        }
                    }
                }
            }
    
            public int Remove<T>(Expression<Func<T, bool>> lambdawhere)
            {
                throw new NotImplementedException();
            }
        }
    
        public class SqlQuery<T> : IQueryable<T>
        {
    
            private Expression _expression;
            private IQueryProvider _provider;
    
            public SqlQuery()
            {
                _provider = new SqlProvider<T>();
                _expression = Expression.Constant(this);
            }
    
            public SqlQuery(Expression expression, IQueryProvider provider)
            {
                _expression = expression;
                _provider = provider;
            }
    
            public IEnumerator<T> GetEnumerator()
            {
                var result = _provider.Execute<List<T>>(_expression);
                if (result == null)
                    yield break;
                foreach (var item in result)
                {
                    yield return item;
                }
            }
    
            System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
            {
                throw new NotImplementedException();
            }
    
            public Type ElementType
            {
                get { return typeof(SqlQuery<T>); }
            }
    
            public Expression Expression
            {
                get { return _expression; }
            }
    
            public IQueryProvider Provider
            {
                get { return _provider; }
            }
        }
    
        public class SqlProvider<T> : IQueryProvider
        {
    
            public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
            {
                IQueryable<TElement> query = new SqlQuery<TElement>(expression, this);
                return query;
            }
    
            public IQueryable CreateQuery(Expression expression)
            {
                throw new NotImplementedException();
            }
    
            public TResult Execute<TResult>(Expression expression)
            {
                MethodCallExpression methodCall = expression as MethodCallExpression;
                Expression<Func<T, bool>> result = null;
                while (methodCall != null)
                {
                    Expression method = methodCall.Arguments[0];
                    Expression lambda = methodCall.Arguments[1];
                    LambdaExpression right = (lambda as UnaryExpression).Operand as LambdaExpression;
                    if (result == null)
                    {
                        result = Expression.Lambda<Func<T, bool>>(right.Body, right.Parameters);
                    }
                    else
                    {
                        Expression left = (result as LambdaExpression).Body;
                        Expression temp = Expression.And(right.Body, left);
                        result = Expression.Lambda<Func<T, bool>>(temp, result.Parameters);
                    }
                    methodCall = method as MethodCallExpression;
                }
                var source = new DBSql().FindAs<T>(result);
                dynamic _temp = source;
                TResult t = (TResult)_temp;
                return t;
            }
    
            public object Execute(Expression expression)
            {
                throw new NotImplementedException();
            }
        }
    }
    复制代码

    搞定,这时可以改下数据库连接,连到自己的数据库,然后像下面这样,添加一个实体类(要与数据库表对应),就可以使用了

    复制代码
       class Program
        {
            public class Staff
            {
                public int ID { get; set; }
                public string Code { get; set; }
                public string Name { get; set; }
                public DateTime? Birthday { get; set; }
                public bool Deletion { get; set; }
            }
    
            static void Main(string[] args)
            {
                IDataBase db = new DBSql();
                IQueryable<Staff> query = db.Source<Staff>();
                query = query.Where(x => x.Name == "张三");
                foreach (var item in query)
                {
    
                }
            }
        }
    复制代码

    是不是很简单?

    虽然信息量有点大,但慢慢理清并消化,我相信会对你又很大帮助!

  • 相关阅读:
    189. go学习1
    [Access][Microsoft][ODBC 驱动程序管理器] 无效的字符串或缓冲区长度 Invalid string or buffer length
    聊聊我对 GraphQL 的一些认知
    gin 源码阅读(2)
    gin 源码阅读(1)
    自动化测试感悟——感悟10条
    Python转exe神器pyinstaller
    在用Python时遇到的坑
    Python BeautifulSoup库 常用方法
    DCDC反馈电路串联的电阻
  • 原文地址:https://www.cnblogs.com/zhangxiaolei521/p/5552175.html
Copyright © 2011-2022 走看看