1:不同的表主键类型可能不同(PS:string,int,long...)
需要定义一个基类,泛型主键, 每个数据库映射表都需要继承改类
/// <summary> /// BaseModel /// </summary> /// <typeparam name="TPrimaryKey">主键类型</typeparam> public abstract class Entity<TPrimaryKey> : DomainEntity, IEntity<TPrimaryKey> { /// <summary> /// 主键 /// </summary> [Key] public virtual TPrimaryKey Id { get; set; } /// <summary> /// Checks if this entity is transient (it has not an Id). /// </summary> /// <returns>True, if this entity is transient</returns> public virtual bool IsTransient() { if (EqualityComparer<TPrimaryKey>.Default.Equals(Id, default(TPrimaryKey))) { return true; } //Workaround for EF Core since it sets int/long to min value when attaching to dbcontext if (typeof(TPrimaryKey) == typeof(int)) { return Convert.ToInt32(Id) <= 0; } if (typeof(TPrimaryKey) == typeof(long)) { return Convert.ToInt64(Id) <= 0; } if (typeof(TPrimaryKey) == typeof(string)) { return string.IsNullOrEmpty(Id.ToString()); } return false; } }
2:仓储接口:
public interface IRepository<TEntity, TPrimaryKey> where TEntity : class, IEntity<TPrimaryKey> { #region DbContextTransaction /// <summary> /// 保存 /// </summary> /// <param name="cancellationToken"></param> /// <returns></returns> Task<int> SaveAsync(CancellationToken cancellationToken = default); /// <summary> /// 回滚 /// </summary> /// <returns></returns> Task RollbackTransactionAsync(); /// <summary> /// 开始事务 /// </summary> /// <returns></returns> IDisposable Begin(); #endregion DbContextTransaction #region Select/Get/Query /// <summary> /// Used to get a IQueryable that is used to retrieve entities from entire table. /// </summary> /// <returns>IQueryable to be used to select entities from database</returns> IQueryable<TEntity> GetAll(); /// <summary> /// Used to get a IQueryable that is used to retrieve entities from entire table. /// One or more /// </summary> /// <param name="propertySelectors">A list of include expressions.</param> /// <returns>IQueryable to be used to select entities from database</returns> IQueryable<TEntity> GetAllIncluding(params Expression<Func<TEntity, object>>[] propertySelectors); /// <summary> /// Used to get all entities. /// </summary> /// <returns>List of all entities</returns> Task<List<TEntity>> GetAllListAsync(); /// <summary> /// Used to get all entities based on given <paramref name="predicate"/>. /// </summary> /// <param name="predicate">A condition to filter eGntities</param> /// <returns>List of all entities</returns> Task<List<TEntity>> GetAllListAsync(Expression<Func<TEntity, bool>> predicate); /// <summary> /// Used to run a query over entire entities. /// <see cref.T="UnitOfWorkAttribute"/> attribute is not always necessary (as opposite to <see cref="GetAll"/>) /// if <paramref name="queryMethod"/> finishes IQueryable with ToList, FirstOrDefault etc.. /// </summary> /// <typeparam name="T">Type of return value of this method</typeparam> /// <param name="queryMethod">This method is used to query over entities</param> /// <returns>Query result</returns> T Query<T>(Func<IQueryable<TEntity>, T> queryMethod); /// <summary> /// Gets an entity with given primary key. /// </summary> /// <param name="id">Primary key of the entity to get</param> /// <returns>Entity</returns> TEntity Get(TPrimaryKey id); /// <summary> /// Gets an entity with given primary key. /// </summary> /// <param name="id">Primary key of the entity to get</param> /// <returns>Entity</returns> Task<TEntity> GetAsync(TPrimaryKey id); /// <summary> /// Gets an entity with given primary key or null if not found. /// </summary> /// <param name="id">Primary key of the entity to get</param> /// <returns>Entity or null</returns> TEntity FirstOrDefault(TPrimaryKey id); /// <summary> /// Gets an entity with given primary key or null if not found. /// </summary> /// <param name="id">Primary key of the entity to get</param> /// <returns>Entity or null</returns> Task<TEntity> FirstOrDefaultAsync(TPrimaryKey id); /// <summary> /// Gets an entity with given given predicate or null if not found. /// </summary> /// <param name="predicate">Predicate to filter entities</param> TEntity FirstOrDefault(Expression<Func<TEntity, bool>> predicate); /// <summary> /// Gets an entity with given given predicate or null if not found. /// </summary> /// <param name="predicate">Predicate to filter entities</param> Task<TEntity> FirstOrDefaultAsync(Expression<Func<TEntity, bool>> predicate); #endregion Select/Get/Query #region Insert /// <summary> /// Inserts a new entity. /// </summary> /// <param name="entity">Inserted entity</param> TEntity Insert(TEntity entity); /// <summary> /// Inserts a new entity. /// </summary> /// <param name="entity">Inserted entity</param> Task<TEntity> InsertAsync(TEntity entity); /// <summary> /// Inserts a new entity and gets it's Id. /// It may require to save current unit of work /// to be able to retrieve id. /// </summary> /// <param name="entity">Entity</param> /// <returns>Id of the entity</returns> TPrimaryKey InsertAndGetId(TEntity entity); /// <summary> /// Inserts a new entity and gets it's Id. /// It may require to save current unit of work /// to be able to retrieve id. /// </summary> /// <param name="entity">Entity</param> /// <returns>Id of the entity</returns> Task<TPrimaryKey> InsertAndGetIdAsync(TEntity entity); #endregion Insert #region Update /// <summary> /// Updates an existing entity. /// </summary> /// <param name="entity">Entity</param> TEntity Update(TEntity entity); /// <summary> /// Updates an existing entity. /// </summary> /// <param name="entity">Entity</param> Task<TEntity> UpdateAsync(TEntity entity); /// <summary> /// Updates an existing entity. /// </summary> /// <param name="id">Id of the entity</param> /// <param name="updateAction">Action that can be used to change values of the entity</param> /// <returns>Updated entity</returns> TEntity Update(TPrimaryKey id, Action<TEntity> updateAction); /// <summary> /// Updates an existing entity. /// </summary> /// <param name="id">Id of the entity</param> /// <param name="updateAction">Action that can be used to change values of the entity</param> /// <returns>Updated entity</returns> Task<TEntity> UpdateAsync(TPrimaryKey id, Func<TEntity, Task> updateAction); #endregion Update #region Delete /// <summary> /// Deletes an entity. /// </summary> /// <param name="entity">Entity to be deleted</param> void Delete(TEntity entity); /// <summary> /// Deletes an entity. /// </summary> /// <param name="entity">Entity to be deleted</param> Task DeleteAsync(TEntity entity); /// <summary> /// Deletes an entity by primary key. /// </summary> /// <param name="id">Primary key of the entity</param> void Delete(TPrimaryKey id); /// <summary> /// Deletes an entity by primary key. /// </summary> /// <param name="id">Primary key of the entity</param> Task DeleteAsync(TPrimaryKey id); /// <summary> /// Deletes many entities by function. /// Notice that: All entities fits to given predicate are retrieved and deleted. /// This may cause major performance problems if there are too many entities with /// given predicate. /// </summary> /// <param name="predicate">A condition to filter entities</param> void Delete(Expression<Func<TEntity, bool>> predicate); /// <summary> /// Deletes many entities by function. /// Notice that: All entities fits to given predicate are retrieved and deleted. /// This may cause major performance problems if there are too many entities with /// given predicate. /// </summary> /// <param name="predicate">A condition to filter entities</param> Task DeleteAsync(Expression<Func<TEntity, bool>> predicate); #endregion Delete #region Aggregates /// <summary> /// Gets count of all entities in this repository. /// </summary> /// <returns>Count of entities</returns> Task<int> CountAsync(); /// <summary> /// Gets count of all entities in this repository based on given <paramref name="predicate"/>. /// </summary> /// <param name="predicate">A method to filter count</param> /// <returns>Count of entities</returns> Task<int> CountAsync(Expression<Func<TEntity, bool>> predicate); /// <summary> /// Gets count of all entities in this repository (use if expected return value is greather than <see cref="int.MaxValue"/>. /// </summary> /// <returns>Count of entities</returns> Task<long> LongCountAsync(); /// <summary> /// Gets count of all entities in this repository based on given <paramref name="predicate"/> /// (use this overload if expected return value is greather than <see cref="int.MaxValue"/>). /// </summary> /// <param name="predicate">A method to filter count</param> /// <returns>Count of entities</returns> Task<long> LongCountAsync(Expression<Func<TEntity, bool>> predicate); #endregion Aggregates }
3:仓储接口实现:
public class Repository<TEntity, TPrimaryKey> : IUnitOfWork, IRepository<TEntity, TPrimaryKey> where TEntity : class, IEntity<TPrimaryKey> { private readonly IDbContextProvider _dbContextProvider; private DbContext dbContext => _dbContextProvider.GetDbContext(); private DbSet<TEntity> DbQueryTable => dbContext.Set<TEntity>(); private DbSet<TEntity> Table => dbContext.Set<TEntity>(); private IMediator mediator; public Repository(IDbContextProvider dbContextProvider, IMediator _mediator) { _dbContextProvider = dbContextProvider; mediator = _mediator; } #region Select public virtual List<TEntity> GetAllList(Expression<Func<TEntity, bool>> predicate) { return GetAll().Where(predicate).ToList(); } public IQueryable<TEntity> GetAll() { return GetAllIncluding(); } public TEntity Get(TPrimaryKey id) { var entity = FirstOrDefault(id); return entity; } public async Task<TEntity> GetAsync(TPrimaryKey id) { var entity = await FirstOrDefaultAsync(id); return entity; } public IQueryable<TEntity> GetAllIncluding(params Expression<Func<TEntity, object>>[] propertySelectors) { var query = GetQueryable(); if (propertySelectors != null && propertySelectors.Length > 0) { foreach (var propertySelector in propertySelectors) { query = query.Include(propertySelector); } } return query; } public async Task<List<TEntity>> GetAllListAsync() { return await GetAll().ToListAsync(); } public async Task<List<TEntity>> GetAllListAsync(Expression<Func<TEntity, bool>> predicate) { return await GetAll().Where(predicate).ToListAsync(); } public async Task<TEntity> SingleAsync(Expression<Func<TEntity, bool>> predicate) { return await GetAll().SingleAsync(predicate); } public TEntity FirstOrDefault(TPrimaryKey id) { return GetAll().FirstOrDefault(CreateEqualityExpressionForId(id)); } public TEntity FirstOrDefault(Expression<Func<TEntity, bool>> predicate) { return GetAll().FirstOrDefault(predicate); } public async Task<TEntity> FirstOrDefaultAsync(TPrimaryKey id) { return await GetAll().FirstOrDefaultAsync(CreateEqualityExpressionForId(id)); } public async Task<TEntity> FirstOrDefaultAsync(Expression<Func<TEntity, bool>> predicate) { return await GetAll().FirstOrDefaultAsync(predicate); } public T Query<T>(Func<IQueryable<TEntity>, T> queryMethod) { return queryMethod(GetAll()); } #endregion Select #region Insert public TEntity Insert(TEntity entity) { return Table.Add(entity).Entity; } public Task<TEntity> InsertAsync(TEntity entity) { return Task.FromResult(Insert(entity)); } public TPrimaryKey InsertAndGetId(TEntity entity) { entity = Insert(entity); if (MayHaveTemporaryKey(entity) || entity.IsTransient()) { dbContext.SaveChanges(); } return entity.Id; } public async Task<TPrimaryKey> InsertAndGetIdAsync(TEntity entity) { entity = await InsertAsync(entity); if (MayHaveTemporaryKey(entity) || entity.IsTransient()) { await dbContext.SaveChangesAsync(); } return entity.Id; } #endregion Insert #region Update public TEntity Update(TEntity entity) { AttachIfNot(entity); dbContext.Entry(entity).State = EntityState.Modified; return entity; } public Task<TEntity> UpdateAsync(TEntity entity) { entity = Update(entity); return Task.FromResult(entity); } public TEntity Update(TPrimaryKey id, Action<TEntity> updateAction) { var entity = Get(id); updateAction(entity); return entity; } public async Task<TEntity> UpdateAsync(TPrimaryKey id, Func<TEntity, Task> updateAction) { var entity = await GetAsync(id); await updateAction(entity); return entity; } #endregion Update #region Delete public void Delete(TEntity entity) { AttachIfNot(entity); Table.Remove(entity); } public Task DeleteAsync(TEntity entity) { Delete(entity); return Task.CompletedTask; } public void Delete(TPrimaryKey id) { var entity = Table.FirstOrDefault(p => EqualityComparer<TPrimaryKey>.Default.Equals(id, (TPrimaryKey)p.Id) ); ;// GetFromChangeTrackerOrNull(id); if (entity != null) { Delete(entity); return; } entity = FirstOrDefault(id); if (entity != null) { Delete(entity); return; } //Could not found the entity, do nothing. } public Task DeleteAsync(TPrimaryKey id) { Delete(id); return Task.CompletedTask; } public virtual void Delete(Expression<Func<TEntity, bool>> predicate) { foreach (var entity in GetAllList(predicate)) { Delete(entity); } } public virtual async Task DeleteAsync(Expression<Func<TEntity, bool>> predicate) { var entities = await GetAllListAsync(predicate); foreach (var entity in entities) { await DeleteAsync(entity); } } #endregion Delete #region Other public async Task<int> CountAsync() { return await GetAll().CountAsync(); } public async Task<int> CountAsync(Expression<Func<TEntity, bool>> predicate) { return await GetAll().Where(predicate).CountAsync(); } public async Task<long> LongCountAsync() { return await GetAll().LongCountAsync(); } public async Task<long> LongCountAsync(Expression<Func<TEntity, bool>> predicate) { return await GetAll().Where(predicate).LongCountAsync(); } #endregion Other #region Helper private IQueryable<TEntity> GetQueryable() { return DbQueryTable.AsQueryable(); } private Expression<Func<TEntity, bool>> CreateEqualityExpressionForId(TPrimaryKey id) { var lambdaParam = Expression.Parameter(typeof(TEntity)); var leftExpression = Expression.PropertyOrField(lambdaParam, "Id"); var idValue = Convert.ChangeType(id, typeof(TPrimaryKey)); Expression<Func<object>> closure = () => idValue; var rightExpression = Expression.Convert(closure.Body, leftExpression.Type); var lambdaBody = Expression.Equal(leftExpression, rightExpression); return Expression.Lambda<Func<TEntity, bool>>(lambdaBody, lambdaParam); } private void AttachIfNot(TEntity entity) { var entry = dbContext.ChangeTracker.Entries().FirstOrDefault(ent => ent.Entity == entity); if (entry != null) { return; } Table.Attach(entity); } private TEntity GetFromChangeTrackerOrNull(TPrimaryKey id) { var entry = dbContext.ChangeTracker.Entries() .FirstOrDefault( ent => ent.Entity is TEntity && TPrimaryKeyEquals(id, ent.Entity as TEntity) //ent.Entity is TEntity && TPrimaryKeyEquals(id, ent.Entity as TEntity) ); return entry?.Entity as TEntity; } private bool TPrimaryKeyEquals(TPrimaryKey id, TEntity entity) { var entityfs = GetProperties(entity); return EqualityComparer<TPrimaryKey>.Default.Equals(id, (TPrimaryKey)entityfs["Id"]); } /// <summary> /// 反射得到实体类的字段名称和值 /// var dict = GetProperties(model); /// </summary> /// <typeparam name="T">实体类</typeparam> /// <param name="t">实例化</param> /// <returns></returns> public static Dictionary<object, object> GetProperties<T>(T t) { var ret = new Dictionary<object, object>(); if (t == null) { return null; } PropertyInfo[] properties = t.GetType().GetProperties(BindingFlags.Instance | BindingFlags.Public); if (properties.Length <= 0) { return null; } foreach (PropertyInfo item in properties) { string name = item.Name; object value = item.GetValue(t, null); if (item.PropertyType.IsValueType || item.PropertyType.Name.StartsWith("String")) { ret.Add(name, value); } } return ret; } private static bool MayHaveTemporaryKey(TEntity entity) { if (typeof(TPrimaryKey) == typeof(byte)) { return true; } if (typeof(TPrimaryKey) == typeof(string)) { return string.IsNullOrEmpty(entity.Id.ToString()); } if (typeof(TPrimaryKey) == typeof(int)) { return Convert.ToInt32(entity.Id) <= 0; } if (typeof(TPrimaryKey) == typeof(long)) { return Convert.ToInt64(entity.Id) <= 0; } return false; } #endregion Helper #region DbContextTransaction public IDisposable Begin() { return dbContext.Run(mediator); } public async Task<int> SaveAsync(CancellationToken cancellationToken = default) { await mediator.DispatchDomainEventsAsync(dbContext); int result =await dbContext.SaveChangesAsync(); return result; } public Task RollbackTransactionAsync() { try { dbContext.Database.RollbackTransaction(); } finally { if (dbContext != null) { dbContext.Dispose(); } } return Task.CompletedTask; } #endregion DbContextTransaction }
4:代码中添加的有MediatR.INotification事件和事务:
执行事务
public class UnitOfWork<TDbContext> : DbContext, IDisposable where TDbContext : DbContext { private IDbContextTransaction currentTransaction; private IMediator mediator; private readonly TDbContext dbContext; public UnitOfWork( TDbContext _dbContext, IMediator _mediator ) { dbContext = _dbContext; currentTransaction = DbContextTransactionStart; mediator = _mediator; } public IDbContextTransaction DbContextTransactionStart => currentTransaction ?? dbContext.Database.BeginTransaction(); /// <summary> /// 提交事务 /// </summary> /// <returns></returns> public void CommitTransactionAsync() { try { currentTransaction?.Commit(); } catch { RollbackTransactionAsync(); throw; } } /// <summary> /// 回滚 /// </summary> /// <returns></returns> public void RollbackTransactionAsync() { try { currentTransaction?.Rollback(); } catch (Exception ex) { throw ex; } } protected virtual void Dispose(bool disposing) { if (disposing) { Task.Run(async () => { await mediator.DispatchDomainEventsAsync(dbContext); await dbContext.SaveChangesAsync(); CommitTransactionAsync(); }).Wait(); } currentTransaction.Dispose(); dbContext.Dispose(); currentTransaction = null; } public override void Dispose() { Dispose(disposing: true); } }
依赖注入方式
private readonly IRepository<User, string> userRepository; private readonly IRepository<Organization, int> organizationRepository; public UserCommandHandler( IRepository<User, string> _userRepository, IRepository<Organization, int> _organizationRepository ) { userRepository = _userRepository; organizationRepository = _organizationRepository; }
事务调用方式
//事务 using (userRepository.Begin()) { //事件 user.SetOrg(); //保存用户 await userRepository.InsertAsync(user); }