zoukankan      html  css  js  c++  java
  • go语言依赖注入实现

    最近做项目中,生成对象还是使用比较原始的New和简单工厂的方式,使用过程中感觉不太爽快(依赖紧密,有点改动就比较麻烦),还是比较喜欢使用依赖注入的方式。

    然后网上没有找到比较好用的依赖注入包,就自己动手写了一个,也不要求啥,能用就会,把我从繁琐的New方法中解脱出来。

    先说一下简单实现原理

    1. 通过反射读取对象的依赖(golang是通过tag实现)
    2. 在容器中查找有无该对象实例
    3. 如果有该对象实例或者创建对象的工厂方法,则注入对象或使用工厂创建对象并注入
    4. 如果无该对象实例,则报错

    需要注意的地方:

    1、注入的对象首字母需要大写,小写的话,在go中代表私有,通过反射无法修改值

    2、go反射无法通过读取配置文件信息动态创建对象

    首先,介绍一下项目层次结构

     主要解决:数据库-》仓储(读写分离)-》服务-》控制器 这几层的依赖注入问题

    数据库,我这里为了简化数据库细节,采用模拟数据的办法来实现,实际项目中是需要读取真是数据库的,代码如下

    //准备用户数据,实际开发一般从数据库读取
    var users []entities.UserEntity
    
    func init() {
        users = append(users, entities.UserEntity{ID: 1, Name: "小明", NickName: "无敌", Gender: 1, Age: 13, Tel: "18886588086", Address: "中国,广东,深圳"})
        users = append(users, entities.UserEntity{ID: 2, Name: "小红", NickName: "傻妞", Gender: 0, Age: 13, Tel: "1888658809", Address: "中国,广东,广州"})
    }
    
    type MockDB struct {
        Host  string
        User  string
        Pwd   string
        Alias string
    }
    
    func (db *MockDB) Connect() bool {
        return true
    }
    
    func (db *MockDB) Users() []entities.UserEntity {
        return users
    }
    
    func (db *MockDB) Close() {
    
    }

    数据仓储,为了实现读写分离,分离了两个接口,例如user仓储分为i_user_reader和i_user_repository,其中i_user_repository包含i_user_reader(即继承了i_user_reader)

    接口定义如下:

    type IUserReader interface {
        GetUsers() []dtos.UserDto
        GetUser(id int64) *dtos.UserDto
        GetMaxUserId() int64
    }
    
    type IUserRepository interface {
        IUserReader
        AddUser(user *inputs.UserInput) error
        UpdateUserNickName(id int64, nickName string) error
    }

    仓储实现如下:

    user_read

    type UserRead struct {
        ReadDb *db.MockDB `inject:"MockDBRead"`
    }
    
    func (r *UserRead) GetUsers() []dtos.UserDto {
        if r.ReadDb.Connect() {
            users := r.ReadDb.Users()
            var list []dtos.UserDto
            for _, user := range users {
                list = append(list, dtos.UserDto{ID: user.ID, Name: user.Name, NickName: user.NickName, Gender: user.Gender, Age: user.Age, Tel: user.Tel, Address: user.Address})
            }
            return list
        }
        return nil
    }
    
    func (r *UserRead) GetUser(id int64) *dtos.UserDto {
        if r.ReadDb.Connect() {
            users := r.ReadDb.Users()
            for _, user := range users {
                if user.ID == id {
                    return &dtos.UserDto{ID: user.ID, Name: user.Name, NickName: user.NickName, Gender: user.Gender, Age: user.Age, Tel: user.Tel, Address: user.Address}
                }
            }
            return &dtos.UserDto{}
        }
        return nil
    }
    
    func (r *UserRead) GetMaxUserId() int64 {
        var maxId int64
        if r.ReadDb.Connect() {
            users := r.ReadDb.Users()
            for _, user := range users {
                if user.ID > maxId {
                    maxId = user.ID
                }
            }
        }
        return maxId
    }
    UserRepository:
    type UserRepository struct {
        UserRead
        WriteDb *db.MockDB `inject:"MockDBWrite"`
    }
    
    func (w *UserRepository) AddUser(user *inputs.UserInput) error {
        model := entities.UserEntity{}
        model.ID = w.GetMaxUserId() + 1
        model.Name = user.Name
        model.NickName = user.NickName
        model.Gender = user.Gender
        model.Age = user.Age
        model.Address = user.Address
        if w.ReadDb.Connect() {
            users := w.ReadDb.Users()
            users = append(users, model)
        }
        return nil
    }
    
    func (w *UserRepository) UpdateUserNickName(id int64, nickName string) error {
        user := w.GetUser(id)
        if user.ID > 0 {
            user.NickName = nickName
            return nil
        } else {
            return errors.New("未找到用户信息")
        }
    }

    注意,user_read依赖注入的是读db:ReadDB,user_repository依赖注入的是写db:WriteDB

    服务的接口和实现

    i_user_service:

    type IUserService interface {
        GetUsers() []dtos.UserDto
        GetUser(id int64) *dtos.UserDto
        AddUser(user *inputs.UserInput) error
    }

    user_service:

    type UserService struct {
        UserRepository repositories.IUserRepository `inject:"UserRepository"`
    }
    
    func (s *UserService) AddUser(user *inputs.UserInput) error {
        return s.UserRepository.AddUser(user)
    }
    
    func (s *UserService) GetUsers() []dtos.UserDto {
        return s.UserRepository.GetUsers()
    }
    
    func (s *UserService) GetUser(id int64) *dtos.UserDto {
        return s.UserRepository.GetUser(id)
    }

    UserService依赖注入UserRepository,另外,项目中,特意把仓储接口定义和服务放在同一层,是为了让服务只依赖仓储接口,不依赖仓储具体实现。这算是设计模式原则的依赖倒置原则的体现吧。

    控制器实现:

    type UserController struct {
        UserService user.IUserService `inject:"UserService"`
    }
    
    func (ctrl *UserController) GetUsers(ctx *gin.Context) {
        users := ctrl.UserService.GetUsers()
        Ok(Response{Code: Success, Msg: "获取用户成功!", Data: users}, ctx)
    }
    
    func (ctrl *UserController) GetUser(ctx *gin.Context) {
        idStr := ctx.Param("id")
        id, err := strconv.ParseInt(idStr, 10, 64)
        if err != nil {
            BadRequestError("id参数格式错误", ctx)
            return
        }
        users := ctrl.UserService.GetUser(id)
        Ok(Response{Code: Success, Msg: "获取用户成功!", Data: users}, ctx)
    }
    
    func (ctrl *UserController) AddUser(ctx *gin.Context) {
        input := inputs.UserInput{}
        err := ctx.ShouldBindJSON(&input)
        if err != nil {
            BadRequestError("参数错误", ctx)
            return
        }
        err = ctrl.UserService.AddUser(&input)
        if err != nil {
            Ok(Response{Code: Failed, Msg: err.Error()}, ctx)
            return
        }
        Ok(Response{Code: Success, Msg: "添加用户成功!"}, ctx)
    }

    UserController依赖注入UserService

    接下来是实现依赖注入的核心代码,容器的实现

    Container:

    var injectTagName = "inject" //依赖注入tag名
    
    //生命周期
    // singleton:单例 单一实例,每次使用都是该实例
    // transient:瞬时实例,每次使用都创建新的实例
    type Container struct {
        sync.Mutex
        singletons map[string]interface{}
        transients map[string]factory
    }
    
    type factory = func() (interface{}, error)
    
    //注册单例对象
    func (c *Container) SetSingleton(name string, singleton interface{}) {
        c.Lock()
        c.singletons[name] = singleton
        c.Unlock()
    }
    
    func (c *Container) GetSingleton(name string) interface{} {
        return c.singletons[name]
    }
    
    //注册瞬时实例创建工厂方法
    func (c *Container) SetTransient(name string, factory factory) {
        c.Lock()
        c.transients[name] = factory
        c.Unlock()
    }
    
    func (c *Container) GetTransient(name string) interface{} {
        factory := c.transients[name]
        instance, _ := factory()
        return instance
    }
    
    //注入实例
    func (c *Container) Entry(instance interface{}) error {
        err := c.entryValue(reflect.ValueOf(instance))
        if err != nil {
            return err
        }
        return nil
    }
    
    func (c *Container) entryValue(value reflect.Value) error {
        if value.Kind() != reflect.Ptr {
            return errors.New("必须为指针")
        }
        elemType, elemValue := value.Type().Elem(), value.Elem()
        for i := 0; i < elemType.NumField(); i++ {
            if !elemValue.Field(i).CanSet() { //不可设置 跳过
                continue
            }
    
            fieldType := elemType.Field(i)
            if fieldType.Anonymous {
                //fmt.Println(fieldType.Name + "是匿名字段")
                item := reflect.New(elemValue.Field(i).Type())
                c.entryValue(item) //递归注入
                elemValue.Field(i).Set(item.Elem())
            } else {
                if elemValue.Field(i).IsZero() { //零值才注入
                    //fmt.Println(elemValue.Field(i).Interface())
                    //fmt.Println(fieldType.Name)
                    tag := fieldType.Tag.Get(injectTagName)
                    injectInstance, err := c.getInstance(tag)
                    if err != nil {
                        return err
                    }
                    c.entryValue(reflect.ValueOf(injectInstance)) //递归注入
    
                    elemValue.Field(i).Set(reflect.ValueOf(injectInstance))
                } else {
                    fmt.Println(fieldType.Name)
                }
            }
        }
        return nil
    }
    
    func (c *Container) getInstance(tag string) (interface{}, error) {
        var injectName string
        tags := strings.Split(tag, ",")
        if len(tags) == 0 {
            injectName = ""
        } else {
            injectName = tags[0]
        }
    
        if c.isTransient(tag) {
            factory, ok := c.transients[injectName]
            if !ok {
                return nil, errors.New("transient factory not found")
            } else {
                return factory()
            }
        } else { //默认单例
            instance, ok := c.singletons[injectName]
            if !ok || instance == nil {
                return nil, errors.New(injectName + " dependency not found")
            } else {
                return instance, nil
            }
        }
    }
    
    // transient:瞬时实例,每次使用都创建新的实例
    func (c *Container) isTransient(tag string) bool {
        tags := strings.Split(tag, ",")
        for _, name := range tags {
            if name == "transient" {
                return true
            }
        }
        return false
    }
    
    func (c *Container) String() string {
        lines := make([]string, 0, len(c.singletons)+len(c.transients)+2)
        lines = append(lines, "singletons:")
        for key, value := range c.singletons {
            line := fmt.Sprintf("    %s: %x %s", key, c.singletons[key], reflect.TypeOf(value).String())
            lines = append(lines, line)
        }
    
        lines = append(lines, "transients:")
        for key, value := range c.transients {
            line := fmt.Sprintf("    %s: %x %s", key, c.transients[key], reflect.TypeOf(value).String())
            lines = append(lines, line)
        }
        return strings.Join(lines, "
    ")
    }

    这里使用了两种生命周期的实例:单例和瞬时(其他生命周期,水平有限哈)

    简单说下原理,容器主要包含两个map对象,用来存储对象和创建对方方法,然后依赖注入实现,就是通过反射获取tag信息,再去容器map中获取对象,通过反射把获取的对象赋值到字段中。

    我这里采用了递归注入的方式,所以本项目中,只用注入UserController对象即可,因为实际项目中多点是有多个Controller对象,所以我这里使用了个简单工厂来创建Controller对象,然后只用注入工厂方法即可

    工厂方法实现如下:

    type CtrlFactory struct {
        UserCtrl *controllers.UserController `inject:"UserController"`
    }

    使用容器前,需要先初始化好容器对象,这里使用一个全局对象,然后初始化好需要注入的对象,实现代码如下:

    var GContainer = &Container{
        singletons: make(map[string]interface{}),
        transients: make(map[string]factory),
    }
    
    func Init() {
        //db
        GContainer.SetSingleton("MockDBRead", &db.MockDB{Host: "192.168.1.12:3036", User: "root", Pwd: "123456", Alias: "Read"})
        GContainer.SetSingleton("MockDBWrite", &db.MockDB{Host: "192.168.1.25:3036", User: "root", Pwd: "123456", Alias: "Write"})
    
        //仓储
        GContainer.SetSingleton("UserRepository", &user.UserRepository{})
    
        //服务
        GContainer.SetSingleton("UserService", &userDomain.UserService{})
    
        //控制器
        GContainer.SetSingleton("UserController", &controllers.UserController{})
    
        //控制器工厂
        ctlFactory := &CtrlFactory{}
        GContainer.SetSingleton("CtrlFactory", ctlFactory)
    
        GContainer.Entry(ctlFactory) //注入
    
        fmt.Println(GContainer.String())
    }

    依赖注入代码实现讲完了,然后就是具体使用了,使用时,先在main方法中调用容器出事化方法Init() (注意,这里Init特意大写,要和go包的init区分,go包的init是自动调用,这里大写的Init是需要手动调用的,至于为啥呢,注意是可以控制调用时机,go包的init调用顺序有点莫名其妙,特别是包引用复杂的时候),main代码如下:

    func main() {
        Init()
        Run()
    }
    
    func Init() {
        inject.Init()
    }
    
    func Run() {
        router := router.Init()
    
        s := &http.Server{
            Addr:           ":8080",
            Handler:        router,
            ReadTimeout:    time.Duration(10) * time.Second,
            WriteTimeout:   time.Duration(10) * time.Second,
            MaxHeaderBytes: 1 << 20,
        }
        go func() {
            log.Println("Server Listen at:8080")
            if err := s.ListenAndServe(); err != nil {
                log.Printf("Listen:%s
    ", err)
            }
        }()
    
        quit := make(chan os.Signal)
        signal.Notify(quit, os.Interrupt)
        <-quit
    
        log.Println("Shutdown Server...")
        ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
        defer cancel()
        if err := s.Shutdown(ctx); err != nil {
            log.Fatal("Server Shutdown:", err)
        }
        log.Println("Server exiting")
    }

    我这里使用了gin框架来构建http服务

    初始化话完毕后,就是在路由中使用controller了,先从容器中获取工厂对象,然后通过go类型推断转化为具体类型,代码如下:

    func Init() *gin.Engine {
        // Creates a router without any middleware by default
        r := gin.New()
        r.Use(gin.Logger())
        // Recovery middleware recovers from any panics and writes a 500 if there was one.
        r.Use(gin.Recovery())
    
        r.GET("/ping", func(c *gin.Context) {
            c.JSON(200, gin.H{
                "message": "pong",
            })
        })
    
        factory := inject.GContainer.GetSingleton("CtrlFactory")
        ctrlFactory := factory.(*inject.CtrlFactory)
    
        apiV1 := r.Group("/api/v1")
        //users
        userRg := apiV1.Group("/user")
        {
            userRg.POST("", ctrlFactory.UserCtrl.AddUser)
            userRg.GET("", ctrlFactory.UserCtrl.GetUsers)
            userRg.GET("/:id", ctrlFactory.UserCtrl.GetUser)
        }
    
        gin.SetMode("debug")
        return r
    }

    核心代码就是:

    factory := inject.GContainer.GetSingleton("CtrlFactory")
    ctrlFactory := factory.(*inject.CtrlFactory)

    ok,介绍完了。初始弄这个依赖注入可能觉得有点麻烦,但这是一劳永逸的办法,后面有啥增加修改的就比较简单

    具体代码放在github上了,有兴趣可以关注一下:https://github.com/marshhu/ma-inject

  • 相关阅读:
    Keyboarding题解
    埃及分数 解题报告
    小木棍加强版解题报告
    扩展欧几里得
    luoguP4999 烦人的数学作业
    中国剩余定理
    20201115gryz模拟赛解题报告
    扩展欧几里得算法
    斐蜀定理
    CSP2020-S游记
  • 原文地址:https://www.cnblogs.com/marshhu/p/12955754.html
Copyright © 2011-2022 走看看