首页 > 代码库 > EntytyFramework批量更新

EntytyFramework批量更新

上一遍批量Insert,使用了EntityFramework.BulkInsert  ,这个的免费版本只提供了批量Inset功能,更新的需要收费,于是乎,就自己实现了一个。

实现的思路:

1、调用BulkInsert,将数据插入到一张#temp临时表中

2、在同一个连接下,执行update ... from #temp where ...的语句

思路看起来很简单,其实实现起来也不是很复杂。

主要代码

        public override void Run<T>(IEnumerable<T> entities, SqlTransaction transaction)
        {
            Options.SqlBulkCopyOptions = SqlBulkCopyOptions.KeepIdentity;
            var keepIdentity = (SqlBulkCopyOptions.KeepIdentity & Options.SqlBulkCopyOptions) > 0;
            using (var reader = new MappedDataReader<T>(entities, this))
            {
                using (var sqlBulkCopy = new SqlBulkCopy(transaction.Connection, Options.SqlBulkCopyOptions, transaction))
                {
                    sqlBulkCopy.BulkCopyTimeout = Options.TimeOut;
                    sqlBulkCopy.BatchSize = Options.BatchSize;
                    var realTableName = string.Format("[{0}].[{1}]", reader.SchemaName, reader.TableName);
                    sqlBulkCopy.DestinationTableName = "#TmpTable";
                    
                    var updatePartion = new List<string>();
                    var wherePartion = new List<string>();
                    for (var i = 0; i < reader.Cols.Count; i++)
                    {
                        var col = reader.Cols[i];
                        if (col.IsPk && (Options.KeyColumnsForUpdate == null || Options.KeyColumnsForUpdate.Count == 0)) wherePartion.Add(string.Format("{0}.{1} = T.{1}", realTableName, col.ColumnName));
                        else if (Options.KeyColumnsForUpdate != null && Options.KeyColumnsForUpdate.Contains(col.ColumnName)) wherePartion.Add(string.Format("{0}.{1} = T.{1}", realTableName, col.ColumnName));
                        else
                        {
                            if (Options.ColumnsForUpdate != null && Options.ColumnsForUpdate.Any())
                            {
                                if (Options.ColumnsForUpdate.Contains(col.ColumnName)) updatePartion.Add(string.Format("{0}.{1} = T.{1}", realTableName, col.ColumnName));
                            }
                            else updatePartion.Add(string.Format("{0}.{1} = T.{1}", realTableName, col.ColumnName));
                        }
                    }
                    if (!updatePartion.Any() || !wherePartion.Any()) throw new Exception("批量Update没有找到Update后面的语句部分,或者没有指定主键");
                    var updateText = string.Format("UPDATE {0} SET {1} FROM #TmpTable T WHERE {2}; ", realTableName, string.Join(",", updatePartion), string.Join(" and ", wherePartion));
                    
                    SqlCommand command = new SqlCommand("SELECT TOP 0 * INTO #TmpTable FROM " + realTableName, transaction.Connection, transaction);
                    command.ExecuteNonQuery();
#if !NET40
                    sqlBulkCopy.EnableStreaming = Options.EnableStreaming;
#endif

                    sqlBulkCopy.NotifyAfter = Options.NotifyAfter;
                    if (Options.Callback != null)
                    {
                        sqlBulkCopy.SqlRowsCopied += Options.Callback;
                    }

                    foreach (var kvp in reader.Cols)
                    {
                        if (kvp.Value.IsIdentity && !keepIdentity)
                        {
                            continue;
                        }
                        sqlBulkCopy.ColumnMappings.Add(kvp.Value.ColumnName, kvp.Value.ColumnName);
                    }

                    sqlBulkCopy.WriteToServer(reader);

                    command.CommandTimeout = 300;
                    command.CommandText = updateText;
                    var r = command.ExecuteNonQuery();
                }
            }
        }

其中Options.KeyColumnsForUpdate是调用公开的批量Update方法传进来的在Update时使用的主键,不指定就使用表的主键列。

Options.ColumnsForUpdate是调用方传进来要更新的列,如果不传就除主键外,全部更新(忘记排除自动步增了)。

 

调用举例:

cx.BulkUpdate(list, new BulkUpdateColumn<实体类型>().AddColumn((x) => x.字段属性1).AddColumn((x) => x.字段属性2));

第二个参数是真的不想传string过来指定要更新哪些字段,还好,一般需要更新的字段应该不会太多,如果太多就不指定,全部更新就是了。

公开的static方法类:

    public static class BulkUpdateExtension
    {
        public static void BulkUpdate<T, TEntity>(this DbContext context, IEnumerable<T> entities, BulkUpdateColumn<TEntity> updateColumn, BulkUpdateColumn<TEntity> keyColumn, BulkInsertOptions options) where TEntity : class
        {
            var bulkInsert = ProviderFactory.GetUpdate(context);
            bulkInsert.Options = options;
            options.ColumnsForUpdate = updateColumn.ColumnName;
            options.KeyColumnsForUpdate = keyColumn.ColumnName;
            bulkInsert.Run(entities);
        }
        public static void BulkUpdate<T, TEntity>(this DbContext context, IEnumerable<T> entities, BulkUpdateColumn<TEntity> updateColumn) where TEntity : class
        {
            context.BulkUpdate(entities, updateColumn, new BulkUpdateColumn<TEntity>(), SqlBulkCopyOptions.Default);
        }
        public static void BulkUpdate<T, TEntity>(this DbContext context, IEnumerable<T> entities, BulkUpdateColumn<TEntity> updateColumn, BulkUpdateColumn<TEntity> keyColumn, int? batchSize = null) where TEntity : class
        {
            context.BulkUpdate(entities, updateColumn, keyColumn, SqlBulkCopyOptions.Default, batchSize);
        }
        public static void BulkUpdate<T, TEntity>(this DbContext context, IEnumerable<T> entities, BulkUpdateColumn<TEntity> updateColumn, BulkUpdateColumn<TEntity> keyColumn, SqlBulkCopyOptions sqlBulkCopyOptions, int? batchSize = null) where TEntity : class
        {

            var options = new BulkInsertOptions { SqlBulkCopyOptions = sqlBulkCopyOptions };
            if (batchSize.HasValue)
            {
                options.BatchSize = batchSize.Value;
            }
            context.BulkUpdate(entities, updateColumn, keyColumn, options);
        }

        public static void BulkUpdate<T, TEntity>(this DbContext context, IEnumerable<T> entities, BulkUpdateColumn<TEntity> updateColumn, BulkUpdateColumn<TEntity> keyColumn, IDbTransaction transaction, SqlBulkCopyOptions sqlBulkCopyOptions = SqlBulkCopyOptions.Default, int? batchSize = null) where TEntity : class
        {
            var options = new BulkInsertOptions { SqlBulkCopyOptions = sqlBulkCopyOptions };
            if (batchSize.HasValue)
            {
                options.BatchSize = batchSize.Value;
            }
            context.BulkUpdate(entities, updateColumn, keyColumn, options);
        }
    }

 

EntytyFramework批量更新