You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
609 lines
22 KiB
609 lines
22 KiB
using Microsoft.EntityFrameworkCore;
|
|
using OfficeOpenXml.FormulaParsing.Excel.Functions.Math;
|
|
using Serilog;
|
|
using SkiaSharp;
|
|
using System;
|
|
using System.Collections.Generic;
|
|
using System.Data.Common;
|
|
using System.Linq;
|
|
using System.Linq.Expressions;
|
|
using System.Reflection;
|
|
using System.Text;
|
|
using System.Threading.Tasks;
|
|
using TaskManager.Entity;
|
|
using TaskManager.EntityFramework;
|
|
using Wood.Util;
|
|
using Wood.Util.Filters;
|
|
|
|
|
|
namespace TaskManager.EntityFramework.Repository
|
|
{
|
|
|
|
public class Repository<TEntity> : IRepository<TEntity>
|
|
where TEntity : BaseEntity
|
|
{
|
|
private JobDbContext _context;
|
|
private DbSet<TEntity> _dbSet;
|
|
|
|
public Repository(JobDbContext context)
|
|
{
|
|
_context = context;
|
|
_dbSet = context.Set<TEntity>();
|
|
}
|
|
|
|
public void SetDbContext(JobDbContext context)
|
|
{
|
|
_context = context;
|
|
_dbSet = context.Set<TEntity>();
|
|
}
|
|
|
|
public async Task<TEntity> GetByIdAsync(long id)
|
|
{
|
|
return await _dbSet.FindAsync(id);
|
|
}
|
|
|
|
public async Task<IEnumerable<TEntity>> GetAllAsync()
|
|
{
|
|
return await _dbSet.ToListAsync();
|
|
}
|
|
|
|
public async Task<TEntity> AddAsync(TEntity entity)
|
|
{
|
|
entity.CreationTime = DateTime.UtcNow;
|
|
await _dbSet.AddAsync(entity);
|
|
await _context.SaveChangesAsync();
|
|
return entity;
|
|
}
|
|
|
|
public async Task UpdateAsync(TEntity entity)
|
|
{
|
|
_dbSet.Update(entity);
|
|
await _context.SaveChangesAsync();
|
|
}
|
|
|
|
public async Task DeleteAsync(long id)
|
|
{
|
|
var entity = await _dbSet.FindAsync(id);
|
|
if (entity != null)
|
|
{
|
|
_dbSet.Remove(entity);
|
|
await _context.SaveChangesAsync();
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
public async Task<PagedResult<TEntity>> GetPagedAsync(PagingParams pagingParams)
|
|
{
|
|
return await _dbSet.AsNoTracking().ToPagedListAsync(pagingParams);
|
|
}
|
|
|
|
|
|
public async Task<PagedResult<TEntity>> GetPagedAsync(
|
|
Expression<Func<TEntity, bool>> filter = null,
|
|
PagingParams pagingParams = null)
|
|
{
|
|
IQueryable<TEntity> query = _dbSet.AsNoTracking();
|
|
|
|
// 应用过滤条件
|
|
if (filter != null)
|
|
{
|
|
query = query.Where(filter);
|
|
}
|
|
|
|
// 应用动态过滤
|
|
if (pagingParams?.Filters != null && pagingParams.Filters.Any())
|
|
{
|
|
query = query.ApplyFilters(pagingParams.Filters);
|
|
}
|
|
|
|
// 应用分页和排序
|
|
pagingParams ??= new PagingParams();
|
|
|
|
var page=await query.ToPagedListAsync(pagingParams);
|
|
|
|
|
|
return page;
|
|
}
|
|
|
|
public async Task<PagedResult<TEntity>> GetDataPagedAsync(
|
|
Expression<Func<TEntity, bool>> filter = null,
|
|
PagingParams pagingParams = null,Condition condition = null)
|
|
{
|
|
IQueryable<TEntity> query = _dbSet.AsNoTracking();
|
|
|
|
// 应用过滤条件
|
|
if (filter != null)
|
|
{
|
|
query = query.Where(filter);
|
|
}
|
|
|
|
// 应用动态过滤
|
|
if (condition?.Filters != null && condition.Filters.Any())
|
|
{
|
|
query = query.ApplyConditionFilters(condition);
|
|
}
|
|
|
|
// 应用分页和排序
|
|
pagingParams ??= new PagingParams();
|
|
|
|
var page = await query.ToPagedListAsync(pagingParams);
|
|
|
|
|
|
return page;
|
|
}
|
|
|
|
|
|
|
|
|
|
}
|
|
public class PagedResult<T>
|
|
{
|
|
public List<T> Data { get; set; }
|
|
public int TotalCount { get; set; }
|
|
public int PageNumber { get; set; }
|
|
public int PageSize { get; set; }
|
|
public int TotalPages => (int)Math.Ceiling(TotalCount / (double)PageSize);
|
|
|
|
public PagedResult()
|
|
{
|
|
Data = new List<T>();
|
|
}
|
|
}
|
|
|
|
|
|
|
|
public class PagingParams
|
|
{
|
|
private const int MaxPageSize = 10000;
|
|
private int _pageSize = 10;
|
|
|
|
public int PageNumber { get; set; } = 1;
|
|
public int PageSize
|
|
{
|
|
get => _pageSize;
|
|
set => _pageSize = (value > MaxPageSize) ? MaxPageSize : value;
|
|
}
|
|
public string SortBy { get; set; }
|
|
public bool IsAscending { get; set; } = true;
|
|
|
|
// 新增:过滤条件(键值对)
|
|
public Dictionary<string, string> Filters { get; set; } = new();
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
public static class QueryableExtensions
|
|
{
|
|
public static IQueryable<T> ApplySort<T>(this IQueryable<T> query, string sortBy, bool isAscending)
|
|
{
|
|
if (string.IsNullOrEmpty(sortBy)) return query;
|
|
|
|
var parameter = Expression.Parameter(typeof(T), "x");
|
|
var property = Expression.Property(parameter, sortBy);
|
|
var lambda = Expression.Lambda(property, parameter);
|
|
|
|
var methodName = isAscending ? "OrderBy" : "OrderByDescending";
|
|
var resultExpression = Expression.Call(
|
|
typeof(Queryable),
|
|
methodName,
|
|
new[] { typeof(T), property.Type },
|
|
query.Expression,
|
|
Expression.Quote(lambda));
|
|
|
|
return query.Provider.CreateQuery<T>(resultExpression);
|
|
}
|
|
|
|
|
|
public static IQueryable<T> ApplyFilters<T>(this IQueryable<T> query, Dictionary<string, string> filters)
|
|
{
|
|
if (filters == null || !filters.Any()) return query;
|
|
|
|
foreach (var filter in filters)
|
|
{
|
|
var property = typeof(T).GetProperty(filter.Key,
|
|
BindingFlags.IgnoreCase | BindingFlags.Public | BindingFlags.Instance);
|
|
|
|
if (property != null)
|
|
{
|
|
var parameter = Expression.Parameter(typeof(T), "x");
|
|
var propertyAccess = Expression.Property(parameter, property);
|
|
var constant = Expression.Constant(
|
|
Convert.ChangeType(filter.Value, property.PropertyType));
|
|
|
|
var equalExpression = Expression.Equal(propertyAccess, constant);
|
|
var lambda = Expression.Lambda<Func<T, bool>>(equalExpression, parameter);
|
|
|
|
query = query.Where(lambda);
|
|
}
|
|
}
|
|
|
|
return query;
|
|
}
|
|
|
|
public static IQueryable<T> ApplyConditionFilters<T>(this IQueryable<T> query, Condition condition)
|
|
{
|
|
if (condition.Filters == null || !condition.Filters.Any()) return query;
|
|
|
|
query = query.Where(condition.Filters.ToLambda<T>());
|
|
|
|
return query;
|
|
}
|
|
|
|
|
|
|
|
public static IQueryable<T> ApplyStringFilter<T>(this IQueryable<T> query,
|
|
string propertyName, string value)
|
|
{
|
|
var property = typeof(T).GetProperty(propertyName,
|
|
BindingFlags.IgnoreCase | BindingFlags.Public | BindingFlags.Instance);
|
|
|
|
if (property != null && property.PropertyType == typeof(string) && !string.IsNullOrEmpty(value))
|
|
{
|
|
var parameter = Expression.Parameter(typeof(T), "x");
|
|
var propertyAccess = Expression.Property(parameter, property);
|
|
var containsMethod = typeof(string).GetMethod("Contains", new[] { typeof(string) });
|
|
var constant = Expression.Constant(value);
|
|
|
|
var containsExpression = Expression.Call(propertyAccess, containsMethod, constant);
|
|
var lambda = Expression.Lambda<Func<T, bool>>(containsExpression, parameter);
|
|
|
|
query = query.Where(lambda);
|
|
}
|
|
|
|
return query;
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
public static async Task<PagedResult<T>> ToPagedListAsync<T>(this IQueryable<T> source, PagingParams pagingParams)
|
|
{
|
|
var count = await source.CountAsync();
|
|
var items = await source
|
|
.ApplySort(pagingParams.SortBy, pagingParams.IsAscending)
|
|
.Skip((pagingParams.PageNumber - 1) * pagingParams.PageSize)
|
|
.Take(pagingParams.PageSize)
|
|
.ToListAsync();
|
|
|
|
return new PagedResult<T>
|
|
{
|
|
Data = items,
|
|
TotalCount = count,
|
|
PageNumber = pagingParams.PageNumber,
|
|
PageSize = pagingParams.PageSize
|
|
};
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public static class EfCoreBulkExtensions
|
|
{
|
|
//private const int BatchSize = 1000;
|
|
|
|
|
|
private static int CalculateBatchSize(int propertyCount)
|
|
{
|
|
// 预留100个参数作为缓冲
|
|
const int maxParameters = 2000;
|
|
return Math.Max(1, maxParameters / propertyCount);
|
|
}
|
|
|
|
/// <summary>
|
|
/// EF Core 批量插入或更新扩展方法(支持复合主键)
|
|
/// </summary>
|
|
public static void BulkMerge<TEntity>(this DbContext context,
|
|
IEnumerable<TEntity> entities,
|
|
params Expression<Func<TEntity, object>>[] keySelectors)
|
|
where TEntity : class
|
|
{
|
|
if (context == null || entities == null || !entities.Any())
|
|
return;
|
|
using var transaction = context.Database.BeginTransaction();
|
|
try
|
|
{
|
|
Type entityType = typeof(TEntity);
|
|
string tableName = context.Model.FindEntityType(entityType).GetTableName();
|
|
string providerName = GetDatabaseProviderName(context);
|
|
|
|
// 获取所有主键属性
|
|
List<string> keyPropNames = keySelectors.Select(GetPropertyName).ToList();
|
|
List<PropertyInfo> keyProps = keyPropNames.Select(entityType.GetProperty).ToList();
|
|
|
|
// 获取所有非主键属性
|
|
PropertyInfo[] props = entityType.GetProperties()
|
|
.Where(p => !keyPropNames.Contains(p.Name)).ToArray();
|
|
|
|
// 计算实体的总属性数量
|
|
int totalProperties = keyProps.Count + props.Length;
|
|
// 动态计算批处理大小
|
|
int batchSize = CalculateBatchSize(totalProperties);
|
|
|
|
// 分批处理
|
|
var batches = entities.Chunk(batchSize);
|
|
|
|
foreach (var batch in batches)
|
|
{
|
|
// 临时表名
|
|
string tempTableName = $"#Temp_Bulk_{Guid.NewGuid().ToString("N").Substring(0, 8)}";
|
|
|
|
try
|
|
{
|
|
// 1. 创建临时表
|
|
|
|
|
|
string createSql = $"SELECT * INTO {tempTableName} FROM {tableName} WHERE 1 = 0";
|
|
context.Database.ExecuteSqlRaw(createSql);
|
|
|
|
// 2. 插入数据到临时表
|
|
InsertToTempTable(context, batch, tempTableName, keyProps, props);
|
|
|
|
// 3. 执行MERGE操作
|
|
ExecuteMerge(context, tableName, tempTableName, keyPropNames, props);
|
|
transaction.Commit();
|
|
|
|
}
|
|
finally
|
|
{
|
|
// 4. 清理临时表
|
|
DropTempTable(context, tempTableName);
|
|
}
|
|
}
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
transaction.Rollback();
|
|
Console.WriteLine($"批量操作失败: {ex.Message}");
|
|
throw;
|
|
}
|
|
}
|
|
|
|
|
|
/// <summary>
|
|
/// EF Core 批量插入或更新扩展方法(支持复合主键)
|
|
/// </summary>
|
|
//public static void BulkMerge<TEntity>(this DbContext context,
|
|
// IEnumerable<TEntity> entities,
|
|
// params Expression<Func<TEntity, object>>[] keySelectors)
|
|
// where TEntity : class
|
|
//{
|
|
// if (context == null || entities == null || !entities.Any())
|
|
// return;
|
|
|
|
|
|
// try
|
|
// {
|
|
// Type entityType = typeof(TEntity);
|
|
// string tableName = context.Model.FindEntityType(entityType).GetTableName();
|
|
// string providerName = GetDatabaseProviderName(context);
|
|
|
|
// // 获取所有主键属性
|
|
// List<string> keyPropNames = keySelectors.Select(GetPropertyName).ToList();
|
|
// List<PropertyInfo> keyProps = keyPropNames.Select(entityType.GetProperty).ToList();
|
|
|
|
// // 获取所有非主键属性
|
|
// PropertyInfo[] props = entityType.GetProperties()
|
|
// .Where(p => !keyPropNames.Contains(p.Name)).ToArray();
|
|
|
|
// // 分批处理
|
|
// var batches = entities.Chunk(BatchSize);
|
|
|
|
// foreach (var batch in batches)
|
|
// {
|
|
// // 临时表名
|
|
// string tempTableName = $"#Temp_Bulk_{Guid.NewGuid().ToString("N").Substring(0, 8)}";
|
|
|
|
// try
|
|
// {
|
|
// // 1. 创建临时表
|
|
|
|
// string createSql = $"SELECT * INTO {tempTableName} FROM {tableName} WHERE 1 = 0";
|
|
|
|
// context.Database.ExecuteSqlRaw(createSql);
|
|
|
|
// //CreateTempTable(context, tempTableName, keyProps, props, providerName);
|
|
|
|
// // 2. 插入数据到临时表
|
|
// InsertToTempTable(context, batch, tempTableName, keyProps, props);
|
|
|
|
// // 3. 执行MERGE操作
|
|
// ExecuteMerge(context, tableName, tempTableName, keyPropNames, props);
|
|
// }
|
|
// finally
|
|
// {
|
|
// // 4. 清理临时表
|
|
// DropTempTable(context, tempTableName);
|
|
// }
|
|
// }
|
|
|
|
|
|
// }
|
|
// catch (Exception ex)
|
|
// {
|
|
|
|
// Console.WriteLine($"批量操作失败: {ex.Message}");
|
|
// throw;
|
|
// }
|
|
|
|
//}
|
|
|
|
// 解析表达式获取属性名
|
|
private static string GetPropertyName<TEntity>(Expression<Func<TEntity, object>> expression)
|
|
{
|
|
if (expression.Body is UnaryExpression unary)
|
|
expression = (Expression<Func<TEntity, object>>)unary.Operand;
|
|
|
|
if (expression.Body is MemberExpression member)
|
|
return member.Member.Name;
|
|
|
|
throw new ArgumentException("表达式必须指向实体属性", nameof(expression));
|
|
}
|
|
|
|
// 创建临时表
|
|
private static void CreateTempTable(DbContext context, string tempTableName,
|
|
List<PropertyInfo> keyProps, PropertyInfo[] props, string providerName)
|
|
{
|
|
string keyConstraints = string.Join(" AND ", keyProps.Select(p => $"[{p.Name}]"));
|
|
string sql = $"CREATE TABLE {tempTableName} (" +
|
|
string.Join(", ", keyProps.Select(p => $"[{p.Name}] {GetSqlType(p.PropertyType, providerName)}")) +
|
|
", " +
|
|
string.Join(", ", props.Select(p => $"[{p.Name}] {GetSqlType(p.PropertyType, providerName)}")) +
|
|
", PRIMARY KEY (" + keyConstraints + ")" +
|
|
")";
|
|
context.Database.ExecuteSqlRaw(sql);
|
|
}
|
|
|
|
// 插入数据到临时表
|
|
private static void InsertToTempTable<TEntity>(DbContext context, IEnumerable<TEntity> entities,
|
|
string tempTableName, List<PropertyInfo> keyProps, PropertyInfo[] props)
|
|
{
|
|
// 构建参数化SQL
|
|
string paramPrefix = "@p";
|
|
int paramIndex = 0;
|
|
string values = string.Join(",\n", entities.Select(e =>
|
|
{
|
|
string row = "(";
|
|
// 添加主键参数
|
|
foreach (var keyProp in keyProps)
|
|
row += $"{paramPrefix}{paramIndex++}, ";
|
|
// 添加其他属性参数
|
|
foreach (var prop in props)
|
|
row += $"{paramPrefix}{paramIndex++}, ";
|
|
return row.TrimEnd(", ".ToCharArray()) + ")";
|
|
}));
|
|
|
|
string sql = $"INSERT INTO {tempTableName} ({GetColumnList(keyProps, props)}) VALUES {values}";
|
|
DbCommand cmd = context.Database.GetDbConnection().CreateCommand();
|
|
cmd.CommandText = sql;
|
|
|
|
// 添加参数
|
|
paramIndex = 0;
|
|
foreach (var entity in entities)
|
|
{
|
|
// 添加主键值
|
|
foreach (var keyProp in keyProps)
|
|
AddParameter(cmd, $"{paramPrefix}{paramIndex++}", keyProp.GetValue(entity));
|
|
|
|
// 添加其他属性值
|
|
foreach (var prop in props)
|
|
AddParameter(cmd, $"{paramPrefix}{paramIndex++}", prop.GetValue(entity));
|
|
}
|
|
|
|
context.Database.OpenConnection();
|
|
try { cmd.ExecuteNonQuery(); }
|
|
finally { context.Database.CloseConnection(); }
|
|
}
|
|
|
|
// 添加参数(防止SQL注入)
|
|
private static void AddParameter(DbCommand cmd, string paramName, object value)
|
|
{
|
|
DbParameter param = cmd.CreateParameter();
|
|
param.ParameterName = paramName;
|
|
param.Value = value ?? DBNull.Value;
|
|
cmd.Parameters.Add(param);
|
|
}
|
|
|
|
// 构建列名列表
|
|
private static string GetColumnList(List<PropertyInfo> keyProps, PropertyInfo[] props)
|
|
{
|
|
return string.Join(", ", keyProps.Select(p => $"[{p.Name}]")) + ", " +
|
|
string.Join(", ", props.Select(p => $"[{p.Name}]"));
|
|
}
|
|
|
|
// 获取SQL类型
|
|
private static string GetSqlType(Type type, string providerName)
|
|
{
|
|
type = Nullable.GetUnderlyingType(type) ?? type;
|
|
|
|
if (providerName.Contains("SqlServer"))
|
|
{
|
|
if (type == typeof(int)) return "INT";
|
|
if (type == typeof(long)) return "BIGINT";
|
|
if (type == typeof(short)) return "SMALLINT";
|
|
if (type == typeof(byte)) return "TINYINT";
|
|
if (type == typeof(bool)) return "BIT";
|
|
if (type == typeof(string)) return "NVARCHAR(MAX)";
|
|
if (type == typeof(DateTime)) return "DATETIME2";
|
|
if (type == typeof(DateTimeOffset)) return "DATETIMEOFFSET";
|
|
if (type == typeof(decimal)) return "DECIMAL(18, 2)";
|
|
if (type == typeof(float)) return "FLOAT";
|
|
if (type == typeof(double)) return "FLOAT";
|
|
if (type == typeof(Guid)) return "UNIQUEIDENTIFIER";
|
|
if (type == typeof(byte[])) return "VARBINARY(MAX)";
|
|
}
|
|
else if (providerName.Contains("PostgreSQL"))
|
|
{
|
|
if (type == typeof(int)) return "INTEGER";
|
|
if (type == typeof(long)) return "BIGINT";
|
|
if (type == typeof(short)) return "SMALLINT";
|
|
if (type == typeof(bool)) return "BOOLEAN";
|
|
if (type == typeof(string)) return "TEXT";
|
|
if (type == typeof(DateTime)) return "TIMESTAMP";
|
|
if (type == typeof(DateTimeOffset)) return "TIMESTAMPTZ";
|
|
if (type == typeof(decimal)) return "DECIMAL(18, 2)";
|
|
if (type == typeof(float)) return "REAL";
|
|
if (type == typeof(double)) return "DOUBLE PRECISION";
|
|
if (type == typeof(Guid)) return "UUID";
|
|
if (type == typeof(byte[])) return "BYTEA";
|
|
}
|
|
|
|
throw new NotSupportedException($"类型 {type.Name} 或数据库提供者 {providerName} 不支持");
|
|
}
|
|
|
|
// 执行MERGE操作
|
|
private static void ExecuteMerge(DbContext context, string tableName, string tempTableName,
|
|
List<string> keyPropNames, PropertyInfo[] props)
|
|
{
|
|
string onClause = string.Join(" AND ", keyPropNames.Select(k =>
|
|
$"target.[{k}] = source.[{k}]"));
|
|
|
|
string updateFields = string.Join(", ", props.Select(p =>
|
|
$"target.[{p.Name}] = source.[{p.Name}]"));
|
|
|
|
string sql = $"MERGE {tableName} AS target " +
|
|
$"USING {tempTableName} AS source " +
|
|
$"ON ({onClause}) " +
|
|
$"WHEN MATCHED THEN UPDATE SET {updateFields} " +
|
|
$"WHEN NOT MATCHED THEN " +
|
|
$"INSERT ({GetColumnList(keyPropNames, props)}) " +
|
|
$"VALUES ({string.Join(", ", keyPropNames.Concat(props.Select(p => $"source.[{p.Name}]")))});";
|
|
|
|
context.Database.ExecuteSqlRaw(sql);
|
|
}
|
|
|
|
// 删除临时表
|
|
private static void DropTempTable(DbContext context, string tempTableName)
|
|
{
|
|
context.Database.ExecuteSqlRaw($"DROP TABLE IF EXISTS {tempTableName}");
|
|
}
|
|
|
|
// 获取数据库提供者名称
|
|
private static string GetDatabaseProviderName(DbContext context)
|
|
{
|
|
return context.Database.ProviderName;
|
|
}
|
|
|
|
// 构建列名列表(重载版本)
|
|
private static string GetColumnList(List<string> keyPropNames, PropertyInfo[] props)
|
|
{
|
|
return string.Join(", ", keyPropNames.Select(k => $"[{k}]")) + ", " +
|
|
string.Join(", ", props.Select(p => $"[{p.Name}]"));
|
|
}
|
|
}
|
|
|
|
}
|
|
|