using Microsoft.EntityFrameworkCore; using OfficeOpenXml.FormulaParsing.Excel.Functions.Math; using Serilog; using SkiaSharp; using System; using System.Collections.Generic; using System.Data.Common; using System.Linq; using System.Linq.Expressions; using System.Reflection; using System.Text; using System.Threading.Tasks; using TaskManager.Entity; using TaskManager.EntityFramework; using Wood.Util; using Wood.Util.Filters; namespace TaskManager.EntityFramework.Repository { public class Repository : IRepository where TEntity : BaseEntity { private JobDbContext _context; private DbSet _dbSet; public Repository(JobDbContext context) { _context = context; _dbSet = context.Set(); } public void SetDbContext(JobDbContext context) { _context = context; _dbSet = context.Set(); } public async Task GetByIdAsync(long id) { return await _dbSet.FindAsync(id); } public async Task> GetAllAsync() { return await _dbSet.ToListAsync(); } public async Task AddAsync(TEntity entity) { entity.CreationTime = DateTime.UtcNow; await _dbSet.AddAsync(entity); await _context.SaveChangesAsync(); return entity; } public async Task UpdateAsync(TEntity entity) { _dbSet.Update(entity); await _context.SaveChangesAsync(); } public async Task DeleteAsync(long id) { var entity = await _dbSet.FindAsync(id); if (entity != null) { _dbSet.Remove(entity); await _context.SaveChangesAsync(); } } public async Task> GetPagedAsync(PagingParams pagingParams) { return await _dbSet.AsNoTracking().ToPagedListAsync(pagingParams); } public async Task> GetPagedAsync( Expression> filter = null, PagingParams pagingParams = null) { IQueryable query = _dbSet.AsNoTracking(); // 应用过滤条件 if (filter != null) { query = query.Where(filter); } // 应用动态过滤 if (pagingParams?.Filters != null && pagingParams.Filters.Any()) { query = query.ApplyFilters(pagingParams.Filters); } // 应用分页和排序 pagingParams ??= new PagingParams(); var page=await query.ToPagedListAsync(pagingParams); return page; } public async Task> GetDataPagedAsync( Expression> filter = null, PagingParams pagingParams = null,Condition condition = null) { IQueryable query = _dbSet.AsNoTracking(); // 应用过滤条件 if (filter != null) { query = query.Where(filter); } // 应用动态过滤 if (condition?.Filters != null && condition.Filters.Any()) { query = query.ApplyConditionFilters(condition); } // 应用分页和排序 pagingParams ??= new PagingParams(); var page = await query.ToPagedListAsync(pagingParams); return page; } } public class PagedResult { public List Data { get; set; } public int TotalCount { get; set; } public int PageNumber { get; set; } public int PageSize { get; set; } public int TotalPages => (int)Math.Ceiling(TotalCount / (double)PageSize); public PagedResult() { Data = new List(); } } public class PagingParams { private const int MaxPageSize = 10000; private int _pageSize = 50; public int PageNumber { get; set; } = 1; public int PageSize { get => _pageSize; set => _pageSize = (value > MaxPageSize) ? MaxPageSize : value; } public string SortBy { get; set; } public bool IsAscending { get; set; } = false; // 新增:过滤条件(键值对) public Dictionary Filters { get; set; } = new(); } public static class QueryableExtensions { public static IQueryable ApplySort(this IQueryable query, string sortBy, bool isAscending) { if (string.IsNullOrEmpty(sortBy)) return query; var parameter = Expression.Parameter(typeof(T), "x"); var property = Expression.Property(parameter, sortBy); var lambda = Expression.Lambda(property, parameter); var methodName = isAscending ? "OrderBy" : "OrderByDescending"; var resultExpression = Expression.Call( typeof(Queryable), methodName, new[] { typeof(T), property.Type }, query.Expression, Expression.Quote(lambda)); return query.Provider.CreateQuery(resultExpression); } public static IQueryable ApplyFilters(this IQueryable query, Dictionary filters) { if (filters == null || !filters.Any()) return query; foreach (var filter in filters) { var property = typeof(T).GetProperty(filter.Key, BindingFlags.IgnoreCase | BindingFlags.Public | BindingFlags.Instance); if (property != null) { var parameter = Expression.Parameter(typeof(T), "x"); var propertyAccess = Expression.Property(parameter, property); var constant = Expression.Constant( Convert.ChangeType(filter.Value, property.PropertyType)); var equalExpression = Expression.Equal(propertyAccess, constant); var lambda = Expression.Lambda>(equalExpression, parameter); query = query.Where(lambda); } } return query; } public static IQueryable ApplyConditionFilters(this IQueryable query, Condition condition) { if (condition.Filters == null || !condition.Filters.Any()) return query; query = query.Where(condition.Filters.ToLambda()); return query; } public static IQueryable ApplyStringFilter(this IQueryable query, string propertyName, string value) { var property = typeof(T).GetProperty(propertyName, BindingFlags.IgnoreCase | BindingFlags.Public | BindingFlags.Instance); if (property != null && property.PropertyType == typeof(string) && !string.IsNullOrEmpty(value)) { var parameter = Expression.Parameter(typeof(T), "x"); var propertyAccess = Expression.Property(parameter, property); var containsMethod = typeof(string).GetMethod("Contains", new[] { typeof(string) }); var constant = Expression.Constant(value); var containsExpression = Expression.Call(propertyAccess, containsMethod, constant); var lambda = Expression.Lambda>(containsExpression, parameter); query = query.Where(lambda); } return query; } public static async Task> ToPagedListAsync(this IQueryable source, PagingParams pagingParams) { var count = await source.CountAsync(); var items = await source .ApplySort(pagingParams.SortBy, pagingParams.IsAscending) .Skip((pagingParams.PageNumber - 1) * pagingParams.PageSize) .Take(pagingParams.PageSize) .ToListAsync(); return new PagedResult { Data = items, TotalCount = count, PageNumber = pagingParams.PageNumber, PageSize = pagingParams.PageSize }; } } public static class EfCoreBulkExtensions { //private const int BatchSize = 1000; private static int CalculateBatchSize(int propertyCount) { // 预留100个参数作为缓冲 const int maxParameters = 2000; return Math.Max(1, maxParameters / propertyCount); } /// /// EF Core 批量插入或更新扩展方法(支持复合主键) /// public static void BulkMerge(this DbContext context, IEnumerable entities, params Expression>[] keySelectors) where TEntity : class { if (context == null || entities == null || !entities.Any()) return; using var transaction = context.Database.BeginTransaction(); try { Type entityType = typeof(TEntity); string tableName = context.Model.FindEntityType(entityType).GetTableName(); string providerName = GetDatabaseProviderName(context); // 获取所有主键属性 List keyPropNames = keySelectors.Select(GetPropertyName).ToList(); List keyProps = keyPropNames.Select(entityType.GetProperty).ToList(); // 获取所有非主键属性 PropertyInfo[] props = entityType.GetProperties() .Where(p => !keyPropNames.Contains(p.Name)).ToArray(); // 计算实体的总属性数量 int totalProperties = keyProps.Count + props.Length; // 动态计算批处理大小 int batchSize = CalculateBatchSize(totalProperties); // 分批处理 var batches = entities.Chunk(batchSize); foreach (var batch in batches) { // 临时表名 string tempTableName = $"#Temp_Bulk_{Guid.NewGuid().ToString("N").Substring(0, 8)}"; try { // 1. 创建临时表 string createSql = $"SELECT * INTO {tempTableName} FROM {tableName} WHERE 1 = 0"; context.Database.ExecuteSqlRaw(createSql); // 2. 插入数据到临时表 InsertToTempTable(context, batch, tempTableName, keyProps, props); // 3. 执行MERGE操作 ExecuteMerge(context, tableName, tempTableName, keyPropNames, props); transaction.Commit(); } finally { // 4. 清理临时表 DropTempTable(context, tempTableName); } } } catch (Exception ex) { transaction.Rollback(); Console.WriteLine($"批量操作失败: {ex.Message}"); throw; } } /// /// EF Core 批量插入或更新扩展方法(支持复合主键) /// //public static void BulkMerge(this DbContext context, // IEnumerable entities, // params Expression>[] keySelectors) // where TEntity : class //{ // if (context == null || entities == null || !entities.Any()) // return; // try // { // Type entityType = typeof(TEntity); // string tableName = context.Model.FindEntityType(entityType).GetTableName(); // string providerName = GetDatabaseProviderName(context); // // 获取所有主键属性 // List keyPropNames = keySelectors.Select(GetPropertyName).ToList(); // List keyProps = keyPropNames.Select(entityType.GetProperty).ToList(); // // 获取所有非主键属性 // PropertyInfo[] props = entityType.GetProperties() // .Where(p => !keyPropNames.Contains(p.Name)).ToArray(); // // 分批处理 // var batches = entities.Chunk(BatchSize); // foreach (var batch in batches) // { // // 临时表名 // string tempTableName = $"#Temp_Bulk_{Guid.NewGuid().ToString("N").Substring(0, 8)}"; // try // { // // 1. 创建临时表 // string createSql = $"SELECT * INTO {tempTableName} FROM {tableName} WHERE 1 = 0"; // context.Database.ExecuteSqlRaw(createSql); // //CreateTempTable(context, tempTableName, keyProps, props, providerName); // // 2. 插入数据到临时表 // InsertToTempTable(context, batch, tempTableName, keyProps, props); // // 3. 执行MERGE操作 // ExecuteMerge(context, tableName, tempTableName, keyPropNames, props); // } // finally // { // // 4. 清理临时表 // DropTempTable(context, tempTableName); // } // } // } // catch (Exception ex) // { // Console.WriteLine($"批量操作失败: {ex.Message}"); // throw; // } //} // 解析表达式获取属性名 private static string GetPropertyName(Expression> expression) { if (expression.Body is UnaryExpression unary) expression = (Expression>)unary.Operand; if (expression.Body is MemberExpression member) return member.Member.Name; throw new ArgumentException("表达式必须指向实体属性", nameof(expression)); } // 创建临时表 private static void CreateTempTable(DbContext context, string tempTableName, List keyProps, PropertyInfo[] props, string providerName) { string keyConstraints = string.Join(" AND ", keyProps.Select(p => $"[{p.Name}]")); string sql = $"CREATE TABLE {tempTableName} (" + string.Join(", ", keyProps.Select(p => $"[{p.Name}] {GetSqlType(p.PropertyType, providerName)}")) + ", " + string.Join(", ", props.Select(p => $"[{p.Name}] {GetSqlType(p.PropertyType, providerName)}")) + ", PRIMARY KEY (" + keyConstraints + ")" + ")"; context.Database.ExecuteSqlRaw(sql); } // 插入数据到临时表 private static void InsertToTempTable(DbContext context, IEnumerable entities, string tempTableName, List keyProps, PropertyInfo[] props) { // 构建参数化SQL string paramPrefix = "@p"; int paramIndex = 0; string values = string.Join(",\n", entities.Select(e => { string row = "("; // 添加主键参数 foreach (var keyProp in keyProps) row += $"{paramPrefix}{paramIndex++}, "; // 添加其他属性参数 foreach (var prop in props) row += $"{paramPrefix}{paramIndex++}, "; return row.TrimEnd(", ".ToCharArray()) + ")"; })); string sql = $"INSERT INTO {tempTableName} ({GetColumnList(keyProps, props)}) VALUES {values}"; DbCommand cmd = context.Database.GetDbConnection().CreateCommand(); cmd.CommandText = sql; // 添加参数 paramIndex = 0; foreach (var entity in entities) { // 添加主键值 foreach (var keyProp in keyProps) AddParameter(cmd, $"{paramPrefix}{paramIndex++}", keyProp.GetValue(entity)); // 添加其他属性值 foreach (var prop in props) AddParameter(cmd, $"{paramPrefix}{paramIndex++}", prop.GetValue(entity)); } context.Database.OpenConnection(); try { cmd.ExecuteNonQuery(); } finally { context.Database.CloseConnection(); } } // 添加参数(防止SQL注入) private static void AddParameter(DbCommand cmd, string paramName, object value) { DbParameter param = cmd.CreateParameter(); param.ParameterName = paramName; param.Value = value ?? DBNull.Value; cmd.Parameters.Add(param); } // 构建列名列表 private static string GetColumnList(List keyProps, PropertyInfo[] props) { return string.Join(", ", keyProps.Select(p => $"[{p.Name}]")) + ", " + string.Join(", ", props.Select(p => $"[{p.Name}]")); } // 获取SQL类型 private static string GetSqlType(Type type, string providerName) { type = Nullable.GetUnderlyingType(type) ?? type; if (providerName.Contains("SqlServer")) { if (type == typeof(int)) return "INT"; if (type == typeof(long)) return "BIGINT"; if (type == typeof(short)) return "SMALLINT"; if (type == typeof(byte)) return "TINYINT"; if (type == typeof(bool)) return "BIT"; if (type == typeof(string)) return "NVARCHAR(MAX)"; if (type == typeof(DateTime)) return "DATETIME2"; if (type == typeof(DateTimeOffset)) return "DATETIMEOFFSET"; if (type == typeof(decimal)) return "DECIMAL(18, 2)"; if (type == typeof(float)) return "FLOAT"; if (type == typeof(double)) return "FLOAT"; if (type == typeof(Guid)) return "UNIQUEIDENTIFIER"; if (type == typeof(byte[])) return "VARBINARY(MAX)"; } else if (providerName.Contains("PostgreSQL")) { if (type == typeof(int)) return "INTEGER"; if (type == typeof(long)) return "BIGINT"; if (type == typeof(short)) return "SMALLINT"; if (type == typeof(bool)) return "BOOLEAN"; if (type == typeof(string)) return "TEXT"; if (type == typeof(DateTime)) return "TIMESTAMP"; if (type == typeof(DateTimeOffset)) return "TIMESTAMPTZ"; if (type == typeof(decimal)) return "DECIMAL(18, 2)"; if (type == typeof(float)) return "REAL"; if (type == typeof(double)) return "DOUBLE PRECISION"; if (type == typeof(Guid)) return "UUID"; if (type == typeof(byte[])) return "BYTEA"; } throw new NotSupportedException($"类型 {type.Name} 或数据库提供者 {providerName} 不支持"); } // 执行MERGE操作 private static void ExecuteMerge(DbContext context, string tableName, string tempTableName, List keyPropNames, PropertyInfo[] props) { string onClause = string.Join(" AND ", keyPropNames.Select(k => $"target.[{k}] = source.[{k}]")); string updateFields = string.Join(", ", props.Select(p => $"target.[{p.Name}] = source.[{p.Name}]")); string sql = $"MERGE {tableName} AS target " + $"USING {tempTableName} AS source " + $"ON ({onClause}) " + $"WHEN MATCHED THEN UPDATE SET {updateFields} " + $"WHEN NOT MATCHED THEN " + $"INSERT ({GetColumnList(keyPropNames, props)}) " + $"VALUES ({string.Join(", ", keyPropNames.Concat(props.Select(p => $"source.[{p.Name}]")))});"; context.Database.ExecuteSqlRaw(sql); } // 删除临时表 private static void DropTempTable(DbContext context, string tempTableName) { context.Database.ExecuteSqlRaw($"DROP TABLE IF EXISTS {tempTableName}"); } // 获取数据库提供者名称 private static string GetDatabaseProviderName(DbContext context) { return context.Database.ProviderName; } // 构建列名列表(重载版本) private static string GetColumnList(List keyPropNames, PropertyInfo[] props) { return string.Join(", ", keyPropNames.Select(k => $"[{k}]")) + ", " + string.Join(", ", props.Select(p => $"[{p.Name}]")); } } }