zoukankan      html  css  js  c++  java
  • 自己写一个类似Dapper的ORM框架(c#)

    Dapper就是一堆Connection的扩展方法,我们也用相同的方法实现,为了练习反射写的,原创~

    使用技术:泛型、反射、表达式树...

    客户端调用:

            static void Main(string[] args)
            {
                var connection = new SqlConnection("Data Source=.;User Id=sa;Password=123456;Database=fanDB;");
                //
                connection.Insert(new Person() { Name = "fan11", Age = 1 });
                connection.Insert(new List<Person> {
                    new Person() { Name = "fan432", Age = 24 },
                    new Person() { Name = "fan", Age = 4 }
                });
                //
                connection.Delete<Person>(5);
                connection.Delete(new Person() { ID = 6 });
                //
                connection.Update(new Person() { ID = 17, Name = "fanfan", Age = 18 });
                //
                var list = connection.Select<Person>(p => p.Name == "fan" || p.Name.Contains("fan1") || p.Name.StartsWith("fan") || p.Name.EndsWith("fan") && p.Age > 3);
    
                Console.ReadKey();
            }

    ORM:

        public static class ORM
        {
            private const string ID_NAME = "ID";
            private const string INSERT_SQL = "INSERT INTO @TABLE_NAME(@COLUMNS) VALUES(@VALUES)";
            private const string SELECT_SQL = "SELECT * FROM @TABLE_NAME WHERE @WHERE";
            private const string DELETE_SQL = "DELETE FROM @TABLE_NAME WHERE @WHERE";
            private const string UPDATE_SQL = "UPDATE @TABLE_NAME SET @UPDATE_COLUMNS WHERE @WHERE";
            private static readonly ConcurrentDictionary<Type, PropertyInfo[]> PROPERTIES_CACHE = new System.Collections.Concurrent.ConcurrentDictionary<Type, PropertyInfo[]>();
            private static readonly WhereBuilder WHERE_BUILDER = null;//通过Expression生成where
            static ORM()
            {
                WHERE_BUILDER = new WhereBuilder('[', ']');
            }
            public static int Insert<T>(this SqlConnection connection, T entity)
            {
                int result = 0;
                var t = typeof(T);
                var tableName = t.Name;
                var columnInfoList = GetColumnInfos(entity);
                var excludeIDColumns = columnInfoList.Where(c => c.Name != ID_NAME);
                var columnNames = excludeIDColumns.Select(c => c.Name);
                var columnParameterNames = excludeIDColumns.Select(c => "@" + c.Name);
                string sql = INSERT_SQL.Replace("@TABLE_NAME", tableName)
                    .Replace("@COLUMNS", string.Join(',', columnNames))
                    .Replace("@VALUES", string.Join(',', columnParameterNames));
    
                SqlParameter[] paras = excludeIDColumns.Select(c => new SqlParameter("@" + c.Name, c.Value)).ToArray();
                OpenConnection(connection);
                using (var command = connection.CreateCommand())
                {
                    command.CommandType = CommandType.Text;
                    command.CommandText = sql;
                    command.Parameters.AddRange(paras);
                    result = command.ExecuteNonQuery();
                }
                return result;
            }
            public static int Insert<T>(this SqlConnection connection, List<T> list)
            {
                int result = 0;
                foreach (var entity in list)
                {
                    result += connection.Insert(entity);
                }
                return result;
            }
    
            public static List<T> Select<T>(this SqlConnection connection, Expression<Func<T, bool>> whereExp) where T : new()
            {
                List<T> list = new List<T>();
                var t = typeof(T);
                var tableName = t.Name;
                var wherePart = WHERE_BUILDER.ToSql<T>(whereExp);
                var whereParameter = wherePart.Parameters;
                var paras = whereParameter.Select(p => new SqlParameter(p.Key, p.Value)).ToArray();
                string sql = SELECT_SQL.Replace("@TABLE_NAME", tableName)
                            .Replace("@WHERE", wherePart.Sql);
                OpenConnection(connection);
    
    
                using (var command = connection.CreateCommand())
                {
                    command.CommandType = CommandType.Text;
                    command.CommandText = sql;
                    command.Parameters.AddRange(paras);
                    using (var reader = command.ExecuteReader())
                    {
                        while (reader.Read())
                        {
                            list.Add(ReaderToEntity<T>(reader));
                        }
                    }
                }
                return list;
    
            }
    
            public static int Delete<T>(this SqlConnection connection, int ID)
            {
                int result = 0;
                var t = typeof(T);
                var tableName = t.Name;
                string sql = DELETE_SQL
                    .Replace("@TABLE_NAME", tableName)
                    .Replace("@WHERE", $"{ID_NAME}=@{ID_NAME}");
                SqlParameter[] paras = new SqlParameter[] { new SqlParameter("@" + ID_NAME, ID) };
                OpenConnection(connection);
                using (var command = connection.CreateCommand())
                {
                    command.CommandType = CommandType.Text;
                    command.CommandText = sql;
                    command.Parameters.AddRange(paras);
                    result = command.ExecuteNonQuery();
                }
                return result;
            }
            public static int Delete<T>(this SqlConnection connection, T entity)
            {
                var IDProperty = entity.GetType().GetProperty(ID_NAME);
                int ID = (int)IDProperty.GetValue(entity);
                return connection.Delete<T>(ID);
            }
    
            public static int Update<T>(this SqlConnection connection, T entity)
            {
                int result = 0;
                var t = typeof(T);
                var tableName = t.Name;
                var columnInfoList = GetColumnInfos(entity);
                var excludeIDColumns = columnInfoList.Where(c => c.Name != ID_NAME);
                var columnNames = excludeIDColumns.Select(c => c.Name);
                var columnParameters = excludeIDColumns.Select(c => c.Name + "=@" + c.Name);
                string sql = UPDATE_SQL.Replace("@TABLE_NAME", tableName)
                    .Replace("@UPDATE_COLUMNS", string.Join(',', columnParameters))
                    .Replace("@WHERE", $"{ID_NAME}=@ID");
                SqlParameter[] paras = columnInfoList.Select(c => new SqlParameter("@" + c.Name, c.Value)).ToArray();
    
                OpenConnection(connection);
                using (var command = connection.CreateCommand())
                {
                    command.CommandType = CommandType.Text;
                    command.CommandText = sql;
                    command.Parameters.AddRange(paras);
                    result = command.ExecuteNonQuery();
                }
                return result;
            }
    
            private static T ReaderToEntity<T>(SqlDataReader reader) where T : new()
            {
                var entity = Activator.CreateInstance(typeof(T));
    
                var propertyInfos = GetPropertys<T>();
                foreach (var propertyInfo in propertyInfos)
                {
                    var value = reader[propertyInfo.Name];
                    propertyInfo.SetValue(entity, value);
                }
                return (T)entity;
            }
            private static PropertyInfo[] GetPropertys<T>()
            {
                return PROPERTIES_CACHE.GetOrAdd(typeof(T), t =>
                {
                    return t.GetProperties();
                });
            }
            private static List<ColumnInfo> GetColumnInfos<T>(T entity)
            {
                var t = entity.GetType();
    
                var columnInfos = new List<ColumnInfo>();
                var properties = GetPropertys<T>();
                for (int i = 0; i < properties.Length; i++)
                {
                    var prop = properties[i];
                    columnInfos.Add(new ColumnInfo(prop.Name, prop.PropertyType.FullName, prop.GetValue(entity)));
                }
                return columnInfos;
            }
            private static DbType GetDbType(string typeName)
            {
                DbType type = DbType.String;
                switch (typeName)
                {
                    case "System.String":
                        type = DbType.String; break;
                    case "System.Int32":
                        type = DbType.Int32; break;
                    case "System.Decimal":
                        type = DbType.Decimal;break;
                        //其他类型自己扩展,我就不加了 Guid、DateTime...
                }
                
                return type;
            }
            private static void OpenConnection(IDbConnection connection)
            {
                if (connection.State != ConnectionState.Open)
                {
                    connection.Open();
                }
            }
        }
        public class ColumnInfo
        {
            public ColumnInfo(string name, string typeName, object value)
            {
                this.Name = name;
                this.TypeName = typeName;
                this.Value = value;
            }
            public string Name { get; set; }
            public string TypeName { get; set; }
            public object Value { get; set; }
        }

    WhereBuilder:将表达式树转成where子句(从第三方扒下来的)

    using System;
    using System.Collections;
    using System.Collections.Generic;
    using System.Linq;
    using System.Linq.Expressions;
    using System.Reflection;
    using System.Runtime.CompilerServices;
    using System.Text;
    
    /// <summary>
    /// 生成Where条件的SQL语句
    /// Generating SQL from expression trees
    /// </summary>
    public class WhereBuilder
    {
        private readonly char _columnBeginChar = '[';
        private readonly char _columnEndChar = ']';
        private System.Collections.ObjectModel.ReadOnlyCollection<ParameterExpression> expressParameterNameCollection;
    
        public WhereBuilder(char columnChar = '`')
        {
            this._columnBeginChar = this._columnEndChar = columnChar;
        }
    
        public WhereBuilder(char columnBeginChar = '[', char columnEndChar = ']')
        {
            this._columnBeginChar = columnBeginChar;
            this._columnEndChar = columnEndChar;
        }
    
        /// <summary>
        /// LINQ转SQL
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <param name="expression"></param>
        /// <returns></returns>
        public WherePart ToSql<T>(Expression<Func<T, bool>> expression)
        {
            var i = 1;
            if (expression.Parameters.Count > 0)
            {
                this.expressParameterNameCollection = expression.Parameters;
            }
            return Recurse(ref i, expression.Body, isUnary: true);
        }
    
        /// <summary>
        /// LINQ转SQL
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <param name="i">种子值</param>
        /// <param name="expression"></param>
        /// <returns></returns>
        public WherePart ToSql<T>(ref int i, Expression<Func<T, bool>> expression)
        {
            if (expression.Parameters.Count > 0)
            {
                this.expressParameterNameCollection = expression.Parameters;
            }
            return Recurse(ref i, expression.Body, isUnary: true);
        }
    
        /// <summary>
        /// LINQ转SQL
        /// </summary>
        /// <param name="i">种子值</param>
        /// <param name="expression"></param>
        /// <param name="isUnary"></param>
        /// <param name="prefix"></param>
        /// <param name="postfix"></param>
        /// <returns></returns>
        private WherePart Recurse(ref int i, Expression expression, bool isUnary = false, string prefix = null, string postfix = null)
        {
            //运算符表达式
            if (expression is UnaryExpression)
            {
                var unary = (UnaryExpression)expression;
                //示例:m.birthday=DateTime.Now
                if (unary.NodeType == ExpressionType.Convert)
                {
                    var value = GetValue(expression);
                    if (value is string)
                    {
                        value = prefix + (string)value + postfix;
                    }
                    return WherePart.IsParameter(i++, value);
                }
                else
                {
                    //示例:m.Birthday>'2018-10-31'
                    return WherePart.Concat(NodeTypeToString(unary.NodeType), Recurse(ref i, unary.Operand, true));
                }
            }
            if (expression is BinaryExpression)
            {
                var body = (BinaryExpression)expression;
                return WherePart.Concat(Recurse(ref i, body.Left), NodeTypeToString(body.NodeType), Recurse(ref i, body.Right));
            }
            //常量值表达式
            //示例右侧表达式:m.ID=123;
            if (expression is ConstantExpression)
            {
                var constant = (ConstantExpression)expression;
                var value = constant.Value;
                if (value is int)
                {
                    return WherePart.IsSql(value.ToString());
                }
                if (value is string)
                {
                    value = prefix + (string)value + postfix;
                }
                if (value is bool && isUnary)
                {
                    return WherePart.Concat(WherePart.IsParameter(i++, value), "=", WherePart.IsSql("1"));
                }
                return WherePart.IsParameter(i++, value);
            }
            //成员表达式
            if (expression is MemberExpression)
            {
                var member = (MemberExpression)expression;
                var memberExpress = member.Expression;
                bool isContainsParameterExpress = false;
                this.IsContainsParameterExpress(member, ref isContainsParameterExpress);
                if (member.Member is PropertyInfo && isContainsParameterExpress)
                {
                    var property = (PropertyInfo)member.Member;
                    //var colName = _tableDef.GetColumnNameFor(property.Name);
                    var colName = property.Name;
                    if (isUnary && member.Type == typeof(bool))
                    {
                        return WherePart.Concat(Recurse(ref i, expression), "=", WherePart.IsParameter(i++, true));
                    }
                    return WherePart.IsSql(string.Format("{0}{1}{2}", this._columnBeginChar, colName, this._columnEndChar));
                }
                if (member.Member is FieldInfo || !isContainsParameterExpress)
                {
                    var value = GetValue(member);
                    if (value is string)
                    {
                        value = prefix + (string)value + postfix;
                    }
                    return WherePart.IsParameter(i++, value);
                }
                throw new Exception($"Expression does not refer to a property or field: {expression}");
            }
            //方法表达式
            if (expression is MethodCallExpression)
            {
                var methodCall = (MethodCallExpression)expression;
                //属性表达式中的参数表达式是否是表达式参数集合中的实例(或者表达式中包含的其他表达式中的参数表达式)
                bool isContainsParameterExpress = false;
                this.IsContainsParameterExpress(methodCall, ref isContainsParameterExpress);
                if (isContainsParameterExpress)
                {
                    // LIKE queries:
                    if (methodCall.Method == typeof(string).GetMethod("Contains", new[] { typeof(string) }))
                    {
                        return WherePart.Concat(Recurse(ref i, methodCall.Object), "LIKE", Recurse(ref i, methodCall.Arguments[0], prefix: "%", postfix: "%"));
                    }
                    if (methodCall.Method == typeof(string).GetMethod("StartsWith", new[] { typeof(string) }))
                    {
                        return WherePart.Concat(Recurse(ref i, methodCall.Object), "LIKE", Recurse(ref i, methodCall.Arguments[0], postfix: "%"));
                    }
                    if (methodCall.Method == typeof(string).GetMethod("EndsWith", new[] { typeof(string) }))
                    {
                        return WherePart.Concat(Recurse(ref i, methodCall.Object), "LIKE", Recurse(ref i, methodCall.Arguments[0], prefix: "%"));
                    }
                    // IN queries:
                    if (methodCall.Method.Name == "Contains")
                    {
                        Expression collection;
                        Expression property;
                        if (methodCall.Method.IsDefined(typeof(ExtensionAttribute)) && methodCall.Arguments.Count == 2)
                        {
                            collection = methodCall.Arguments[0];
                            property = methodCall.Arguments[1];
                        }
                        else if (!methodCall.Method.IsDefined(typeof(ExtensionAttribute)) && methodCall.Arguments.Count == 1)
                        {
                            collection = methodCall.Object;
                            property = methodCall.Arguments[0];
                        }
                        else
                        {
                            throw new Exception("Unsupported method call: " + methodCall.Method.Name);
                        }
                        var values = (IEnumerable)GetValue(collection);
                        return WherePart.Concat(Recurse(ref i, property), "IN", WherePart.IsCollection(ref i, values));
                    }
                }
                else
                {
                    var value = GetValue(expression);
                    if (value is string)
                    {
                        value = prefix + (string)value + postfix;
                    }
                    return WherePart.IsParameter(i++, value);
                }
    
                throw new Exception("Unsupported method call: " + methodCall.Method.Name);
            }
            //New表达式
            if (expression is NewExpression)
            {
                var member = (NewExpression)expression;
                var value = GetValue(member);
                if (value is string)
                {
                    value = prefix + (string)value + postfix;
                }
                return WherePart.IsParameter(i++, value);
            }
            throw new Exception("Unsupported expression: " + expression.GetType().Name);
        }
        /// <summary>
        /// 判断表达式内部是否含有变量M
        /// </summary>
        /// <param name="expression">表达式</param>
        /// <returns></returns>
        private void IsContainsParameterExpress(Expression expression, ref bool result)
        {
            if (this.expressParameterNameCollection != null && this.expressParameterNameCollection.Count > 0 && expression != null)
            {
                if (expression is MemberExpression)
                {
                    if (this.expressParameterNameCollection.Contains(((MemberExpression)expression).Expression))
                    {
                        result = true;
                    }
                }
                else if (expression is MethodCallExpression)
                {
                    MethodCallExpression methodCallExpression = (MethodCallExpression)expression;
    
                    if (methodCallExpression.Object != null)
                    {
                        if (methodCallExpression.Object is MethodCallExpression)
                        {
                            //判断示例1:m.ID.ToString().Contains("123")
                            this.IsContainsParameterExpress(methodCallExpression.Object, ref result);
                        }
                        else if (methodCallExpression.Object is MemberExpression)
                        {
                            //判断示例2:m.ID.Contains(123)
                            MemberExpression MemberExpression = (MemberExpression)methodCallExpression.Object;
                            if (MemberExpression.Expression != null && this.expressParameterNameCollection.Contains(MemberExpression.Expression))
                            {
                                result = true;
                            }
                        }
                    }
                    //判断示例3: int[] ids=new ids[]{1,2,3};  ids.Contains(m.ID)
                    if (result == false && methodCallExpression.Arguments != null && methodCallExpression.Arguments.Count > 0)
                    {
                        foreach (Expression express in methodCallExpression.Arguments)
                        {
                            if (express is MemberExpression || express is MethodCallExpression)
                            {
                                this.IsContainsParameterExpress(express, ref result);
                            }
                            else if (this.expressParameterNameCollection.Contains(express))
                            {
                                result = true;
                                break;
                            }
                        }
                    }
                }
            }
        }
    
        private static object GetValue(Expression member)
        {
            // source: http://stackoverflow.com/a/2616980/291955
            var objectMember = Expression.Convert(member, typeof(object));
            var getterLambda = Expression.Lambda<Func<object>>(objectMember);
            var getter = getterLambda.Compile();
            return getter();
        }
    
        private static string NodeTypeToString(ExpressionType nodeType)
        {
            switch (nodeType)
            {
                case ExpressionType.Add:
                    return "+";
                case ExpressionType.And:
                    return "&";
                case ExpressionType.AndAlso:
                    return "AND";
                case ExpressionType.Divide:
                    return "/";
                case ExpressionType.Equal:
                    return "=";
                case ExpressionType.ExclusiveOr:
                    return "^";
                case ExpressionType.GreaterThan:
                    return ">";
                case ExpressionType.GreaterThanOrEqual:
                    return ">=";
                case ExpressionType.LessThan:
                    return "<";
                case ExpressionType.LessThanOrEqual:
                    return "<=";
                case ExpressionType.Modulo:
                    return "%";
                case ExpressionType.Multiply:
                    return "*";
                case ExpressionType.Negate:
                    return "-";
                case ExpressionType.Not:
                    return "NOT";
                case ExpressionType.NotEqual:
                    return "<>";
                case ExpressionType.Or:
                    return "|";
                case ExpressionType.OrElse:
                    return "OR";
                case ExpressionType.Subtract:
                    return "-";
            }
            throw new Exception($"Unsupported node type: {nodeType}");
        }
    }
    
    public class WherePart
    {
        /// <summary>
        /// 含有参数变量的SQL语句
        /// </summary>
        public string Sql { get; set; }
        /// <summary>
        /// SQL语句中的参数变量
        /// </summary>
        public Dictionary<string, object> Parameters { get; set; } = new Dictionary<string, object>();
    
        public static WherePart IsSql(string sql)
        {
            return new WherePart()
            {
                Parameters = new Dictionary<string, object>(),
                Sql = sql
            };
        }
    
        public static WherePart IsParameter(int count, object value)
        {
            return new WherePart()
            {
                Parameters = { { count.ToString(), value } },
                Sql = $"@{count}"
            };
        }
    
        public static WherePart IsCollection(ref int countStart, IEnumerable values)
        {
            var parameters = new Dictionary<string, object>();
            var sql = new StringBuilder("(");
            foreach (var value in values)
            {
                parameters.Add((countStart).ToString(), value);
                sql.Append($"@{countStart},");
                countStart++;
            }
            if (sql.Length == 1)
            {
                sql.Append("null,");
            }
            sql[sql.Length - 1] = ')';
            return new WherePart()
            {
                Parameters = parameters,
                Sql = sql.ToString()
            };
        }
    
        public static WherePart Concat(string @operator, WherePart operand)
        {
            return new WherePart()
            {
                Parameters = operand.Parameters,
                Sql = $"({@operator} {operand.Sql})"
            };
        }
    
        public static WherePart Concat(WherePart left, string @operator, WherePart right)
        {
            return new WherePart()
            {
                Parameters = left.Parameters.Union(right.Parameters).ToDictionary(kvp => kvp.Key, kvp => kvp.Value),
                Sql = $"({left.Sql} {@operator} {right.Sql})"
            };
        }
    }
    View Code
  • 相关阅读:
    python语言程序设计部分习题
    Python基础:Python运行的两种基本方式
    python简介及详细安装方法
    MTBF平均故障间隔时间(转)
    SSH远程登录配置文件sshd_config详解
    SSH服务详解(转)
    GCC编译之后的代码信息
    移动设备识别ID
    STM32CubeMX自建MDK工程的基本步骤
    职位英文缩写
  • 原文地址:https://www.cnblogs.com/fanfan-90/p/12788422.html
Copyright © 2011-2022 走看看