zoukankan      html  css  js  c++  java
  • ASP.NET Core搭建多层网站架构【4-工作单元和仓储设计】

    2020/01/28, ASP.NET Core 3.1, VS2019, Microsoft.EntityFrameworkCore.Relational 3.1.1

    摘要:基于ASP.NET Core 3.1 WebApi搭建后端多层网站架构【4-工作单元和仓储设计】
    使用泛型仓储(Repository)和工作单元(UnitOfWork)模式封装数据访问层基础的增删改查等方法

    文章目录

    此分支项目代码

    关于本章节的工作单元模式:
    泛型仓储封装了通用的增删改查方法,由工作单元统一管理仓储以保证数据库上下文一致性。
    要获取仓储,都从工作单元中获取,通过仓储改动数据库后,由工作单元进行提交。
    代码参考Arch/UnitOfWork的设计,大部分都是参考他的,然后做了一些中文注释,去除了分布式多库支持

    添加包引用

    MS.UnitOfWork项目添加对Microsoft.EntityFrameworkCore.Relational包的引用:

    <ItemGroup>
      <PackageReference Include="Microsoft.EntityFrameworkCore.Relational" Version="3.1.1" />
    </ItemGroup>
    

    分页处理封装

    MS.UnitOfWork项目中添加Collections文件夹,在该文件夹下添加IPagedList.csPagedList.csIEnumerablePagedListExtensions.csIQueryablePageListExtensions.cs类。

    IPagedList.cs

    using System.Collections.Generic;
    
    namespace MS.UnitOfWork.Collections
    {
        /// <summary>
        /// 提供任何类型的分页接口
        /// </summary>
        /// <typeparam name="T">需要分页的数据类型</typeparam>
        public interface IPagedList<T>
        {
            /// <summary>
            /// 起始页 值
            /// </summary>
            int IndexFrom { get; }
            /// <summary>
            /// 当前页 值
            /// </summary>
            int PageIndex { get; }
            /// <summary>
            /// 每页大小
            /// </summary>
            int PageSize { get; }
            /// <summary>
            /// 数据总数
            /// </summary>
            int TotalCount { get; }
            /// <summary>
            /// 总页数
            /// </summary>
            int TotalPages { get; }
            /// <summary>
            /// 当前页数据
            /// </summary>
            IList<T> Items { get; }
            /// <summary>
            /// 是否有上一页
            /// </summary>
            bool HasPreviousPage { get; }
            /// <summary>
            /// 是否有下一页
            /// </summary>
            bool HasNextPage { get; }
        }
    }
    
    

    PagedList.cs

    using System;
    using System.Collections.Generic;
    using System.Linq;
    
    namespace MS.UnitOfWork.Collections
    {
        /// <summary>
        /// 提供数据的分页,<see cref="IPagedList{T}"/>的默认实现
        /// </summary>
        /// <typeparam name="T"></typeparam>
        public class PagedList<T> : IPagedList<T>
        {
            /// <summary>
            /// 当前页 值
            /// </summary>
            public int PageIndex { get; set; }
            /// <summary>
            /// 每页大小
            /// </summary>
            public int PageSize { get; set; }
            /// <summary>
            /// 数据总数
            /// </summary>
            public int TotalCount { get; set; }
            /// <summary>
            /// 总页数
            /// </summary>
            public int TotalPages { get; set; }
            /// <summary>
            /// 起始页 值
            /// </summary>
            public int IndexFrom { get; set; }
            /// <summary>
            /// 当前页数据
            /// </summary>
            public IList<T> Items { get; set; }
            /// <summary>
            /// 是否有上一页
            /// </summary>
            public bool HasPreviousPage => PageIndex - IndexFrom > 0;
            /// <summary>
            /// 是否有下一页
            /// </summary>
            public bool HasNextPage => PageIndex - IndexFrom + 1 < TotalPages;
    
            /// <summary>
            /// 初始化实例
            /// </summary>
            /// <param name="source">The source.</param>
            /// <param name="pageIndex">The index of the page.</param>
            /// <param name="pageSize">The size of the page.</param>
            /// <param name="indexFrom">The index from.</param>
            internal PagedList(IEnumerable<T> source, int pageIndex, int pageSize, int indexFrom)
            {
                if (indexFrom > pageIndex)
                {
                    throw new ArgumentException($"indexFrom: {indexFrom} > pageIndex: {pageIndex},起始页必须小于等于当前页");
                }
    
                if (source is IQueryable<T> querable)
                {
                    PageIndex = pageIndex;
                    PageSize = pageSize;
                    IndexFrom = indexFrom;
                    TotalCount = querable.Count();
                    TotalPages = (int)Math.Ceiling(TotalCount / (double)PageSize);
    
                    Items = querable.Skip((PageIndex - IndexFrom) * PageSize).Take(PageSize).ToList();
                }
                else
                {
                    PageIndex = pageIndex;
                    PageSize = pageSize;
                    IndexFrom = indexFrom;
                    TotalCount = source.Count();
                    TotalPages = (int)Math.Ceiling(TotalCount / (double)PageSize);
    
                    Items = source.Skip((PageIndex - IndexFrom) * PageSize).Take(PageSize).ToList();
                }
            }
    
            /// <summary>
            /// Initializes a new instance of the <see cref="PagedList{T}" /> class.
            /// </summary>
            internal PagedList() => Items = new T[0];
        }
    
    
        /// <summary>
        /// 提供数据的分页,并支持数据类型转换
        /// </summary>
        /// <typeparam name="TSource">数据源类型</typeparam>
        /// <typeparam name="TResult">输出数据类型</typeparam>
        internal class PagedList<TSource, TResult> : IPagedList<TResult>
        {
            /// <summary>
            /// 当前页 值
            /// </summary>
            public int PageIndex { get; set; }
            /// <summary>
            /// 每页大小
            /// </summary>
            public int PageSize { get; set; }
            /// <summary>
            /// 数据总数
            /// </summary>
            public int TotalCount { get; set; }
            /// <summary>
            /// 总页数
            /// </summary>
            public int TotalPages { get; set; }
            /// <summary>
            /// 起始页 值
            /// </summary>
            public int IndexFrom { get; set; }
            /// <summary>
            /// 当前页数据
            /// </summary>
            public IList<TResult> Items { get; set; }
            /// <summary>
            /// 是否有上一页
            /// </summary>
            public bool HasPreviousPage => PageIndex - IndexFrom > 0;
            /// <summary>
            /// 是否有下一页
            /// </summary>
            public bool HasNextPage => PageIndex - IndexFrom + 1 < TotalPages;
    
    
            /// <summary>
            /// 初始化实例
            /// </summary>
            /// <param name="source">The source.</param>
            /// <param name="converter">The converter.</param>
            /// <param name="pageIndex">The index of the page.</param>
            /// <param name="pageSize">The size of the page.</param>
            /// <param name="indexFrom">The index from.</param>
            public PagedList(IEnumerable<TSource> source, Func<IEnumerable<TSource>, IEnumerable<TResult>> converter, int pageIndex, int pageSize, int indexFrom)
            {
                if (indexFrom > pageIndex)
                {
                    throw new ArgumentException($"indexFrom: {indexFrom} > pageIndex: {pageIndex},起始页必须小于等于当前页");
                }
    
                if (source is IQueryable<TSource> querable)
                {
                    PageIndex = pageIndex;
                    PageSize = pageSize;
                    IndexFrom = indexFrom;
                    TotalCount = querable.Count();
                    TotalPages = (int)Math.Ceiling(TotalCount / (double)PageSize);
    
                    var items = querable.Skip((PageIndex - IndexFrom) * PageSize).Take(PageSize).ToArray();
    
                    Items = new List<TResult>(converter(items));
                }
                else
                {
                    PageIndex = pageIndex;
                    PageSize = pageSize;
                    IndexFrom = indexFrom;
                    TotalCount = source.Count();
                    TotalPages = (int)Math.Ceiling(TotalCount / (double)PageSize);
    
                    var items = source.Skip((PageIndex - IndexFrom) * PageSize).Take(PageSize).ToArray();
    
                    Items = new List<TResult>(converter(items));
                }
            }
    
            /// <summary>
            /// Initializes a new instance of the <see cref="PagedList{TSource, TResult}" /> class.
            /// </summary>
            /// <param name="source">The source.</param>
            /// <param name="converter">The converter.</param>
            public PagedList(IPagedList<TSource> source, Func<IEnumerable<TSource>, IEnumerable<TResult>> converter)
            {
                PageIndex = source.PageIndex;
                PageSize = source.PageSize;
                IndexFrom = source.IndexFrom;
                TotalCount = source.TotalCount;
                TotalPages = source.TotalPages;
    
                Items = new List<TResult>(converter(source.Items));
            }
        }
    
        /// <summary>
        /// Provides some help methods for <see cref="IPagedList{T}"/> interface.
        /// </summary>
        public static class PagedList
        {
            /// <summary>
            /// Creates an empty of <see cref="IPagedList{T}"/>.
            /// </summary>
            /// <typeparam name="T">The type for paging </typeparam>
            /// <returns>An empty instance of <see cref="IPagedList{T}"/>.</returns>
            public static IPagedList<T> Empty<T>() => new PagedList<T>();
            /// <summary>
            /// Creates a new instance of <see cref="IPagedList{TResult}"/> from source of <see cref="IPagedList{TSource}"/> instance.
            /// </summary>
            /// <typeparam name="TResult">The type of the result.</typeparam>
            /// <typeparam name="TSource">The type of the source.</typeparam>
            /// <param name="source">The source.</param>
            /// <param name="converter">The converter.</param>
            /// <returns>An instance of <see cref="IPagedList{TResult}"/>.</returns>
            public static IPagedList<TResult> From<TResult, TSource>(IPagedList<TSource> source, Func<IEnumerable<TSource>, IEnumerable<TResult>> converter) => new PagedList<TSource, TResult>(source, converter);
        }
    }
    
    

    IEnumerablePagedListExtensions.cs

    using System;
    using System.Collections.Generic;
    
    namespace MS.UnitOfWork.Collections
    {
        /// <summary>
        /// 给<see cref="IEnumerable{T}"/>添加扩展方法来支持分页
        /// </summary>
        public static class IEnumerablePagedListExtensions
        {
            /// <summary>
            /// 在数据中取得固定页的数据
            /// </summary>
            /// <typeparam name="T">数据类型</typeparam>
            /// <param name="source">数据源</param>
            /// <param name="pageIndex">当前页</param>
            /// <param name="pageSize">页大小</param>
            /// <param name="indexFrom">起始页</param>
            /// <returns></returns>
            public static IPagedList<T> ToPagedList<T>(this IEnumerable<T> source, int pageIndex, int pageSize, int indexFrom = 1) => new PagedList<T>(source, pageIndex, pageSize, indexFrom);
    
            /// <summary>
            /// 在数据中取得固定页数据,并转换为指定数据类型
            /// </summary>
            /// <typeparam name="TSource">数据源类型</typeparam>
            /// <typeparam name="TResult">输出数据类型</typeparam>
            /// <param name="source">数据源</param>
            /// <param name="converter"></param>
            /// <param name="pageIndex">当前页</param>
            /// <param name="pageSize">页大小</param>
            /// <param name="indexFrom">起始页</param>
            /// <returns></returns>
            public static IPagedList<TResult> ToPagedList<TSource, TResult>(this IEnumerable<TSource> source, Func<IEnumerable<TSource>, IEnumerable<TResult>> converter, int pageIndex, int pageSize, int indexFrom = 1) => new PagedList<TSource, TResult>(source, converter, pageIndex, pageSize, indexFrom);
        }
    }
    
    

    IQueryablePageListExtensions.cs

    using Microsoft.EntityFrameworkCore;
    using System;
    using System.Linq;
    using System.Threading;
    using System.Threading.Tasks;
    
    namespace MS.UnitOfWork.Collections
    {
        public static class IQueryablePageListExtensions
        {
            /// <summary>
            /// 在数据中取得固定页的数据(异步操作)
            /// </summary>
            /// <typeparam name="T">数据类型</typeparam>
            /// <param name="source">数据源</param>
            /// <param name="pageIndex">当前页</param>
            /// <param name="pageSize">页大小</param>
            /// <param name="indexFrom">起始页</param>
            /// <param name="cancellationToken">异步观察参数</param>
            /// <returns></returns>
            public static async Task<IPagedList<T>> ToPagedListAsync<T>(this IQueryable<T> source, int pageIndex, int pageSize, int indexFrom = 1, CancellationToken cancellationToken = default(CancellationToken))
            {
                if (indexFrom > pageIndex)
                {
                    throw new ArgumentException($"indexFrom: {indexFrom} > pageIndex: {pageIndex}, must indexFrom <= pageIndex");
                }
    
                var count = await source.CountAsync(cancellationToken).ConfigureAwait(false);
                var items = await source.Skip((pageIndex - indexFrom) * pageSize)
                                        .Take(pageSize).ToListAsync(cancellationToken).ConfigureAwait(false);
    
                var pagedList = new PagedList<T>()
                {
                    PageIndex = pageIndex,
                    PageSize = pageSize,
                    IndexFrom = indexFrom,
                    TotalCount = count,
                    Items = items,
                    TotalPages = (int)Math.Ceiling(count / (double)pageSize)
                };
    
                return pagedList;
            }
        }
    }
    
    

    针对IQueryable、IEnumerable类型的数据做了分页扩展方法封装,主要用于向数据库获取数据时进行分页筛选

    泛型仓储

    MS.UnitOfWork项目中添加Repository文件夹,在该文件夹下添加IRepository.csRepository.cs类。

    IRepository.cs

    using MS.UnitOfWork.Collections;
    using Microsoft.EntityFrameworkCore.ChangeTracking;
    using Microsoft.EntityFrameworkCore.Query;
    using System;
    using System.Collections.Generic;
    using System.Linq;
    using System.Linq.Expressions;
    using System.Threading;
    using System.Threading.Tasks;
    
    namespace MS.UnitOfWork
    {
        /// <summary>
        /// 通用仓储接口
        /// </summary>
        /// <typeparam name="TEntity"></typeparam>
        public interface IRepository<TEntity> where TEntity : class
        {
            #region GetAll
            /// <summary>
            ///获取所有实体
            ///注意性能!
            /// </summary>
            /// <returns>The <see cref="IQueryable{TEntity}"/>.</returns>
            IQueryable<TEntity> GetAll();
    
            /// <summary>
            /// 获取所有实体
            /// </summary>
            /// <param name="predicate">条件表达式</param>
            /// <param name="orderBy">排序</param>
            /// <param name="include">包含的导航属性</param>
            /// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param>
            /// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param>
            /// <returns></returns>
            IQueryable<TEntity> GetAll(
                Expression<Func<TEntity, bool>> predicate = null,
                Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
                Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
                bool disableTracking = true,
                bool ignoreQueryFilters = false);
    
            /// <summary>
            /// 获取所有实体,必须提供筛选谓词
            /// </summary>
            /// <typeparam name="TResult">输出数据类型</typeparam>
            /// <param name="selector">投影选择器</param>
            /// <param name="predicate">筛选谓词</param>
            /// <param name="orderBy">排序</param>
            /// <param name="include">包含的导航属性</param>
            /// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param>
            /// <returns></returns>
            IQueryable<TResult> GetAll<TResult>(
                Expression<Func<TEntity, TResult>> selector,
                Expression<Func<TEntity, bool>> predicate = null,
                Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
                Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
                bool disableTracking = true,
                bool ignoreQueryFilters = false
                ) where TResult : class;
    
            /// <summary>
            /// 获取所有实体
            /// </summary>
            /// <param name="predicate">条件表达式</param>
            /// <param name="orderBy">排序</param>
            /// <param name="include">包含的导航属性</param>
            /// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param>
            /// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param>
            /// <returns></returns>
            Task<IList<TEntity>> GetAllAsync(
                Expression<Func<TEntity, bool>> predicate = null,
                Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
                Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
                bool disableTracking = true,
                bool ignoreQueryFilters = false);
            #endregion
    
            #region GetPagedList
            /// <summary>
            /// 获取分页数据
            /// 默认是关闭追踪查询的(拿到的数据默认只读)
            /// 默认开启全局查询筛选过滤
            /// </summary>
            /// <param name="predicate">条件表达式</param>
            /// <param name="orderBy">排序</param>
            /// <param name="include">包含的导航属性</param>
            /// <param name="pageIndex">当前页。默认第一页</param>
            /// <param name="pageSize">页大小。默认20笔数据</param>
            /// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param>
            /// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param>
            /// <returns></returns>
            IPagedList<TEntity> GetPagedList(
                Expression<Func<TEntity, bool>> predicate = null,
                Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
                Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
                int pageIndex = 1,
                int pageSize = 20,
                bool disableTracking = true,
                bool ignoreQueryFilters = false);
    
            /// <summary>
            /// 获取分页数据
            /// 默认是关闭追踪查询的(拿到的数据默认只读)
            /// 默认开启全局查询筛选过滤
            /// </summary>
            /// <param name="predicate">条件表达式</param>
            /// <param name="orderBy">排序</param>
            /// <param name="include">包含的导航属性</param>
            /// <param name="pageIndex">当前页。默认第一页</param>
            /// <param name="pageSize">页大小。默认20笔数据</param>
            /// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param>
            /// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param>
            /// <param name="cancellationToken">异步token</param>
            /// <returns></returns>
            Task<IPagedList<TEntity>> GetPagedListAsync(
                Expression<Func<TEntity, bool>> predicate = null,
                Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
                Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
                int pageIndex = 1,
                int pageSize = 20,
                bool disableTracking = true,
                bool ignoreQueryFilters = false,
                CancellationToken cancellationToken = default);
    
            /// <summary>
            /// 获取分页数据
            /// 默认是关闭追踪查询的(拿到的数据默认只读)
            /// 默认开启全局查询筛选过滤
            /// </summary>
            /// <typeparam name="TResult">输出数据类型</typeparam>
            /// <param name="selector">投影选择器</param>
            /// <param name="predicate">条件表达式</param>
            /// <param name="orderBy">排序</param>
            /// <param name="include">包含的导航属性</param>
            /// <param name="pageIndex">当前页。默认第一页</param>
            /// <param name="pageSize">页大小。默认20笔数据</param>
            /// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param>
            /// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param>
            /// <returns></returns>
            IPagedList<TResult> GetPagedList<TResult>(
                Expression<Func<TEntity, TResult>> selector,
                Expression<Func<TEntity, bool>> predicate = null,
                Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
                Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
                int pageIndex = 1,
                int pageSize = 20,
                bool disableTracking = true,
                bool ignoreQueryFilters = false
                ) where TResult : class;
    
    
            /// <summary>
            /// 获取分页数据
            /// 默认是关闭追踪查询的(拿到的数据默认只读)
            /// 默认开启全局查询筛选过滤
            /// </summary>
            /// <typeparam name="TResult">输出数据类型</typeparam>
            /// <param name="selector">投影选择器</param>
            /// <param name="predicate">条件表达式</param>
            /// <param name="orderBy">排序</param>
            /// <param name="include">包含的导航属性</param>
            /// <param name="pageIndex">当前页。默认第一页</param>
            /// <param name="pageSize">页大小。默认20笔数据</param>
            /// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param>
            /// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param>
            /// <param name="cancellationToken">异步token</param>
            /// <returns></returns>
            Task<IPagedList<TResult>> GetPagedListAsync<TResult>(
                Expression<Func<TEntity, TResult>> selector,
                Expression<Func<TEntity, bool>> predicate = null,
                Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
                Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
                int pageIndex = 1,
                int pageSize = 20,
                bool disableTracking = true,
                bool ignoreQueryFilters = false,
                CancellationToken cancellationToken = default) where TResult : class;
    
            #endregion
    
            #region GetFirstOrDefault
            /// <summary>
            /// 获取满足条件的序列中的第一个元素
            /// 如果没有元素满足条件,则返回默认值
            /// 默认是关闭追踪查询的(拿到的数据默认只读)
            /// 默认开启全局查询筛选过滤
            /// </summary>
            /// <param name="predicate">条件表达式</param>
            /// <param name="orderBy">排序</param>
            /// <param name="include">包含的导航属性</param>
            /// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param>
            /// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param>
            /// <returns></returns>
            TEntity GetFirstOrDefault(
                Expression<Func<TEntity, bool>> predicate = null,
                Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
                Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
                bool disableTracking = true,
                bool ignoreQueryFilters = false);
    
            /// <summary>
            /// 获取满足条件的序列中的第一个元素
            /// 如果没有元素满足条件,则返回默认值
            /// 默认是关闭追踪查询的(拿到的数据默认只读)
            /// 默认开启全局查询筛选过滤
            /// </summary>
            /// <param name="predicate">条件表达式</param>
            /// <param name="orderBy">排序</param>
            /// <param name="include">包含的导航属性</param>
            /// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param>
            /// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param>
            /// <param name="cancellationToken">异步token</param>
            /// <returns></returns>
            Task<TEntity> GetFirstOrDefaultAsync(
                Expression<Func<TEntity, bool>> predicate = null,
                Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
                Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
                bool disableTracking = true,
                bool ignoreQueryFilters = false,
                CancellationToken cancellationToken = default);
    
            /// <summary>
            /// 获取满足条件的序列中的第一个元素
            /// 如果没有元素满足条件,则返回默认值
            /// 默认是关闭追踪查询的(拿到的数据默认只读)
            /// 默认开启全局查询筛选过滤
            /// </summary>
            /// <typeparam name="TResult">输出数据类型</typeparam>
            /// <param name="selector">投影选择器</param>
            /// <param name="predicate">条件表达式</param>
            /// <param name="orderBy">排序</param>
            /// <param name="include">包含的导航属性</param>
            /// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param>
            /// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param>
            /// <returns></returns>
            TResult GetFirstOrDefault<TResult>(
                Expression<Func<TEntity, TResult>> selector,
                Expression<Func<TEntity, bool>> predicate = null,
                Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
                Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
                bool disableTracking = true,
                bool ignoreQueryFilters = false);
    
            /// <summary>
            /// 获取满足条件的序列中的第一个元素
            /// 如果没有元素满足条件,则返回默认值
            /// 默认是关闭追踪查询的(拿到的数据默认只读)
            /// 默认开启全局查询筛选过滤
            /// </summary>
            /// <typeparam name="TResult">输出数据类型</typeparam>
            /// <param name="selector">投影选择器</param>
            /// <param name="predicate">条件表达式</param>
            /// <param name="orderBy">排序</param>
            /// <param name="include">包含的导航属性</param>
            /// <param name="disableTracking">设置为true关闭追踪查询。默认为true</param>
            /// <param name="ignoreQueryFilters">设置为true忽略全局查询筛选过滤。默认为false</param>
            /// <param name="cancellationToken">异步token</param>
            /// <returns></returns>
            Task<TResult> GetFirstOrDefaultAsync<TResult>(
                Expression<Func<TEntity, TResult>> selector,
                Expression<Func<TEntity, bool>> predicate = null,
                Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
                Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
                bool disableTracking = true,
                bool ignoreQueryFilters = false,
                CancellationToken cancellationToken = default);
    
    
            #endregion
    
            #region Find
            /// <summary>
            /// Finds an entity with the given primary key values. If found, is attached to the context and returned. If no entity is found, then null is returned.
            /// </summary>
            /// <param name="keyValues">The values of the primary key for the entity to be found.</param>
            /// <returns>The found entity or null.</returns>
            TEntity Find(params object[] keyValues);
    
            /// <summary>
            /// Finds an entity with the given primary key values. If found, is attached to the context and returned. If no entity is found, then null is returned.
            /// </summary>
            /// <param name="keyValues">The values of the primary key for the entity to be found.</param>
            /// <returns>A <see cref="Task{TEntity}"/> that represents the asynchronous find operation. The task result contains the found entity or null.</returns>
            ValueTask<TEntity> FindAsync(params object[] keyValues);
    
            /// <summary>
            /// Finds an entity with the given primary key values. If found, is attached to the context and returned. If no entity is found, then null is returned.
            /// </summary>
            /// <param name="keyValues">The values of the primary key for the entity to be found.</param>
            /// <param name="cancellationToken">A <see cref="CancellationToken"/> to observe while waiting for the task to complete.</param>
            /// <returns>A <see cref="Task{TEntity}"/> that represents the asynchronous find operation. The task result contains the found entity or null.</returns>
            ValueTask<TEntity> FindAsync(object[] keyValues, CancellationToken cancellationToken);
            #endregion
    
            #region sql、count、exist
            /// <summary>
            /// 使用原生sql查询来获取指定数据
            /// </summary>
            /// <param name="sql"></param>
            /// <param name="parameters"></param>
            /// <returns></returns>
            IQueryable<TEntity> FromSql(string sql, params object[] parameters);
    
            /// <summary>
            /// 查询数量
            /// </summary>
            /// <param name="predicate"></param>
            /// <returns></returns>
            int Count(Expression<Func<TEntity, bool>> predicate = null);
    
            /// <summary>
            /// 查询数量
            /// </summary>
            /// <param name="predicate"></param>
            /// <returns></returns>
            Task<int> CountAsync(Expression<Func<TEntity, bool>> predicate = null);
    
            /// <summary>
            /// 按指定条件元素是否存在
            /// </summary>
            /// <param name="predicate"></param>
            /// <returns></returns>
            bool Exists(Expression<Func<TEntity, bool>> predicate = null);
            #endregion
    
            #region Insert
            /// <summary>
            /// Inserts a new entity synchronously.
            /// </summary>
            /// <param name="entity"></param>
            /// <returns></returns>
            TEntity Insert(TEntity entity);
    
            /// <summary>
            /// Inserts a range of entities synchronously.
            /// </summary>
            /// <param name="entities">The entities to insert.</param>
            void Insert(params TEntity[] entities);
    
    
            /// <summary>
            /// Inserts a range of entities synchronously.
            /// </summary>
            /// <param name="entities">The entities to insert.</param>
            void Insert(IEnumerable<TEntity> entities);
    
    
            /// <summary>
            /// Inserts a new entity asynchronously.
            /// </summary>
            /// <param name="entity">The entity to insert.</param>
            /// <param name="cancellationToken">A <see cref="CancellationToken"/> to observe while waiting for the task to complete.</param>
            /// <returns>A <see cref="Task"/> that represents the asynchronous insert operation.</returns>
            ValueTask<EntityEntry<TEntity>> InsertAsync(TEntity entity, CancellationToken cancellationToken = default);
    
            /// <summary>
            /// Inserts a range of entities asynchronously.
            /// </summary>
            /// <param name="entities">The entities to insert.</param>
            /// <returns>A <see cref="Task"/> that represents the asynchronous insert operation.</returns>
            Task InsertAsync(params TEntity[] entities);
    
            /// <summary>
            /// Inserts a range of entities asynchronously.
            /// </summary>
            /// <param name="entities">The entities to insert.</param>
            /// <param name="cancellationToken">A <see cref="CancellationToken"/> to observe while waiting for the task to complete.</param>
            /// <returns>A <see cref="Task"/> that represents the asynchronous insert operation.</returns>
            Task InsertAsync(IEnumerable<TEntity> entities, CancellationToken cancellationToken = default);
            #endregion
    
            #region Update
            /// <summary>
            /// Updates the specified entity.
            /// </summary>
            /// <param name="entity">The entity.</param>
            void Update(TEntity entity);
    
            /// <summary>
            /// Updates the specified entities.
            /// </summary>
            /// <param name="entities">The entities.</param>
            void Update(params TEntity[] entities);
    
            /// <summary>
            /// Updates the specified entities.
            /// </summary>
            /// <param name="entities">The entities.</param>
            void Update(IEnumerable<TEntity> entities);
            #endregion
    
            #region Delete
            /// <summary>
            /// Deletes the entity by the specified primary key.
            /// </summary>
            /// <param name="id">The primary key value.</param>
            void Delete(object id);
    
            /// <summary>
            /// Deletes the specified entity.
            /// </summary>
            /// <param name="entity">The entity to delete.</param>
            void Delete(TEntity entity);
    
            /// <summary>
            /// Deletes the specified entities.
            /// </summary>
            /// <param name="entities">The entities.</param>
            void Delete(params TEntity[] entities);
    
            /// <summary>
            /// Deletes the specified entities.
            /// </summary>
            /// <param name="entities">The entities.</param>
            void Delete(IEnumerable<TEntity> entities);
            #endregion
        }
    }
    
    

    Repository.cs

    using MS.UnitOfWork.Collections;
    using Microsoft.EntityFrameworkCore;
    using Microsoft.EntityFrameworkCore.ChangeTracking;
    using Microsoft.EntityFrameworkCore.Query;
    using System;
    using System.Collections.Generic;
    using System.Linq;
    using System.Linq.Expressions;
    using System.Reflection;
    using System.Threading;
    using System.Threading.Tasks;
    
    namespace MS.UnitOfWork
    {
        /// <summary>
        /// 通用仓储的默认实现
        /// </summary>
        /// <typeparam name="TEntity"></typeparam>
        public class Repository<TEntity> : IRepository<TEntity> where TEntity : class
        {
            protected readonly DbContext _dbContext;
            protected readonly DbSet<TEntity> _dbSet;
    
            public Repository(DbContext dbContext)
            {
                _dbContext = dbContext ?? throw new ArgumentNullException(nameof(dbContext));
                _dbSet = _dbContext.Set<TEntity>();
            }
    
            #region GetAll
            public IQueryable<TEntity> GetAll() => _dbSet;
    
            public IQueryable<TEntity> GetAll(
                Expression<Func<TEntity, bool>> predicate = null,
                Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
                Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
                bool disableTracking = true,
                bool ignoreQueryFilters = false)
            {
                IQueryable<TEntity> query = _dbSet;
                if (disableTracking)
                {
                    query = query.AsNoTracking();
                }
    
                if (include != null)
                {
                    query = include(query);
                }
    
                if (predicate != null)
                {
                    query = query.Where(predicate);
                }
    
                if (ignoreQueryFilters)
                {
                    query = query.IgnoreQueryFilters();
                }
    
                if (orderBy != null)
                {
                    return orderBy(query);
                }
                else
                {
                    return query;
                }
            }
    
            public IQueryable<TResult> GetAll<TResult>(
                Expression<Func<TEntity, TResult>> selector,
                Expression<Func<TEntity, bool>> predicate,
                Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
                Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
                bool disableTracking = true,
                bool ignoreQueryFilters = false) where TResult : class
            {
                IQueryable<TEntity> query = _dbSet;
                if (disableTracking)
                {
                    query = query.AsNoTracking();
                }
    
                if (include != null)
                {
                    query = include(query);
                }
    
                if (predicate != null)
                {
                    query = query.Where(predicate);
                }
    
                if (ignoreQueryFilters)
                {
                    query = query.IgnoreQueryFilters();
                }
    
                if (orderBy != null)
                {
                    return orderBy(query).Select(selector);
                }
                else
                {
                    return query.Select(selector);
                }
            }
    
            public async Task<IList<TEntity>> GetAllAsync(Expression<Func<TEntity, bool>> predicate = null, Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null, Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null, bool disableTracking = true, bool ignoreQueryFilters = false)
            {
                IQueryable<TEntity> query = _dbSet;
    
                if (disableTracking)
                {
                    query = query.AsNoTracking();
                }
    
                if (include != null)
                {
                    query = include(query);
                }
    
                if (predicate != null)
                {
                    query = query.Where(predicate);
                }
    
                if (ignoreQueryFilters)
                {
                    query = query.IgnoreQueryFilters();
                }
    
                if (orderBy != null)
                {
                    return await orderBy(query).ToListAsync();
                }
                else
                {
                    return await query.ToListAsync();
                }
            }
            #endregion
    
            #region GetPagedList
            public virtual IPagedList<TEntity> GetPagedList(
                Expression<Func<TEntity, bool>> predicate = null,
                Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
                Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
                int pageIndex = 1,
                int pageSize = 20,
                bool disableTracking = true,
                bool ignoreQueryFilters = false)
            {
                IQueryable<TEntity> query = _dbSet;
                if (disableTracking)
                {
                    query = query.AsNoTracking();
                }
    
                if (include != null)
                {
                    query = include(query);
                }
    
                if (predicate != null)
                {
                    query = query.Where(predicate);
                }
    
                if (ignoreQueryFilters)
                {
                    query = query.IgnoreQueryFilters();
                }
    
                if (orderBy != null)
                {
                    return orderBy(query).ToPagedList(pageIndex, pageSize);
                }
                else
                {
                    return query.ToPagedList(pageIndex, pageSize);
                }
            }
    
            public virtual async Task<IPagedList<TEntity>> GetPagedListAsync(
                Expression<Func<TEntity, bool>> predicate = null,
                Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
                Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
                int pageIndex = 1,
                int pageSize = 20,
                bool disableTracking = true,
                bool ignoreQueryFilters = false,
                CancellationToken cancellationToken = default)
            {
                IQueryable<TEntity> query = _dbSet;
                if (disableTracking)
                {
                    query = query.AsNoTracking();
                }
    
                if (include != null)
                {
                    query = include(query);
                }
    
                if (predicate != null)
                {
                    query = query.Where(predicate);
                }
    
                if (ignoreQueryFilters)
                {
                    query = query.IgnoreQueryFilters();
                }
    
                if (orderBy != null)
                {
                    return await orderBy(query).ToPagedListAsync(pageIndex, pageSize, 1, cancellationToken);
                }
                else
                {
                    return await query.ToPagedListAsync(pageIndex, pageSize, 1, cancellationToken);
                }
            }
    
            public virtual IPagedList<TResult> GetPagedList<TResult>(
                Expression<Func<TEntity, TResult>> selector,
                Expression<Func<TEntity, bool>> predicate = null,
                Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
                Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
                int pageIndex = 1,
                int pageSize = 20,
                bool disableTracking = true,
                bool ignoreQueryFilters = false)
                where TResult : class
            {
                IQueryable<TEntity> query = _dbSet;
                if (disableTracking)
                {
                    query = query.AsNoTracking();
                }
    
                if (include != null)
                {
                    query = include(query);
                }
    
                if (predicate != null)
                {
                    query = query.Where(predicate);
                }
    
                if (ignoreQueryFilters)
                {
                    query = query.IgnoreQueryFilters();
                }
    
                if (orderBy != null)
                {
                    return orderBy(query).Select(selector).ToPagedList(pageIndex, pageSize);
                }
                else
                {
                    return query.Select(selector).ToPagedList(pageIndex, pageSize);
                }
            }
    
            public virtual async Task<IPagedList<TResult>> GetPagedListAsync<TResult>(
                Expression<Func<TEntity, TResult>> selector,
                Expression<Func<TEntity, bool>> predicate = null,
                Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
                Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
                int pageIndex = 1,
                int pageSize = 20,
                bool disableTracking = true,
                bool ignoreQueryFilters = false,
                CancellationToken cancellationToken = default)
                where TResult : class
            {
                IQueryable<TEntity> query = _dbSet;
                if (disableTracking)
                {
                    query = query.AsNoTracking();
                }
    
                if (include != null)
                {
                    query = include(query);
                }
    
                if (predicate != null)
                {
                    query = query.Where(predicate);
                }
    
                if (ignoreQueryFilters)
                {
                    query = query.IgnoreQueryFilters();
                }
    
                if (orderBy != null)
                {
                    return await orderBy(query).Select(selector).ToPagedListAsync(pageIndex, pageSize, 1, cancellationToken);
                }
                else
                {
                    return await query.Select(selector).ToPagedListAsync(pageIndex, pageSize, 1, cancellationToken);
                }
            }
            #endregion
    
            #region GetFirstOrDefault 
    
            public virtual TEntity GetFirstOrDefault(
                Expression<Func<TEntity, bool>> predicate = null,
                Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
                Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
                bool disableTracking = true,
                bool ignoreQueryFilters = false)
            {
                IQueryable<TEntity> query = _dbSet;
                if (disableTracking)
                {
                    query = query.AsNoTracking();
                }
    
                if (include != null)
                {
                    query = include(query);
                }
    
                if (predicate != null)
                {
                    query = query.Where(predicate);
                }
    
                if (ignoreQueryFilters)
                {
                    query = query.IgnoreQueryFilters();
                }
    
                if (orderBy != null)
                {
                    return orderBy(query).FirstOrDefault();
                }
                else
                {
                    return query.FirstOrDefault();
                }
            }
    
    
            public virtual async Task<TEntity> GetFirstOrDefaultAsync(
                Expression<Func<TEntity, bool>> predicate = null,
                Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
                Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
                bool disableTracking = true,
                bool ignoreQueryFilters = false,
                CancellationToken cancellationToken = default)
            {
                IQueryable<TEntity> query = _dbSet;
                if (disableTracking)
                {
                    query = query.AsNoTracking();
                }
    
                if (include != null)
                {
                    query = include(query);
                }
    
                if (predicate != null)
                {
                    query = query.Where(predicate);
                }
    
                if (ignoreQueryFilters)
                {
                    query = query.IgnoreQueryFilters();
                }
    
                if (orderBy != null)
                {
                    return await orderBy(query).FirstOrDefaultAsync(cancellationToken);
                }
                else
                {
                    return await query.FirstOrDefaultAsync(cancellationToken);
                }
            }
    
            public virtual TResult GetFirstOrDefault<TResult>(
                Expression<Func<TEntity, TResult>> selector,
                Expression<Func<TEntity, bool>> predicate = null,
                Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
                Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
                bool disableTracking = true,
                bool ignoreQueryFilters = false)
            {
                IQueryable<TEntity> query = _dbSet;
                if (disableTracking)
                {
                    query = query.AsNoTracking();
                }
    
                if (include != null)
                {
                    query = include(query);
                }
    
                if (predicate != null)
                {
                    query = query.Where(predicate);
                }
    
                if (ignoreQueryFilters)
                {
                    query = query.IgnoreQueryFilters();
                }
    
                if (orderBy != null)
                {
                    return orderBy(query).Select(selector).FirstOrDefault();
                }
                else
                {
                    return query.Select(selector).FirstOrDefault();
                }
            }
    
            public virtual async Task<TResult> GetFirstOrDefaultAsync<TResult>(
                Expression<Func<TEntity, TResult>> selector,
                Expression<Func<TEntity, bool>> predicate = null,
                Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
                Func<IQueryable<TEntity>, IIncludableQueryable<TEntity, object>> include = null,
                bool disableTracking = true,
                bool ignoreQueryFilters = false,
                CancellationToken cancellationToken = default)
            {
                IQueryable<TEntity> query = _dbSet;
                if (disableTracking)
                {
                    query = query.AsNoTracking();
                }
    
                if (include != null)
                {
                    query = include(query);
                }
    
                if (predicate != null)
                {
                    query = query.Where(predicate);
                }
    
                if (ignoreQueryFilters)
                {
                    query = query.IgnoreQueryFilters();
                }
    
                if (orderBy != null)
                {
                    return await orderBy(query).Select(selector).FirstOrDefaultAsync(cancellationToken);
                }
                else
                {
                    return await query.Select(selector).FirstOrDefaultAsync(cancellationToken);
                }
            }
            #endregion
    
            #region Find
    
            public virtual TEntity Find(params object[] keyValues) => _dbSet.Find(keyValues);
    
            public virtual ValueTask<TEntity> FindAsync(params object[] keyValues) => _dbSet.FindAsync(keyValues);
    
            public virtual ValueTask<TEntity> FindAsync(object[] keyValues, CancellationToken cancellationToken) => _dbSet.FindAsync(keyValues, cancellationToken);
            #endregion
    
            #region sql、count、exist
            public virtual IQueryable<TEntity> FromSql(string sql, params object[] parameters) => _dbSet.FromSqlRaw(sql, parameters);
    
            public virtual int Count(Expression<Func<TEntity, bool>> predicate = null)
            {
                if (predicate == null)
                {
                    return _dbSet.Count();
                }
                else
                {
                    return _dbSet.Count(predicate);
                }
            }
    
            public virtual async Task<int> CountAsync(Expression<Func<TEntity, bool>> predicate = null)
            {
                if (predicate == null)
                {
                    return await _dbSet.CountAsync();
                }
                else
                {
                    return await _dbSet.CountAsync(predicate);
                }
            }
            public virtual bool Exists(Expression<Func<TEntity, bool>> predicate = null)
            {
                if (predicate == null)
                {
                    return _dbSet.Any();
                }
                else
                {
                    return _dbSet.Any(predicate);
                }
            }
            #endregion
    
            #region Insert
            public virtual TEntity Insert(TEntity entity)
            {
                return _dbSet.Add(entity).Entity;
            }
    
            public virtual void Insert(params TEntity[] entities) => _dbSet.AddRange(entities);
    
            public virtual void Insert(IEnumerable<TEntity> entities) => _dbSet.AddRange(entities);
    
            public virtual ValueTask<EntityEntry<TEntity>> InsertAsync(TEntity entity, CancellationToken cancellationToken = default(CancellationToken))
            {
                return _dbSet.AddAsync(entity, cancellationToken);
    
                // Shadow properties?
                //var property = _dbContext.Entry(entity).Property("Created");
                //if (property != null) {
                //property.CurrentValue = DateTime.Now;
                //}
            }
    
            public virtual Task InsertAsync(params TEntity[] entities) => _dbSet.AddRangeAsync(entities);
    
            public virtual Task InsertAsync(IEnumerable<TEntity> entities, CancellationToken cancellationToken = default(CancellationToken)) => _dbSet.AddRangeAsync(entities, cancellationToken);
    
            #endregion
    
            #region Update
            public virtual void Update(TEntity entity)
            {
                _dbSet.Update(entity);
            }
    
            public virtual void UpdateAsync(TEntity entity)
            {
                _dbSet.Update(entity);
    
            }
    
            public virtual void Update(params TEntity[] entities) => _dbSet.UpdateRange(entities);
    
            public virtual void Update(IEnumerable<TEntity> entities) => _dbSet.UpdateRange(entities);
            #endregion
    
            #region Delete
    
            public virtual void Delete(TEntity entity) => _dbSet.Remove(entity);
    
            public virtual void Delete(object id)
            {
                var entity = _dbSet.Find(id);
                if (entity != null)
                {
                    Delete(entity);
                }
            }
    
            public virtual void Delete(params TEntity[] entities) => _dbSet.RemoveRange(entities);
    
            public virtual void Delete(IEnumerable<TEntity> entities) => _dbSet.RemoveRange(entities);
    
            #endregion
    
        }
    }
    
    

    说明

    • 封装了通用的增删改查操作
    • 以Async方法名结尾的是异步操作
    • 方法注释都在接口中
    • 查询:
      • GetAll查询所有满足条件的实体(注意性能)
      • GetPagedList分页查询
      • GetFirstOrDefault获取满足条件的第一个元素
      • Find根据主键查找元素,比如给一个Id值
      • FromSql原生sql查询
      • Count查询数量
      • Exists查询是否存在
    • 查询中包含了很多条件:
      • 分页查询默认每页20笔数据
      • 默认关闭了追踪查询
      • 默认开启了全局查询过滤
      • selector参数可以转换查询出来的数据为其他类型

    工作单元

    MS.UnitOfWork项目中添加UnitOfWork文件夹,在该文件夹下添加IUnitOfWork.csUnitOfWork.cs类。

    IUnitOfWork.cs

    using Microsoft.EntityFrameworkCore;
    using Microsoft.EntityFrameworkCore.Storage;
    using System;
    using System.Linq;
    using System.Threading.Tasks;
    
    namespace MS.UnitOfWork
    {
        /// <summary>
        /// 定义工作单元接口
        /// </summary>
        public interface IUnitOfWork<TContext> : IDisposable where TContext : DbContext
        {
            /// <summary>
            /// 获取DBContext
            /// </summary>
            /// <returns></returns>
            TContext DbContext { get; }
            /// <summary>
            /// 开始一个事务
            /// </summary>
            /// <returns></returns>
            IDbContextTransaction BeginTransaction();
    
            /// <summary>
            /// 获取指定仓储
            /// </summary>
            /// <typeparam name="TEntity"></typeparam>
            /// <param name="hasCustomRepository">如有自定义仓储设为True</param>
            /// <returns></returns>
            IRepository<TEntity> GetRepository<TEntity>(bool hasCustomRepository = false) where TEntity : class;
    
            /// <summary>
            /// DbContext提交修改
            /// </summary>
            /// <returns></returns>
            int SaveChanges();
    
            /// <summary>
            /// DbContext提交修改(异步)
            /// </summary>
            /// <returns></returns>
            Task<int> SaveChangesAsync();
    
            /// <summary>
            /// 执行原生sql语句
            /// </summary>
            /// <param name="sql">sql语句</param>
            /// <param name="parameters">参数</param>
            /// <returns></returns>
            int ExecuteSqlCommand(string sql, params object[] parameters);
    
            /// <summary>
            /// 使用原生sql查询来获取指定数据
            /// </summary>
            /// <typeparam name="TEntity"></typeparam>
            /// <param name="sql"></param>
            /// <param name="parameters">参数</param>
            /// <returns></returns>
            IQueryable<TEntity> FromSql<TEntity>(string sql, params object[] parameters) where TEntity : class;
        }
    }
    

    UnitOfWork.cs

    using Microsoft.EntityFrameworkCore;
    using Microsoft.EntityFrameworkCore.Infrastructure;
    using Microsoft.EntityFrameworkCore.Storage;
    using System;
    using System.Collections.Generic;
    using System.Linq;
    using System.Threading.Tasks;
    
    namespace MS.UnitOfWork
    {
        /// <summary>
        /// 工作单元的默认实现.
        /// </summary>
        /// <typeparam name="TContext"></typeparam>
        public class UnitOfWork<TContext> : IUnitOfWork<TContext> where TContext : DbContext
        {
            protected readonly TContext _context;
            protected bool _disposed = false;
            protected Dictionary<Type, object> _repositories;
    
            public UnitOfWork(TContext context)
            {
                _context = context ?? throw new ArgumentNullException(nameof(context));
            }
    
            /// <summary>
            /// 获取DbContext
            /// </summary>
            public TContext DbContext => _context;
            /// <summary>
            /// 开始一个事务
            /// </summary>
            /// <returns></returns>
            public IDbContextTransaction BeginTransaction()
            {
                return _context.Database.BeginTransaction();
            }
    
            /// <summary>
            /// 获取指定仓储
            /// </summary>
            /// <typeparam name="TEntity"></typeparam>
            /// <param name="hasCustomRepository"></param>
            /// <returns></returns>
            public IRepository<TEntity> GetRepository<TEntity>(bool hasCustomRepository = false) where TEntity : class
            {
                if (_repositories == null)
                {
                    _repositories = new Dictionary<Type, object>();
                }
    
                Type type = typeof(IRepository<TEntity>);
                if (!_repositories.TryGetValue(type, out object repo))
                {
                    IRepository<TEntity> newRepo = new Repository<TEntity>(_context);
                    _repositories.Add(type, newRepo);
                    return newRepo;
                }
                return (IRepository<TEntity>)repo;
            }
    
            /// <summary>
            /// 执行原生sql语句
            /// </summary>
            /// <param name="sql">sql语句</param>
            /// <param name="parameters">参数</param>
            /// <returns></returns>
            public int ExecuteSqlCommand(string sql, params object[] parameters) => _context.Database.ExecuteSqlRaw(sql, parameters);
    
            /// <summary>
            /// 使用原生sql查询来获取指定数据
            /// </summary>
            /// <typeparam name="TEntity"></typeparam>
            /// <param name="sql"></param>
            /// <param name="parameters">参数</param>
            /// <returns></returns>
            public IQueryable<TEntity> FromSql<TEntity>(string sql, params object[] parameters) where TEntity : class => _context.Set<TEntity>().FromSqlRaw(sql, parameters);
    
            /// <summary>
            /// DbContext提交修改
            /// </summary>
            /// <returns></returns>
            public int SaveChanges()
            {
                return _context.SaveChanges();
            }
    
            /// <summary>
            /// DbContext提交修改(异步)
            /// </summary>
            /// <returns></returns>
            public async Task<int> SaveChangesAsync()
            {
                return await _context.SaveChangesAsync();
            }
    
    
            public void Dispose()
            {
                Dispose(true);
    
                GC.SuppressFinalize(this);
            }
            protected virtual void Dispose(bool disposing)
            {
                if (!_disposed)
                {
                    if (disposing)
                    {
                        // clear repositories
                        if (_repositories != null)
                        {
                            _repositories.Clear();
                        }
    
                        // dispose the db context.
                        _context.Dispose();
                    }
                }
    
                _disposed = true;
            }
        }
    }
    
    

    说明

    • 从工作单元中获取仓储或DbContext数据库上下文
    • 如果要使用Transaction事务,也是从工作单元中开启
    • 通过仓储修改数据后,使用工作单元SaveChanges提交修改

    封装Ioc注册

    MS.UnitOfWork项目中添加UnitOfWorkServiceExtensions.cs类:

    using Microsoft.EntityFrameworkCore;
    using Microsoft.Extensions.DependencyInjection;
    
    namespace MS.UnitOfWork
    {
        /// <summary>
        ///在 <see cref="IServiceCollection"/>中安装工作单元依赖注入的扩展方法
        /// </summary>
        public static class UnitOfWorkServiceExtensions
        {
            /// <summary>
            /// 在<see cref ="IServiceCollection"/>中注册给定上下文作为服务的工作单元。
            /// 同时注册了dbcontext
            /// </summary>
            /// <typeparam name="TContext"></typeparam>
            /// <param name="services"></param>
            /// <remarks>此方法仅支持一个db上下文,如果多次调用,将抛出异常。</remarks>
            /// <returns></returns>
            public static IServiceCollection AddUnitOfWorkService<TContext>(this IServiceCollection services, System.Action<DbContextOptionsBuilder> action) where TContext : DbContext
            {
                //注册dbcontext
                services.AddDbContext<TContext>(action);
                //注册工作单元
                services.AddScoped<IUnitOfWork<TContext>, UnitOfWork<TContext>>();
                return services;
            }
        }
    }
    

    这样一来,如果项目要使用该工作单元,直接在Startup中调用AddUnitOfWorkService注册即可

    项目完成后,如下图所示:

    使用方法展示

    using (var tran = _unitOfWork.BeginTransaction())//开启一个事务
    {
        Role newRow = _mapper.Map<Role>(viewModel);
        newRow.Id = _idWorker.NextId();//获取一个雪花Id
        newRow.Creator = 1219490056771866624;//由于暂时还没有做登录,所以拿不到登录者信息,先随便写一个后面再完善
        newRow.CreateTime = DateTime.Now;
        _unitOfWork.GetRepository<Role>().Insert(newRow);
        await _unitOfWork.SaveChangesAsync();
        await tran.CommitAsync();//提交事务
    }
    

    以上展示了工作单元开启事务,用using包裹,直到tran.CommitAsync()才提交事务,如果遇到错误,会自动回滚

    //从数据库中取出该记录
    var row = await _unitOfWork.GetRepository<Role>().FindAsync(viewModel.Id);//在viewModel.CheckField中已经获取了一次用于检查,所以此处不会重复再从数据库取一次,有缓存
    //修改对应的值
    row.Name = viewModel.Name;
    row.DisplayName = viewModel.DisplayName;
    row.Remark = viewModel.Remark;
    row.Modifier = 1219490056771866624;//由于暂时还没有做登录,所以拿不到登录者信息,先随便写一个后面再完善
    row.ModifyTime = DateTime.Now;
    _unitOfWork.GetRepository<Role>().Update(row);
    await _unitOfWork.SaveChangesAsync();//提交
    
    • 以上展示了根据主键Id获取数据,更新数据。
    • 也可以GetFirstOrDefault获取数据,disableTracking参数设为false,开启追踪,这样获取到的数据修改后,直接SaveChangesAsync,不需要update(关键就是开启了追踪,所以不需要update实体了)
  • 相关阅读:
    今年要读的书
    java多线程
    json-lib 使用教程
    tomcat原理
    静态long类型常量serialVersionUID的作用
    使用junit4测试Spring
    MySQL各版本的区别
    spring mvc 下载安装
    hibernate、struts、spring mvc的作用
    【面试】hibernate n+1问题
  • 原文地址:https://www.cnblogs.com/kasnti/p/12238521.html
Copyright © 2011-2022 走看看