|
|
@ -1,8 +1,10 @@ |
|
|
|
using Microsoft.EntityFrameworkCore; |
|
|
|
using OfficeOpenXml.FormulaParsing.Excel.Functions.Math; |
|
|
|
using Serilog; |
|
|
|
using SkiaSharp; |
|
|
|
using System; |
|
|
|
using System.Collections.Generic; |
|
|
|
using System.Data.Common; |
|
|
|
using System.Linq; |
|
|
|
using System.Linq.Expressions; |
|
|
|
using System.Reflection; |
|
|
@ -12,7 +14,7 @@ using TaskManager.Entity; |
|
|
|
using TaskManager.EntityFramework; |
|
|
|
using Wood.Util; |
|
|
|
using Wood.Util.Filters; |
|
|
|
using Z.BulkOperations; |
|
|
|
|
|
|
|
|
|
|
|
namespace TaskManager.EntityFramework.Repository |
|
|
|
{ |
|
|
@ -69,17 +71,6 @@ namespace TaskManager.EntityFramework.Repository |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
public async Task BlukMergeAsync(List<TEntity> entities,Action<BulkOperation<TEntity>> action) |
|
|
|
{ |
|
|
|
_context.BulkMerge(entities, action); |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
public async Task BlukInsertAsync(List<TEntity> entities, Action<BulkOperation<TEntity>> action) |
|
|
|
{ |
|
|
|
_context.BulkInsert(entities, action); |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -289,4 +280,330 @@ namespace TaskManager.EntityFramework.Repository |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public static class EfCoreBulkExtensions |
|
|
|
{ |
|
|
|
//private const int BatchSize = 1000;
|
|
|
|
|
|
|
|
|
|
|
|
private static int CalculateBatchSize(int propertyCount) |
|
|
|
{ |
|
|
|
// 预留100个参数作为缓冲
|
|
|
|
const int maxParameters = 2000; |
|
|
|
return Math.Max(1, maxParameters / propertyCount); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
/// EF Core 批量插入或更新扩展方法(支持复合主键)
|
|
|
|
/// </summary>
|
|
|
|
public static void BulkMerge<TEntity>(this DbContext context, |
|
|
|
IEnumerable<TEntity> entities, |
|
|
|
params Expression<Func<TEntity, object>>[] keySelectors) |
|
|
|
where TEntity : class |
|
|
|
{ |
|
|
|
if (context == null || entities == null || !entities.Any()) |
|
|
|
return; |
|
|
|
using var transaction = context.Database.BeginTransaction(); |
|
|
|
try |
|
|
|
{ |
|
|
|
Type entityType = typeof(TEntity); |
|
|
|
string tableName = context.Model.FindEntityType(entityType).GetTableName(); |
|
|
|
string providerName = GetDatabaseProviderName(context); |
|
|
|
|
|
|
|
// 获取所有主键属性
|
|
|
|
List<string> keyPropNames = keySelectors.Select(GetPropertyName).ToList(); |
|
|
|
List<PropertyInfo> keyProps = keyPropNames.Select(entityType.GetProperty).ToList(); |
|
|
|
|
|
|
|
// 获取所有非主键属性
|
|
|
|
PropertyInfo[] props = entityType.GetProperties() |
|
|
|
.Where(p => !keyPropNames.Contains(p.Name)).ToArray(); |
|
|
|
|
|
|
|
// 计算实体的总属性数量
|
|
|
|
int totalProperties = keyProps.Count + props.Length; |
|
|
|
// 动态计算批处理大小
|
|
|
|
int batchSize = CalculateBatchSize(totalProperties); |
|
|
|
|
|
|
|
// 分批处理
|
|
|
|
var batches = entities.Chunk(batchSize); |
|
|
|
|
|
|
|
foreach (var batch in batches) |
|
|
|
{ |
|
|
|
// 临时表名
|
|
|
|
string tempTableName = $"#Temp_Bulk_{Guid.NewGuid().ToString("N").Substring(0, 8)}"; |
|
|
|
|
|
|
|
try |
|
|
|
{ |
|
|
|
// 1. 创建临时表
|
|
|
|
|
|
|
|
|
|
|
|
string createSql = $"SELECT * INTO {tempTableName} FROM {tableName} WHERE 1 = 0"; |
|
|
|
context.Database.ExecuteSqlRaw(createSql); |
|
|
|
|
|
|
|
// 2. 插入数据到临时表
|
|
|
|
InsertToTempTable(context, batch, tempTableName, keyProps, props); |
|
|
|
|
|
|
|
// 3. 执行MERGE操作
|
|
|
|
ExecuteMerge(context, tableName, tempTableName, keyPropNames, props); |
|
|
|
transaction.Commit(); |
|
|
|
|
|
|
|
} |
|
|
|
finally |
|
|
|
{ |
|
|
|
// 4. 清理临时表
|
|
|
|
DropTempTable(context, tempTableName); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
catch (Exception ex) |
|
|
|
{ |
|
|
|
transaction.Rollback(); |
|
|
|
Console.WriteLine($"批量操作失败: {ex.Message}"); |
|
|
|
throw; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
/// <summary>
|
|
|
|
/// EF Core 批量插入或更新扩展方法(支持复合主键)
|
|
|
|
/// </summary>
|
|
|
|
//public static void BulkMerge<TEntity>(this DbContext context,
|
|
|
|
// IEnumerable<TEntity> entities,
|
|
|
|
// params Expression<Func<TEntity, object>>[] keySelectors)
|
|
|
|
// where TEntity : class
|
|
|
|
//{
|
|
|
|
// if (context == null || entities == null || !entities.Any())
|
|
|
|
// return;
|
|
|
|
|
|
|
|
|
|
|
|
// try
|
|
|
|
// {
|
|
|
|
// Type entityType = typeof(TEntity);
|
|
|
|
// string tableName = context.Model.FindEntityType(entityType).GetTableName();
|
|
|
|
// string providerName = GetDatabaseProviderName(context);
|
|
|
|
|
|
|
|
// // 获取所有主键属性
|
|
|
|
// List<string> keyPropNames = keySelectors.Select(GetPropertyName).ToList();
|
|
|
|
// List<PropertyInfo> keyProps = keyPropNames.Select(entityType.GetProperty).ToList();
|
|
|
|
|
|
|
|
// // 获取所有非主键属性
|
|
|
|
// PropertyInfo[] props = entityType.GetProperties()
|
|
|
|
// .Where(p => !keyPropNames.Contains(p.Name)).ToArray();
|
|
|
|
|
|
|
|
// // 分批处理
|
|
|
|
// var batches = entities.Chunk(BatchSize);
|
|
|
|
|
|
|
|
// foreach (var batch in batches)
|
|
|
|
// {
|
|
|
|
// // 临时表名
|
|
|
|
// string tempTableName = $"#Temp_Bulk_{Guid.NewGuid().ToString("N").Substring(0, 8)}";
|
|
|
|
|
|
|
|
// try
|
|
|
|
// {
|
|
|
|
// // 1. 创建临时表
|
|
|
|
|
|
|
|
// string createSql = $"SELECT * INTO {tempTableName} FROM {tableName} WHERE 1 = 0";
|
|
|
|
|
|
|
|
// context.Database.ExecuteSqlRaw(createSql);
|
|
|
|
|
|
|
|
// //CreateTempTable(context, tempTableName, keyProps, props, providerName);
|
|
|
|
|
|
|
|
// // 2. 插入数据到临时表
|
|
|
|
// InsertToTempTable(context, batch, tempTableName, keyProps, props);
|
|
|
|
|
|
|
|
// // 3. 执行MERGE操作
|
|
|
|
// ExecuteMerge(context, tableName, tempTableName, keyPropNames, props);
|
|
|
|
// }
|
|
|
|
// finally
|
|
|
|
// {
|
|
|
|
// // 4. 清理临时表
|
|
|
|
// DropTempTable(context, tempTableName);
|
|
|
|
// }
|
|
|
|
// }
|
|
|
|
|
|
|
|
|
|
|
|
// }
|
|
|
|
// catch (Exception ex)
|
|
|
|
// {
|
|
|
|
|
|
|
|
// Console.WriteLine($"批量操作失败: {ex.Message}");
|
|
|
|
// throw;
|
|
|
|
// }
|
|
|
|
|
|
|
|
//}
|
|
|
|
|
|
|
|
// 解析表达式获取属性名
|
|
|
|
private static string GetPropertyName<TEntity>(Expression<Func<TEntity, object>> expression) |
|
|
|
{ |
|
|
|
if (expression.Body is UnaryExpression unary) |
|
|
|
expression = (Expression<Func<TEntity, object>>)unary.Operand; |
|
|
|
|
|
|
|
if (expression.Body is MemberExpression member) |
|
|
|
return member.Member.Name; |
|
|
|
|
|
|
|
throw new ArgumentException("表达式必须指向实体属性", nameof(expression)); |
|
|
|
} |
|
|
|
|
|
|
|
// 创建临时表
|
|
|
|
private static void CreateTempTable(DbContext context, string tempTableName, |
|
|
|
List<PropertyInfo> keyProps, PropertyInfo[] props, string providerName) |
|
|
|
{ |
|
|
|
string keyConstraints = string.Join(" AND ", keyProps.Select(p => $"[{p.Name}]")); |
|
|
|
string sql = $"CREATE TABLE {tempTableName} (" + |
|
|
|
string.Join(", ", keyProps.Select(p => $"[{p.Name}] {GetSqlType(p.PropertyType, providerName)}")) + |
|
|
|
", " + |
|
|
|
string.Join(", ", props.Select(p => $"[{p.Name}] {GetSqlType(p.PropertyType, providerName)}")) + |
|
|
|
", PRIMARY KEY (" + keyConstraints + ")" + |
|
|
|
")"; |
|
|
|
context.Database.ExecuteSqlRaw(sql); |
|
|
|
} |
|
|
|
|
|
|
|
// 插入数据到临时表
|
|
|
|
private static void InsertToTempTable<TEntity>(DbContext context, IEnumerable<TEntity> entities, |
|
|
|
string tempTableName, List<PropertyInfo> keyProps, PropertyInfo[] props) |
|
|
|
{ |
|
|
|
// 构建参数化SQL
|
|
|
|
string paramPrefix = "@p"; |
|
|
|
int paramIndex = 0; |
|
|
|
string values = string.Join(",\n", entities.Select(e => |
|
|
|
{ |
|
|
|
string row = "("; |
|
|
|
// 添加主键参数
|
|
|
|
foreach (var keyProp in keyProps) |
|
|
|
row += $"{paramPrefix}{paramIndex++}, "; |
|
|
|
// 添加其他属性参数
|
|
|
|
foreach (var prop in props) |
|
|
|
row += $"{paramPrefix}{paramIndex++}, "; |
|
|
|
return row.TrimEnd(", ".ToCharArray()) + ")"; |
|
|
|
})); |
|
|
|
|
|
|
|
string sql = $"INSERT INTO {tempTableName} ({GetColumnList(keyProps, props)}) VALUES {values}"; |
|
|
|
DbCommand cmd = context.Database.GetDbConnection().CreateCommand(); |
|
|
|
cmd.CommandText = sql; |
|
|
|
|
|
|
|
// 添加参数
|
|
|
|
paramIndex = 0; |
|
|
|
foreach (var entity in entities) |
|
|
|
{ |
|
|
|
// 添加主键值
|
|
|
|
foreach (var keyProp in keyProps) |
|
|
|
AddParameter(cmd, $"{paramPrefix}{paramIndex++}", keyProp.GetValue(entity)); |
|
|
|
|
|
|
|
// 添加其他属性值
|
|
|
|
foreach (var prop in props) |
|
|
|
AddParameter(cmd, $"{paramPrefix}{paramIndex++}", prop.GetValue(entity)); |
|
|
|
} |
|
|
|
|
|
|
|
context.Database.OpenConnection(); |
|
|
|
try { cmd.ExecuteNonQuery(); } |
|
|
|
finally { context.Database.CloseConnection(); } |
|
|
|
} |
|
|
|
|
|
|
|
// 添加参数(防止SQL注入)
|
|
|
|
private static void AddParameter(DbCommand cmd, string paramName, object value) |
|
|
|
{ |
|
|
|
DbParameter param = cmd.CreateParameter(); |
|
|
|
param.ParameterName = paramName; |
|
|
|
param.Value = value ?? DBNull.Value; |
|
|
|
cmd.Parameters.Add(param); |
|
|
|
} |
|
|
|
|
|
|
|
// 构建列名列表
|
|
|
|
private static string GetColumnList(List<PropertyInfo> keyProps, PropertyInfo[] props) |
|
|
|
{ |
|
|
|
return string.Join(", ", keyProps.Select(p => $"[{p.Name}]")) + ", " + |
|
|
|
string.Join(", ", props.Select(p => $"[{p.Name}]")); |
|
|
|
} |
|
|
|
|
|
|
|
// 获取SQL类型
|
|
|
|
private static string GetSqlType(Type type, string providerName) |
|
|
|
{ |
|
|
|
type = Nullable.GetUnderlyingType(type) ?? type; |
|
|
|
|
|
|
|
if (providerName.Contains("SqlServer")) |
|
|
|
{ |
|
|
|
if (type == typeof(int)) return "INT"; |
|
|
|
if (type == typeof(long)) return "BIGINT"; |
|
|
|
if (type == typeof(short)) return "SMALLINT"; |
|
|
|
if (type == typeof(byte)) return "TINYINT"; |
|
|
|
if (type == typeof(bool)) return "BIT"; |
|
|
|
if (type == typeof(string)) return "NVARCHAR(MAX)"; |
|
|
|
if (type == typeof(DateTime)) return "DATETIME2"; |
|
|
|
if (type == typeof(DateTimeOffset)) return "DATETIMEOFFSET"; |
|
|
|
if (type == typeof(decimal)) return "DECIMAL(18, 2)"; |
|
|
|
if (type == typeof(float)) return "FLOAT"; |
|
|
|
if (type == typeof(double)) return "FLOAT"; |
|
|
|
if (type == typeof(Guid)) return "UNIQUEIDENTIFIER"; |
|
|
|
if (type == typeof(byte[])) return "VARBINARY(MAX)"; |
|
|
|
} |
|
|
|
else if (providerName.Contains("PostgreSQL")) |
|
|
|
{ |
|
|
|
if (type == typeof(int)) return "INTEGER"; |
|
|
|
if (type == typeof(long)) return "BIGINT"; |
|
|
|
if (type == typeof(short)) return "SMALLINT"; |
|
|
|
if (type == typeof(bool)) return "BOOLEAN"; |
|
|
|
if (type == typeof(string)) return "TEXT"; |
|
|
|
if (type == typeof(DateTime)) return "TIMESTAMP"; |
|
|
|
if (type == typeof(DateTimeOffset)) return "TIMESTAMPTZ"; |
|
|
|
if (type == typeof(decimal)) return "DECIMAL(18, 2)"; |
|
|
|
if (type == typeof(float)) return "REAL"; |
|
|
|
if (type == typeof(double)) return "DOUBLE PRECISION"; |
|
|
|
if (type == typeof(Guid)) return "UUID"; |
|
|
|
if (type == typeof(byte[])) return "BYTEA"; |
|
|
|
} |
|
|
|
|
|
|
|
throw new NotSupportedException($"类型 {type.Name} 或数据库提供者 {providerName} 不支持"); |
|
|
|
} |
|
|
|
|
|
|
|
// 执行MERGE操作
|
|
|
|
private static void ExecuteMerge(DbContext context, string tableName, string tempTableName, |
|
|
|
List<string> keyPropNames, PropertyInfo[] props) |
|
|
|
{ |
|
|
|
string onClause = string.Join(" AND ", keyPropNames.Select(k => |
|
|
|
$"target.[{k}] = source.[{k}]")); |
|
|
|
|
|
|
|
string updateFields = string.Join(", ", props.Select(p => |
|
|
|
$"target.[{p.Name}] = source.[{p.Name}]")); |
|
|
|
|
|
|
|
string sql = $"MERGE {tableName} AS target " + |
|
|
|
$"USING {tempTableName} AS source " + |
|
|
|
$"ON ({onClause}) " + |
|
|
|
$"WHEN MATCHED THEN UPDATE SET {updateFields} " + |
|
|
|
$"WHEN NOT MATCHED THEN " + |
|
|
|
$"INSERT ({GetColumnList(keyPropNames, props)}) " + |
|
|
|
$"VALUES ({string.Join(", ", keyPropNames.Concat(props.Select(p => $"source.[{p.Name}]")))});"; |
|
|
|
|
|
|
|
context.Database.ExecuteSqlRaw(sql); |
|
|
|
} |
|
|
|
|
|
|
|
// 删除临时表
|
|
|
|
private static void DropTempTable(DbContext context, string tempTableName) |
|
|
|
{ |
|
|
|
context.Database.ExecuteSqlRaw($"DROP TABLE IF EXISTS {tempTableName}"); |
|
|
|
} |
|
|
|
|
|
|
|
// 获取数据库提供者名称
|
|
|
|
private static string GetDatabaseProviderName(DbContext context) |
|
|
|
{ |
|
|
|
return context.Database.ProviderName; |
|
|
|
} |
|
|
|
|
|
|
|
// 构建列名列表(重载版本)
|
|
|
|
private static string GetColumnList(List<string> keyPropNames, PropertyInfo[] props) |
|
|
|
{ |
|
|
|
return string.Join(", ", keyPropNames.Select(k => $"[{k}]")) + ", " + |
|
|
|
string.Join(", ", props.Select(p => $"[{p.Name}]")); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|