zoukankan      html  css  js  c++  java
  • EFCore扩展Update方法(实现 Update User SET Id=Id+1)

    EFCore扩展Update方法(实现 Update User SET Id = Id + 1)

    前言

    1. EFCore在操作更新的时候往往需要先查询一遍数据,再去更新相应的字段,如果针对批量更新的话会很麻烦,效率也很低。
    2. 目前github上 EFCore.Extentions 项目,实现批量更新挺方便的,但是针对 Update User SET Id = Id + 1 这种操作还是没有解决
    3. 本文主要就是扩展自更新Update

    实现原理

    1. 先根据IQuaryable 获取到SQL语句
            private static readonly TypeInfo QueryCompilerTypeInfo = typeof(QueryCompiler).GetTypeInfo();
            private static readonly FieldInfo QueryCompilerField = typeof(EntityQueryProvider).GetTypeInfo().DeclaredFields.First(x => x.Name == "_queryCompiler");
            private static readonly FieldInfo QueryModelGeneratorField = QueryCompilerTypeInfo.DeclaredFields.First(x => x.Name == "_queryModelGenerator");
            private static readonly FieldInfo DataBaseField = QueryCompilerTypeInfo.DeclaredFields.Single(x => x.Name == "_database");
            private static readonly PropertyInfo DatabaseDependenciesField = typeof(Database).GetTypeInfo().DeclaredProperties.Single(x => x.Name == "Dependencies");

            /// <summary>
            /// 将query 转化为sql语句
            /// </summary>
            /// <typeparam name="TEntity"></typeparam>
            /// <param name="query"></param>
            /// <returns></returns>
            internal static string ToSql<TEntity>(this IQueryable<TEntity> query) where TEntity : class
            {
                var queryCompiler = (QueryCompiler)QueryCompilerField.GetValue(query.Provider);
                var modelGenerator = (QueryModelGenerator)QueryModelGeneratorField.GetValue(queryCompiler);
                var queryModel = modelGenerator.ParseQuery(query.Expression);
                var database = (IDatabase)DataBaseField.GetValue(queryCompiler);
                var databaseDependencies = (DatabaseDependencies)DatabaseDependenciesField.GetValue(database);
                var queryCompilationContext = databaseDependencies.QueryCompilationContextFactory.Create(false);
                var modelVisitor = (RelationalQueryModelVisitor)queryCompilationContext.CreateQueryModelVisitor();

                modelVisitor.CreateQueryExecutor<TEntity>(queryModel);
                string sql = modelVisitor.Queries.First().ToString();
                return sql;
            }
    1. 把获取的查询语句的,From之前的语句砍掉,然后拼接
            public static (string, string) GetBatchSql<T>(IQueryable<T> query) where T : class, new()
            {
                string sqlQuery = query.ToSql();
                string tableAlias = sqlQuery.Substring(8, sqlQuery.IndexOf("]") - 8);
                int indexFROM = sqlQuery.IndexOf(Environment.NewLine);
                string sql = sqlQuery.Substring(indexFROM, sqlQuery.Length - indexFROM);
                sql = sql.Contains("{") ? sql.Replace("{", "{{") : sql; // Curly brackets have to escaped:
                sql = sql.Contains("}") ? sql.Replace("}", "}}") : sql; // https://github.com/aspnet/EntityFrameworkCore/issues/8820
                return (sql, tableAlias);
            }
    1. 根据传入的表达式Expression<Func<T,bool> 生成(Update [a] SET) 之后要更新的部分.列如:Expression<Func<T,bool> expression=a=>a.Id == a.Id + 1 生成[a].[Id]=[a].[Id]+parm_0 pram_0=1

    通过分析Expression的节点[NodeType]来生成相应的操作符,递归拼接sql语句和参数

            public static void CreateUpdateBody(string Param, Expression expression, ref StringBuilder sb, ref List<SqlParameter> sp)
            {
                if (expression is BinaryExpression binaryExpression)
                {
                    CreateUpdateBody(Param, binaryExpression.Left, ref sb, ref sp);

                    switch (binaryExpression.NodeType)
                    {
                        case ExpressionType.Add:
                            sb.Append(" +");
                            break;
                        case ExpressionType.Divide:
                            sb.Append(" /");
                            break;
                        case ExpressionType.Multiply:
                            sb.Append(" *");
                            break;
                        case ExpressionType.Subtract:
                            sb.Append(" -");
                            break;
                        case ExpressionType.And:
                            sb.Append(" ,");
                            break;
                        case ExpressionType.AndAlso:
                            sb.Append(" ,");
                            break;
                        case ExpressionType.Or:
                            sb.Append(" ,");
                            break;
                        case ExpressionType.OrElse:
                            sb.Append(" ,");
                            break;
                        case ExpressionType.Equal:
                            sb.Append(" =");
                            break;
                        default: break;
                    }

                    CreateUpdateBody(Param, binaryExpression.Right, ref sb, ref sp);
                }

                if (expression is ConstantExpression constantExpression)
                {
                    var parmName = $"param_{sp.Count}";
                    sp.Add(new SqlParameter(parmName, constantExpression.Value));
                    sb.Append($" @{parmName}");
                }

                if (expression is MemberExpression memberExpression)
                {
                    sb.Append($"{Param}.[{memberExpression.Member.Name}]");
                }
            }
    1. 最后执行生成SQL语句,详情请看源码github

    2. 调用

            static void Main(string[] args)
            {
                using (var context = new TestContext())
                {
                    var list = context.User.Select<UserModel>().ToList();

                    var user1 = context.User.AsNoTracking().FirstOrDefault(x => x.Id == 2);

                    Console.WriteLine($"-----------Before Update --------------------");
                    Console.WriteLine($"{user1.Id}:{user1.Name}:{user1.RoleId}");

                    context.User.Where(x => x.Id == 2).RestValue(x => x.Name == (x.Name + " Add Bob") && x.RoleId == (x.RoleId + 1));

                    var user2 = context.User.AsNoTracking().FirstOrDefault(x => x.Id == 2);

                    Console.WriteLine($"-----------After Update --------------------");
                    Console.WriteLine($"{user2.Id}:{user2.Name}:{user2.RoleId}");
                }
                Console.WriteLine($"------------结束--------------------");
                Console.ReadLine();
            }
    1. 生成的SQL如下
         UPDATE [x] SET x.[Name] =x.[Name] + @param_0 ,x.[RoleId] =x.[RoleId] + @param_1
         FROM [User] AS [x]
         WHERE [x].[Id] = 2
    
  • 相关阅读:
    LVS---服务器集群系统
    I/O的基本概念
    rsync+cron同步文件服务
    IAAS、PAAS、SAAS及公有云、私有云概念
    Python3456学习结构
    Python列表常用函数解析
    Python字符串常用函数详解
    验证码生成
    Python随机数生成random.randint()与np.random.randint()
    python在线&离线安装第三库的方法
  • 原文地址:https://www.cnblogs.com/castyuan/p/10194325.html
Copyright © 2011-2022 走看看