You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

609 lines
22 KiB

using Microsoft.EntityFrameworkCore;
using OfficeOpenXml.FormulaParsing.Excel.Functions.Math;
using Serilog;
using SkiaSharp;
using System;
using System.Collections.Generic;
using System.Data.Common;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Text;
using System.Threading.Tasks;
using TaskManager.Entity;
using TaskManager.EntityFramework;
using Wood.Util;
using Wood.Util.Filters;
namespace TaskManager.EntityFramework.Repository
{
public class Repository<TEntity> : IRepository<TEntity>
where TEntity : BaseEntity
{
private JobDbContext _context;
private DbSet<TEntity> _dbSet;
public Repository(JobDbContext context)
{
_context = context;
_dbSet = context.Set<TEntity>();
}
public void SetDbContext(JobDbContext context)
{
_context = context;
_dbSet = context.Set<TEntity>();
}
public async Task<TEntity> GetByIdAsync(long id)
{
return await _dbSet.FindAsync(id);
}
public async Task<IEnumerable<TEntity>> GetAllAsync()
{
return await _dbSet.ToListAsync();
}
public async Task<TEntity> AddAsync(TEntity entity)
{
entity.CreationTime = DateTime.UtcNow;
await _dbSet.AddAsync(entity);
await _context.SaveChangesAsync();
return entity;
}
public async Task UpdateAsync(TEntity entity)
{
_dbSet.Update(entity);
await _context.SaveChangesAsync();
}
public async Task DeleteAsync(long id)
{
var entity = await _dbSet.FindAsync(id);
if (entity != null)
{
_dbSet.Remove(entity);
await _context.SaveChangesAsync();
}
}
public async Task<PagedResult<TEntity>> GetPagedAsync(PagingParams pagingParams)
{
return await _dbSet.AsNoTracking().ToPagedListAsync(pagingParams);
}
public async Task<PagedResult<TEntity>> GetPagedAsync(
Expression<Func<TEntity, bool>> filter = null,
PagingParams pagingParams = null)
{
IQueryable<TEntity> query = _dbSet.AsNoTracking();
// 应用过滤条件
if (filter != null)
{
query = query.Where(filter);
}
// 应用动态过滤
if (pagingParams?.Filters != null && pagingParams.Filters.Any())
{
query = query.ApplyFilters(pagingParams.Filters);
}
// 应用分页和排序
pagingParams ??= new PagingParams();
var page=await query.ToPagedListAsync(pagingParams);
return page;
}
public async Task<PagedResult<TEntity>> GetDataPagedAsync(
Expression<Func<TEntity, bool>> filter = null,
PagingParams pagingParams = null,Condition condition = null)
{
IQueryable<TEntity> query = _dbSet.AsNoTracking();
// 应用过滤条件
if (filter != null)
{
query = query.Where(filter);
}
// 应用动态过滤
if (condition?.Filters != null && condition.Filters.Any())
{
query = query.ApplyConditionFilters(condition);
}
// 应用分页和排序
pagingParams ??= new PagingParams();
var page = await query.ToPagedListAsync(pagingParams);
return page;
}
}
public class PagedResult<T>
{
public List<T> Data { get; set; }
public int TotalCount { get; set; }
public int PageNumber { get; set; }
public int PageSize { get; set; }
public int TotalPages => (int)Math.Ceiling(TotalCount / (double)PageSize);
public PagedResult()
{
Data = new List<T>();
}
}
public class PagingParams
{
private const int MaxPageSize = 10000;
private int _pageSize = 50;
public int PageNumber { get; set; } = 1;
public int PageSize
{
get => _pageSize;
set => _pageSize = (value > MaxPageSize) ? MaxPageSize : value;
}
public string SortBy { get; set; }
public bool IsAscending { get; set; } = false;
// 新增:过滤条件(键值对)
public Dictionary<string, string> Filters { get; set; } = new();
}
public static class QueryableExtensions
{
public static IQueryable<T> ApplySort<T>(this IQueryable<T> query, string sortBy, bool isAscending)
{
if (string.IsNullOrEmpty(sortBy)) return query;
var parameter = Expression.Parameter(typeof(T), "x");
var property = Expression.Property(parameter, sortBy);
var lambda = Expression.Lambda(property, parameter);
var methodName = isAscending ? "OrderBy" : "OrderByDescending";
var resultExpression = Expression.Call(
typeof(Queryable),
methodName,
new[] { typeof(T), property.Type },
query.Expression,
Expression.Quote(lambda));
return query.Provider.CreateQuery<T>(resultExpression);
}
public static IQueryable<T> ApplyFilters<T>(this IQueryable<T> query, Dictionary<string, string> filters)
{
if (filters == null || !filters.Any()) return query;
foreach (var filter in filters)
{
var property = typeof(T).GetProperty(filter.Key,
BindingFlags.IgnoreCase | BindingFlags.Public | BindingFlags.Instance);
if (property != null)
{
var parameter = Expression.Parameter(typeof(T), "x");
var propertyAccess = Expression.Property(parameter, property);
var constant = Expression.Constant(
Convert.ChangeType(filter.Value, property.PropertyType));
var equalExpression = Expression.Equal(propertyAccess, constant);
var lambda = Expression.Lambda<Func<T, bool>>(equalExpression, parameter);
query = query.Where(lambda);
}
}
return query;
}
public static IQueryable<T> ApplyConditionFilters<T>(this IQueryable<T> query, Condition condition)
{
if (condition.Filters == null || !condition.Filters.Any()) return query;
query = query.Where(condition.Filters.ToLambda<T>());
return query;
}
public static IQueryable<T> ApplyStringFilter<T>(this IQueryable<T> query,
string propertyName, string value)
{
var property = typeof(T).GetProperty(propertyName,
BindingFlags.IgnoreCase | BindingFlags.Public | BindingFlags.Instance);
if (property != null && property.PropertyType == typeof(string) && !string.IsNullOrEmpty(value))
{
var parameter = Expression.Parameter(typeof(T), "x");
var propertyAccess = Expression.Property(parameter, property);
var containsMethod = typeof(string).GetMethod("Contains", new[] { typeof(string) });
var constant = Expression.Constant(value);
var containsExpression = Expression.Call(propertyAccess, containsMethod, constant);
var lambda = Expression.Lambda<Func<T, bool>>(containsExpression, parameter);
query = query.Where(lambda);
}
return query;
}
public static async Task<PagedResult<T>> ToPagedListAsync<T>(this IQueryable<T> source, PagingParams pagingParams)
{
var count = await source.CountAsync();
var items = await source
.ApplySort(pagingParams.SortBy, pagingParams.IsAscending)
.Skip((pagingParams.PageNumber - 1) * pagingParams.PageSize)
.Take(pagingParams.PageSize)
.ToListAsync();
return new PagedResult<T>
{
Data = items,
TotalCount = count,
PageNumber = pagingParams.PageNumber,
PageSize = pagingParams.PageSize
};
}
}
public static class EfCoreBulkExtensions
{
//private const int BatchSize = 1000;
private static int CalculateBatchSize(int propertyCount)
{
// 预留100个参数作为缓冲
const int maxParameters = 2000;
return Math.Max(1, maxParameters / propertyCount);
}
/// <summary>
/// EF Core 批量插入或更新扩展方法(支持复合主键)
/// </summary>
public static void BulkMerge<TEntity>(this DbContext context,
IEnumerable<TEntity> entities,
params Expression<Func<TEntity, object>>[] keySelectors)
where TEntity : class
{
if (context == null || entities == null || !entities.Any())
return;
using var transaction = context.Database.BeginTransaction();
try
{
Type entityType = typeof(TEntity);
string tableName = context.Model.FindEntityType(entityType).GetTableName();
string providerName = GetDatabaseProviderName(context);
// 获取所有主键属性
List<string> keyPropNames = keySelectors.Select(GetPropertyName).ToList();
List<PropertyInfo> keyProps = keyPropNames.Select(entityType.GetProperty).ToList();
// 获取所有非主键属性
PropertyInfo[] props = entityType.GetProperties()
.Where(p => !keyPropNames.Contains(p.Name)).ToArray();
// 计算实体的总属性数量
int totalProperties = keyProps.Count + props.Length;
// 动态计算批处理大小
int batchSize = CalculateBatchSize(totalProperties);
// 分批处理
var batches = entities.Chunk(batchSize);
foreach (var batch in batches)
{
// 临时表名
string tempTableName = $"#Temp_Bulk_{Guid.NewGuid().ToString("N").Substring(0, 8)}";
try
{
// 1. 创建临时表
string createSql = $"SELECT * INTO {tempTableName} FROM {tableName} WHERE 1 = 0";
context.Database.ExecuteSqlRaw(createSql);
// 2. 插入数据到临时表
InsertToTempTable(context, batch, tempTableName, keyProps, props);
// 3. 执行MERGE操作
ExecuteMerge(context, tableName, tempTableName, keyPropNames, props);
transaction.Commit();
}
finally
{
// 4. 清理临时表
DropTempTable(context, tempTableName);
}
}
}
catch (Exception ex)
{
transaction.Rollback();
Console.WriteLine($"批量操作失败: {ex.Message}");
throw;
}
}
/// <summary>
/// EF Core 批量插入或更新扩展方法(支持复合主键)
/// </summary>
//public static void BulkMerge<TEntity>(this DbContext context,
// IEnumerable<TEntity> entities,
// params Expression<Func<TEntity, object>>[] keySelectors)
// where TEntity : class
//{
// if (context == null || entities == null || !entities.Any())
// return;
// try
// {
// Type entityType = typeof(TEntity);
// string tableName = context.Model.FindEntityType(entityType).GetTableName();
// string providerName = GetDatabaseProviderName(context);
// // 获取所有主键属性
// List<string> keyPropNames = keySelectors.Select(GetPropertyName).ToList();
// List<PropertyInfo> keyProps = keyPropNames.Select(entityType.GetProperty).ToList();
// // 获取所有非主键属性
// PropertyInfo[] props = entityType.GetProperties()
// .Where(p => !keyPropNames.Contains(p.Name)).ToArray();
// // 分批处理
// var batches = entities.Chunk(BatchSize);
// foreach (var batch in batches)
// {
// // 临时表名
// string tempTableName = $"#Temp_Bulk_{Guid.NewGuid().ToString("N").Substring(0, 8)}";
// try
// {
// // 1. 创建临时表
// string createSql = $"SELECT * INTO {tempTableName} FROM {tableName} WHERE 1 = 0";
// context.Database.ExecuteSqlRaw(createSql);
// //CreateTempTable(context, tempTableName, keyProps, props, providerName);
// // 2. 插入数据到临时表
// InsertToTempTable(context, batch, tempTableName, keyProps, props);
// // 3. 执行MERGE操作
// ExecuteMerge(context, tableName, tempTableName, keyPropNames, props);
// }
// finally
// {
// // 4. 清理临时表
// DropTempTable(context, tempTableName);
// }
// }
// }
// catch (Exception ex)
// {
// Console.WriteLine($"批量操作失败: {ex.Message}");
// throw;
// }
//}
// 解析表达式获取属性名
private static string GetPropertyName<TEntity>(Expression<Func<TEntity, object>> expression)
{
if (expression.Body is UnaryExpression unary)
expression = (Expression<Func<TEntity, object>>)unary.Operand;
if (expression.Body is MemberExpression member)
return member.Member.Name;
throw new ArgumentException("表达式必须指向实体属性", nameof(expression));
}
// 创建临时表
private static void CreateTempTable(DbContext context, string tempTableName,
List<PropertyInfo> keyProps, PropertyInfo[] props, string providerName)
{
string keyConstraints = string.Join(" AND ", keyProps.Select(p => $"[{p.Name}]"));
string sql = $"CREATE TABLE {tempTableName} (" +
string.Join(", ", keyProps.Select(p => $"[{p.Name}] {GetSqlType(p.PropertyType, providerName)}")) +
", " +
string.Join(", ", props.Select(p => $"[{p.Name}] {GetSqlType(p.PropertyType, providerName)}")) +
", PRIMARY KEY (" + keyConstraints + ")" +
")";
context.Database.ExecuteSqlRaw(sql);
}
// 插入数据到临时表
private static void InsertToTempTable<TEntity>(DbContext context, IEnumerable<TEntity> entities,
string tempTableName, List<PropertyInfo> keyProps, PropertyInfo[] props)
{
// 构建参数化SQL
string paramPrefix = "@p";
int paramIndex = 0;
string values = string.Join(",\n", entities.Select(e =>
{
string row = "(";
// 添加主键参数
foreach (var keyProp in keyProps)
row += $"{paramPrefix}{paramIndex++}, ";
// 添加其他属性参数
foreach (var prop in props)
row += $"{paramPrefix}{paramIndex++}, ";
return row.TrimEnd(", ".ToCharArray()) + ")";
}));
string sql = $"INSERT INTO {tempTableName} ({GetColumnList(keyProps, props)}) VALUES {values}";
DbCommand cmd = context.Database.GetDbConnection().CreateCommand();
cmd.CommandText = sql;
// 添加参数
paramIndex = 0;
foreach (var entity in entities)
{
// 添加主键值
foreach (var keyProp in keyProps)
AddParameter(cmd, $"{paramPrefix}{paramIndex++}", keyProp.GetValue(entity));
// 添加其他属性值
foreach (var prop in props)
AddParameter(cmd, $"{paramPrefix}{paramIndex++}", prop.GetValue(entity));
}
context.Database.OpenConnection();
try { cmd.ExecuteNonQuery(); }
finally { context.Database.CloseConnection(); }
}
// 添加参数(防止SQL注入)
private static void AddParameter(DbCommand cmd, string paramName, object value)
{
DbParameter param = cmd.CreateParameter();
param.ParameterName = paramName;
param.Value = value ?? DBNull.Value;
cmd.Parameters.Add(param);
}
// 构建列名列表
private static string GetColumnList(List<PropertyInfo> keyProps, PropertyInfo[] props)
{
return string.Join(", ", keyProps.Select(p => $"[{p.Name}]")) + ", " +
string.Join(", ", props.Select(p => $"[{p.Name}]"));
}
// 获取SQL类型
private static string GetSqlType(Type type, string providerName)
{
type = Nullable.GetUnderlyingType(type) ?? type;
if (providerName.Contains("SqlServer"))
{
if (type == typeof(int)) return "INT";
if (type == typeof(long)) return "BIGINT";
if (type == typeof(short)) return "SMALLINT";
if (type == typeof(byte)) return "TINYINT";
if (type == typeof(bool)) return "BIT";
if (type == typeof(string)) return "NVARCHAR(MAX)";
if (type == typeof(DateTime)) return "DATETIME2";
if (type == typeof(DateTimeOffset)) return "DATETIMEOFFSET";
if (type == typeof(decimal)) return "DECIMAL(18, 2)";
if (type == typeof(float)) return "FLOAT";
if (type == typeof(double)) return "FLOAT";
if (type == typeof(Guid)) return "UNIQUEIDENTIFIER";
if (type == typeof(byte[])) return "VARBINARY(MAX)";
}
else if (providerName.Contains("PostgreSQL"))
{
if (type == typeof(int)) return "INTEGER";
if (type == typeof(long)) return "BIGINT";
if (type == typeof(short)) return "SMALLINT";
if (type == typeof(bool)) return "BOOLEAN";
if (type == typeof(string)) return "TEXT";
if (type == typeof(DateTime)) return "TIMESTAMP";
if (type == typeof(DateTimeOffset)) return "TIMESTAMPTZ";
if (type == typeof(decimal)) return "DECIMAL(18, 2)";
if (type == typeof(float)) return "REAL";
if (type == typeof(double)) return "DOUBLE PRECISION";
if (type == typeof(Guid)) return "UUID";
if (type == typeof(byte[])) return "BYTEA";
}
throw new NotSupportedException($"类型 {type.Name} 或数据库提供者 {providerName} 不支持");
}
// 执行MERGE操作
private static void ExecuteMerge(DbContext context, string tableName, string tempTableName,
List<string> keyPropNames, PropertyInfo[] props)
{
string onClause = string.Join(" AND ", keyPropNames.Select(k =>
$"target.[{k}] = source.[{k}]"));
string updateFields = string.Join(", ", props.Select(p =>
$"target.[{p.Name}] = source.[{p.Name}]"));
string sql = $"MERGE {tableName} AS target " +
$"USING {tempTableName} AS source " +
$"ON ({onClause}) " +
$"WHEN MATCHED THEN UPDATE SET {updateFields} " +
$"WHEN NOT MATCHED THEN " +
$"INSERT ({GetColumnList(keyPropNames, props)}) " +
$"VALUES ({string.Join(", ", keyPropNames.Concat(props.Select(p => $"source.[{p.Name}]")))});";
context.Database.ExecuteSqlRaw(sql);
}
// 删除临时表
private static void DropTempTable(DbContext context, string tempTableName)
{
context.Database.ExecuteSqlRaw($"DROP TABLE IF EXISTS {tempTableName}");
}
// 获取数据库提供者名称
private static string GetDatabaseProviderName(DbContext context)
{
return context.Database.ProviderName;
}
// 构建列名列表(重载版本)
private static string GetColumnList(List<string> keyPropNames, PropertyInfo[] props)
{
return string.Join(", ", keyPropNames.Select(k => $"[{k}]")) + ", " +
string.Join(", ", props.Select(p => $"[{p.Name}]"));
}
}
}