1:不同的表主键类型可能不同(PS:string,int,long...)
需要定义一个基类,泛型主键, 每个数据库映射表都需要继承改类
/// <summary>
/// BaseModel
/// </summary>
/// <typeparam name="TPrimaryKey">主键类型</typeparam>
public abstract class Entity<TPrimaryKey> : DomainEntity, IEntity<TPrimaryKey>
{
/// <summary>
/// 主键
/// </summary>
[Key]
public virtual TPrimaryKey Id { get; set; }
/// <summary>
/// Checks if this entity is transient (it has not an Id).
/// </summary>
/// <returns>True, if this entity is transient</returns>
public virtual bool IsTransient()
{
if (EqualityComparer<TPrimaryKey>.Default.Equals(Id, default(TPrimaryKey)))
{
return true;
}
//Workaround for EF Core since it sets int/long to min value when attaching to dbcontext
if (typeof(TPrimaryKey) == typeof(int))
{
return Convert.ToInt32(Id) <= 0;
}
if (typeof(TPrimaryKey) == typeof(long))
{
return Convert.ToInt64(Id) <= 0;
}
if (typeof(TPrimaryKey) == typeof(string))
{
return string.IsNullOrEmpty(Id.ToString());
}
return false;
}
}
2:仓储接口:
public interface IRepository<TEntity, TPrimaryKey>
where TEntity : class, IEntity<TPrimaryKey>
{
#region DbContextTransaction
/// <summary>
/// 保存
/// </summary>
/// <param name="cancellationToken"></param>
/// <returns></returns>
Task<int> SaveAsync(CancellationToken cancellationToken = default);
/// <summary>
/// 回滚
/// </summary>
/// <returns></returns>
Task RollbackTransactionAsync();
/// <summary>
/// 开始事务
/// </summary>
/// <returns></returns>
IDisposable Begin();
#endregion DbContextTransaction
#region Select/Get/Query
/// <summary>
/// Used to get a IQueryable that is used to retrieve entities from entire table.
/// </summary>
/// <returns>IQueryable to be used to select entities from database</returns>
IQueryable<TEntity> GetAll();
/// <summary>
/// Used to get a IQueryable that is used to retrieve entities from entire table.
/// One or more
/// </summary>
/// <param name="propertySelectors">A list of include expressions.</param>
/// <returns>IQueryable to be used to select entities from database</returns>
IQueryable<TEntity> GetAllIncluding(params Expression<Func<TEntity, object>>[] propertySelectors);
/// <summary>
/// Used to get all entities.
/// </summary>
/// <returns>List of all entities</returns>
Task<List<TEntity>> GetAllListAsync();
/// <summary>
/// Used to get all entities based on given <paramref name="predicate"/>.
/// </summary>
/// <param name="predicate">A condition to filter eGntities</param>
/// <returns>List of all entities</returns>
Task<List<TEntity>> GetAllListAsync(Expression<Func<TEntity, bool>> predicate);
/// <summary>
/// Used to run a query over entire entities.
/// <see cref.T="UnitOfWorkAttribute"/> attribute is not always necessary (as opposite to <see cref="GetAll"/>)
/// if <paramref name="queryMethod"/> finishes IQueryable with ToList, FirstOrDefault etc..
/// </summary>
/// <typeparam name="T">Type of return value of this method</typeparam>
/// <param name="queryMethod">This method is used to query over entities</param>
/// <returns>Query result</returns>
T Query<T>(Func<IQueryable<TEntity>, T> queryMethod);
/// <summary>
/// Gets an entity with given primary key.
/// </summary>
/// <param name="id">Primary key of the entity to get</param>
/// <returns>Entity</returns>
TEntity Get(TPrimaryKey id);
/// <summary>
/// Gets an entity with given primary key.
/// </summary>
/// <param name="id">Primary key of the entity to get</param>
/// <returns>Entity</returns>
Task<TEntity> GetAsync(TPrimaryKey id);
/// <summary>
/// Gets an entity with given primary key or null if not found.
/// </summary>
/// <param name="id">Primary key of the entity to get</param>
/// <returns>Entity or null</returns>
TEntity FirstOrDefault(TPrimaryKey id);
/// <summary>
/// Gets an entity with given primary key or null if not found.
/// </summary>
/// <param name="id">Primary key of the entity to get</param>
/// <returns>Entity or null</returns>
Task<TEntity> FirstOrDefaultAsync(TPrimaryKey id);
/// <summary>
/// Gets an entity with given given predicate or null if not found.
/// </summary>
/// <param name="predicate">Predicate to filter entities</param>
TEntity FirstOrDefault(Expression<Func<TEntity, bool>> predicate);
/// <summary>
/// Gets an entity with given given predicate or null if not found.
/// </summary>
/// <param name="predicate">Predicate to filter entities</param>
Task<TEntity> FirstOrDefaultAsync(Expression<Func<TEntity, bool>> predicate);
#endregion Select/Get/Query
#region Insert
/// <summary>
/// Inserts a new entity.
/// </summary>
/// <param name="entity">Inserted entity</param>
TEntity Insert(TEntity entity);
/// <summary>
/// Inserts a new entity.
/// </summary>
/// <param name="entity">Inserted entity</param>
Task<TEntity> InsertAsync(TEntity entity);
/// <summary>
/// Inserts a new entity and gets it's Id.
/// It may require to save current unit of work
/// to be able to retrieve id.
/// </summary>
/// <param name="entity">Entity</param>
/// <returns>Id of the entity</returns>
TPrimaryKey InsertAndGetId(TEntity entity);
/// <summary>
/// Inserts a new entity and gets it's Id.
/// It may require to save current unit of work
/// to be able to retrieve id.
/// </summary>
/// <param name="entity">Entity</param>
/// <returns>Id of the entity</returns>
Task<TPrimaryKey> InsertAndGetIdAsync(TEntity entity);
#endregion Insert
#region Update
/// <summary>
/// Updates an existing entity.
/// </summary>
/// <param name="entity">Entity</param>
TEntity Update(TEntity entity);
/// <summary>
/// Updates an existing entity.
/// </summary>
/// <param name="entity">Entity</param>
Task<TEntity> UpdateAsync(TEntity entity);
/// <summary>
/// Updates an existing entity.
/// </summary>
/// <param name="id">Id of the entity</param>
/// <param name="updateAction">Action that can be used to change values of the entity</param>
/// <returns>Updated entity</returns>
TEntity Update(TPrimaryKey id, Action<TEntity> updateAction);
/// <summary>
/// Updates an existing entity.
/// </summary>
/// <param name="id">Id of the entity</param>
/// <param name="updateAction">Action that can be used to change values of the entity</param>
/// <returns>Updated entity</returns>
Task<TEntity> UpdateAsync(TPrimaryKey id, Func<TEntity, Task> updateAction);
#endregion Update
#region Delete
/// <summary>
/// Deletes an entity.
/// </summary>
/// <param name="entity">Entity to be deleted</param>
void Delete(TEntity entity);
/// <summary>
/// Deletes an entity.
/// </summary>
/// <param name="entity">Entity to be deleted</param>
Task DeleteAsync(TEntity entity);
/// <summary>
/// Deletes an entity by primary key.
/// </summary>
/// <param name="id">Primary key of the entity</param>
void Delete(TPrimaryKey id);
/// <summary>
/// Deletes an entity by primary key.
/// </summary>
/// <param name="id">Primary key of the entity</param>
Task DeleteAsync(TPrimaryKey id);
/// <summary>
/// Deletes many entities by function.
/// Notice that: All entities fits to given predicate are retrieved and deleted.
/// This may cause major performance problems if there are too many entities with
/// given predicate.
/// </summary>
/// <param name="predicate">A condition to filter entities</param>
void Delete(Expression<Func<TEntity, bool>> predicate);
/// <summary>
/// Deletes many entities by function.
/// Notice that: All entities fits to given predicate are retrieved and deleted.
/// This may cause major performance problems if there are too many entities with
/// given predicate.
/// </summary>
/// <param name="predicate">A condition to filter entities</param>
Task DeleteAsync(Expression<Func<TEntity, bool>> predicate);
#endregion Delete
#region Aggregates
/// <summary>
/// Gets count of all entities in this repository.
/// </summary>
/// <returns>Count of entities</returns>
Task<int> CountAsync();
/// <summary>
/// Gets count of all entities in this repository based on given <paramref name="predicate"/>.
/// </summary>
/// <param name="predicate">A method to filter count</param>
/// <returns>Count of entities</returns>
Task<int> CountAsync(Expression<Func<TEntity, bool>> predicate);
/// <summary>
/// Gets count of all entities in this repository (use if expected return value is greather than <see cref="int.MaxValue"/>.
/// </summary>
/// <returns>Count of entities</returns>
Task<long> LongCountAsync();
/// <summary>
/// Gets count of all entities in this repository based on given <paramref name="predicate"/>
/// (use this overload if expected return value is greather than <see cref="int.MaxValue"/>).
/// </summary>
/// <param name="predicate">A method to filter count</param>
/// <returns>Count of entities</returns>
Task<long> LongCountAsync(Expression<Func<TEntity, bool>> predicate);
#endregion Aggregates
}
3:仓储接口实现:
public class Repository<TEntity, TPrimaryKey> : IUnitOfWork, IRepository<TEntity, TPrimaryKey>
where TEntity : class, IEntity<TPrimaryKey>
{
private readonly IDbContextProvider _dbContextProvider;
private DbContext dbContext => _dbContextProvider.GetDbContext();
private DbSet<TEntity> DbQueryTable => dbContext.Set<TEntity>();
private DbSet<TEntity> Table => dbContext.Set<TEntity>();
private IMediator mediator;
public Repository(IDbContextProvider dbContextProvider, IMediator _mediator)
{
_dbContextProvider = dbContextProvider;
mediator = _mediator;
}
#region Select
public virtual List<TEntity> GetAllList(Expression<Func<TEntity, bool>> predicate)
{
return GetAll().Where(predicate).ToList();
}
public IQueryable<TEntity> GetAll()
{
return GetAllIncluding();
}
public TEntity Get(TPrimaryKey id)
{
var entity = FirstOrDefault(id);
return entity;
}
public async Task<TEntity> GetAsync(TPrimaryKey id)
{
var entity = await FirstOrDefaultAsync(id);
return entity;
}
public IQueryable<TEntity> GetAllIncluding(params Expression<Func<TEntity, object>>[] propertySelectors)
{
var query = GetQueryable();
if (propertySelectors != null && propertySelectors.Length > 0)
{
foreach (var propertySelector in propertySelectors)
{
query = query.Include(propertySelector);
}
}
return query;
}
public async Task<List<TEntity>> GetAllListAsync()
{
return await GetAll().ToListAsync();
}
public async Task<List<TEntity>> GetAllListAsync(Expression<Func<TEntity, bool>> predicate)
{
return await GetAll().Where(predicate).ToListAsync();
}
public async Task<TEntity> SingleAsync(Expression<Func<TEntity, bool>> predicate)
{
return await GetAll().SingleAsync(predicate);
}
public TEntity FirstOrDefault(TPrimaryKey id)
{
return GetAll().FirstOrDefault(CreateEqualityExpressionForId(id));
}
public TEntity FirstOrDefault(Expression<Func<TEntity, bool>> predicate)
{
return GetAll().FirstOrDefault(predicate);
}
public async Task<TEntity> FirstOrDefaultAsync(TPrimaryKey id)
{
return await GetAll().FirstOrDefaultAsync(CreateEqualityExpressionForId(id));
}
public async Task<TEntity> FirstOrDefaultAsync(Expression<Func<TEntity, bool>> predicate)
{
return await GetAll().FirstOrDefaultAsync(predicate);
}
public T Query<T>(Func<IQueryable<TEntity>, T> queryMethod)
{
return queryMethod(GetAll());
}
#endregion Select
#region Insert
public TEntity Insert(TEntity entity)
{
return Table.Add(entity).Entity;
}
public Task<TEntity> InsertAsync(TEntity entity)
{
return Task.FromResult(Insert(entity));
}
public TPrimaryKey InsertAndGetId(TEntity entity)
{
entity = Insert(entity);
if (MayHaveTemporaryKey(entity) || entity.IsTransient())
{
dbContext.SaveChanges();
}
return entity.Id;
}
public async Task<TPrimaryKey> InsertAndGetIdAsync(TEntity entity)
{
entity = await InsertAsync(entity);
if (MayHaveTemporaryKey(entity) || entity.IsTransient())
{
await dbContext.SaveChangesAsync();
}
return entity.Id;
}
#endregion Insert
#region Update
public TEntity Update(TEntity entity)
{
AttachIfNot(entity);
dbContext.Entry(entity).State = EntityState.Modified;
return entity;
}
public Task<TEntity> UpdateAsync(TEntity entity)
{
entity = Update(entity);
return Task.FromResult(entity);
}
public TEntity Update(TPrimaryKey id, Action<TEntity> updateAction)
{
var entity = Get(id);
updateAction(entity);
return entity;
}
public async Task<TEntity> UpdateAsync(TPrimaryKey id, Func<TEntity, Task> updateAction)
{
var entity = await GetAsync(id);
await updateAction(entity);
return entity;
}
#endregion Update
#region Delete
public void Delete(TEntity entity)
{
AttachIfNot(entity);
Table.Remove(entity);
}
public Task DeleteAsync(TEntity entity)
{
Delete(entity);
return Task.CompletedTask;
}
public void Delete(TPrimaryKey id)
{
var entity = Table.FirstOrDefault(p =>
EqualityComparer<TPrimaryKey>.Default.Equals(id, (TPrimaryKey)p.Id)
); ;// GetFromChangeTrackerOrNull(id);
if (entity != null)
{
Delete(entity);
return;
}
entity = FirstOrDefault(id);
if (entity != null)
{
Delete(entity);
return;
}
//Could not found the entity, do nothing.
}
public Task DeleteAsync(TPrimaryKey id)
{
Delete(id);
return Task.CompletedTask;
}
public virtual void Delete(Expression<Func<TEntity, bool>> predicate)
{
foreach (var entity in GetAllList(predicate))
{
Delete(entity);
}
}
public virtual async Task DeleteAsync(Expression<Func<TEntity, bool>> predicate)
{
var entities = await GetAllListAsync(predicate);
foreach (var entity in entities)
{
await DeleteAsync(entity);
}
}
#endregion Delete
#region Other
public async Task<int> CountAsync()
{
return await GetAll().CountAsync();
}
public async Task<int> CountAsync(Expression<Func<TEntity, bool>> predicate)
{
return await GetAll().Where(predicate).CountAsync();
}
public async Task<long> LongCountAsync()
{
return await GetAll().LongCountAsync();
}
public async Task<long> LongCountAsync(Expression<Func<TEntity, bool>> predicate)
{
return await GetAll().Where(predicate).LongCountAsync();
}
#endregion Other
#region Helper
private IQueryable<TEntity> GetQueryable()
{
return DbQueryTable.AsQueryable();
}
private Expression<Func<TEntity, bool>> CreateEqualityExpressionForId(TPrimaryKey id)
{
var lambdaParam = Expression.Parameter(typeof(TEntity));
var leftExpression = Expression.PropertyOrField(lambdaParam, "Id");
var idValue = Convert.ChangeType(id, typeof(TPrimaryKey));
Expression<Func<object>> closure = () => idValue;
var rightExpression = Expression.Convert(closure.Body, leftExpression.Type);
var lambdaBody = Expression.Equal(leftExpression, rightExpression);
return Expression.Lambda<Func<TEntity, bool>>(lambdaBody, lambdaParam);
}
private void AttachIfNot(TEntity entity)
{
var entry = dbContext.ChangeTracker.Entries().FirstOrDefault(ent => ent.Entity == entity);
if (entry != null)
{
return;
}
Table.Attach(entity);
}
private TEntity GetFromChangeTrackerOrNull(TPrimaryKey id)
{
var entry = dbContext.ChangeTracker.Entries()
.FirstOrDefault(
ent =>
ent.Entity is TEntity && TPrimaryKeyEquals(id, ent.Entity as TEntity)
//ent.Entity is TEntity && TPrimaryKeyEquals(id, ent.Entity as TEntity)
);
return entry?.Entity as TEntity;
}
private bool TPrimaryKeyEquals(TPrimaryKey id, TEntity entity)
{
var entityfs = GetProperties(entity);
return EqualityComparer<TPrimaryKey>.Default.Equals(id, (TPrimaryKey)entityfs["Id"]);
}
/// <summary>
/// 反射得到实体类的字段名称和值
/// var dict = GetProperties(model);
/// </summary>
/// <typeparam name="T">实体类</typeparam>
/// <param name="t">实例化</param>
/// <returns></returns>
public static Dictionary<object, object> GetProperties<T>(T t)
{
var ret = new Dictionary<object, object>();
if (t == null) { return null; }
PropertyInfo[] properties = t.GetType().GetProperties(BindingFlags.Instance | BindingFlags.Public);
if (properties.Length <= 0) { return null; }
foreach (PropertyInfo item in properties)
{
string name = item.Name;
object value = item.GetValue(t, null);
if (item.PropertyType.IsValueType || item.PropertyType.Name.StartsWith("String"))
{
ret.Add(name, value);
}
}
return ret;
}
private static bool MayHaveTemporaryKey(TEntity entity)
{
if (typeof(TPrimaryKey) == typeof(byte))
{
return true;
}
if (typeof(TPrimaryKey) == typeof(string))
{
return string.IsNullOrEmpty(entity.Id.ToString());
}
if (typeof(TPrimaryKey) == typeof(int))
{
return Convert.ToInt32(entity.Id) <= 0;
}
if (typeof(TPrimaryKey) == typeof(long))
{
return Convert.ToInt64(entity.Id) <= 0;
}
return false;
}
#endregion Helper
#region DbContextTransaction
public IDisposable Begin()
{
return dbContext.Run(mediator);
}
public async Task<int> SaveAsync(CancellationToken cancellationToken = default)
{
await mediator.DispatchDomainEventsAsync(dbContext);
int result =await dbContext.SaveChangesAsync();
return result;
}
public Task RollbackTransactionAsync()
{
try
{
dbContext.Database.RollbackTransaction();
}
finally
{
if (dbContext != null)
{
dbContext.Dispose();
}
}
return Task.CompletedTask;
}
#endregion DbContextTransaction
}
4:代码中添加的有MediatR.INotification事件和事务:
执行事务
public class UnitOfWork<TDbContext> : DbContext, IDisposable
where TDbContext : DbContext
{
private IDbContextTransaction currentTransaction;
private IMediator mediator;
private readonly TDbContext dbContext;
public UnitOfWork(
TDbContext _dbContext,
IMediator _mediator
)
{
dbContext = _dbContext;
currentTransaction = DbContextTransactionStart;
mediator = _mediator;
}
public IDbContextTransaction DbContextTransactionStart => currentTransaction ?? dbContext.Database.BeginTransaction();
/// <summary>
/// 提交事务
/// </summary>
/// <returns></returns>
public void CommitTransactionAsync()
{
try
{
currentTransaction?.Commit();
}
catch
{
RollbackTransactionAsync();
throw;
}
}
/// <summary>
/// 回滚
/// </summary>
/// <returns></returns>
public void RollbackTransactionAsync()
{
try
{
currentTransaction?.Rollback();
}
catch (Exception ex)
{
throw ex;
}
}
protected virtual void Dispose(bool disposing)
{
if (disposing)
{
Task.Run(async () =>
{
await mediator.DispatchDomainEventsAsync(dbContext);
await dbContext.SaveChangesAsync();
CommitTransactionAsync();
}).Wait();
}
currentTransaction.Dispose();
dbContext.Dispose();
currentTransaction = null;
}
public override void Dispose()
{
Dispose(disposing: true);
}
}
依赖注入方式
private readonly IRepository<User, string> userRepository;
private readonly IRepository<Organization, int> organizationRepository;
public UserCommandHandler(
IRepository<User, string> _userRepository,
IRepository<Organization, int> _organizationRepository
)
{
userRepository = _userRepository;
organizationRepository = _organizationRepository;
}
事务调用方式
//事务
using (userRepository.Begin())
{
//事件
user.SetOrg();
//保存用户
await userRepository.InsertAsync(user);
}