zoukankan      html  css  js  c++  java
  • How to create own operator with python in mxnet?

    继承CustomOp

    • 定义操作符,重写前向后向方法,此时可以通过_init__ 方法传递需要用到的参数
     1 class LossLayer(mxnet.operator.CustomOp):
     2     def __init__(self, *args, **kwargs):
     3         super(LossLayer, self).__init__()
     4         # recipe some arguments for forward or backward calculation
     5         
     6     def forward(self, is_train, req, in_data, out_data, aux):
     7         """
     8         in_data是一个列表,其中tensor的顺序和对应属性类中定义的list_arguments()参数一一对应
     9         out_data输出列表
    10         is_train 是否是训练过程
    11         req [Null, write or inplace, add]指如何处理对应的复制操作
    12         """
    13         pass
    14         # 函数最后一般调用父类的self.assign(dst, req[0], src)进行赋值操作
    15         # 但对于dst或者src是list类型的时候要调用多次assign函数处理,此时也可以直接自己赋值
    16         # dst[:]=src
    17         
    18     def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
    19         """
    20         out_grad 上一层反传的误差
    21         in_data 输入数据,list
    22         out_data 输出的数据,由forward方法确定, 其类型大小和out_grad一致
    23         in_grad 需要计算的回传误差
    24         """
    25         pass
    26         # 其操作值得复制操作类似于forward方法        
    • 定义好操作符之后还需要定义其对应的属性类,并将其注册到operator中
    1 @mx.operator.register('losslayer')  # 注意这里注册的名字将是后面调用该操作符使用的类型名
    • 重写对应的属性类
     1 class LossLayerProp(mx.operator.CustomOpProp): # 这里的名字并非必须对应操作类名称,被@修饰符修饰
     2   def __init__(self, params):
     3     super(LossLayerProp,self).__init__(need_top_grad=False)
     4     # 最后的损失层不需要接收上层的误差,则将need_top_grad设置为False
     5     # 可以传递一些参数用以传递给操作类
     6    
     7   def list_arguments(self):  
     8     # 这个方法非常重要,定义了该操作符的输入参数,当绑定对应操作符时,输入量由该方法指定
     9     return ['data1','data2','data3','label']
    10   
    11   def list_outputs(self):
    12     # 同样返回的是列表,表示输出的量,这个其实是输出变量的后缀suffix
    13     # 若返回的是['output1','output2']则输出为 操作类的名称name加上对应后缀的量[name_output1, name_output2]
    14     return ['output']
    15   
    16   def infer_shape(self, in_shape):
    17     # 给定in_shape,显示每一个变量的对应大小,以判断大小是否一致
    18     return [],[],[]
    19       # 返回的必须是3个列表,即使列表为空,分别对应着输入参数的大小、输出数据的大小、aux参数的大小,一般最后一个为空
    20     
    21     def infer_type(self, in_type):
    22       # 该方法类似于infer_shape,推断数据类型
    23 
    24     def create_operator(self, ctx, shapes, dtypes):
    25       # 该方法真正的创建操作类对象,默认调用
    26       return LossLayer()
    • 自定义操作符的使用
     1 data1=mx.sym.Variable('data1')
     2 data2=mx.sym.Variable('data2')
     3 data3=mx.sym.Variable('data3')
     4 label = mx.sym.Variable('label')
     5 # 下面这句调用很重要,显示指定输入的symbol,然后指定自定义操作符类型
     6 net = mx.sym.Custom(data1=data1, data2=data2, data3=data3, label=label, name='net', op_type='losslayer')  
     7 # 输出操作符的相关属性
     8 print(net.infer_shape(data1=(4,1,10,10), data2=(4,1,10,10),data3=(4,1,10,10) label=(4,)))
     9 # data1=(4,1,10,10)表示对应symbol的shape
    10 print(net.infer_type(data1=np.int, data2=np.int, data3=np.int, label=np.int))
    11 # data1=np.int 标识对应symbol的数据类型
    12 print(net.list_arguments()) # 变量参数
    13 print(net.list_outputs()) #输出的变量参数
    14 
    15 ex = net.simple_bind(ctx=mx.gpu(0), data1=(4,1,10,10), data2=(4,1,10,10),data3=(4,1,10,10) label=(4,)) # simple_bind只需要指定输入参数的大小
    16 ex.forward(data1=data1, data2=data2, label=label))
    17 print(ex.outputs[0])
    • 上面是没有参数的层,创建带有参数的中间层和上面类似, 只是修改下面部分代码
    1 def list_arguments(self):
    2     return ['data','weight', 'bias']
    3     
    4 def infer_shape(self, in_shape):
    5     data_shape = in_shape[0]
    6     weight_shape = ...
    7     bias_shape = ...
    8     output_shape = ...
    9     return [data_shape, weight_shape, bias_shape], [output_shape], []

    调用方式:

    net = mx.symbol.Custom(data, name='newLayer', op_type='myLayer')

     包含参数的layer在定义backward方法时要注意梯度的更新方式,即req的选择

    NOTE:

    有参数的操作符中,一般使用‘weight’和‘bias’作为参数, 该参数会最为后缀加到 opname_weight, opname_bias中,因为mxnet默认的参数初始化方法只认‘weight’, 'bias', 'gamma', 'beta'四个量, 对于自己新定义的量,比如weight2, 需要指定初始化方法

    Default initialization is now limited to "weight", "bias", "gamma" (1.0), and "beta" (0.0).
    Please use mx.sym.Variable(init=mx.init.*) to set initialization pattern
  • 相关阅读:
    Struts上传
    Struts2转换器
    Strust2拦截器
    Strust2标签(转)
    hibernate延迟加载和抓取策略(转)
    hibernate映射(单向双向的一对多、多对一以及一对一、多对一(转)
    struts简单实现新闻的增删改查
    HIbernate 缓存机制(转)
    Hibernate中封装session(静态单例模式)
    使用工具自动生成hibernate的配置文件、实体类与连接数据库
  • 原文地址:https://www.cnblogs.com/YiXiaoZhou/p/7505289.html
Copyright © 2011-2022 走看看