zoukankan      html  css  js  c++  java
  • pytorch 学习笔记之编写 C 扩展,又涨姿势了

    pytorch利用 CFFI 进行 C 语言扩展。包括两个基本的步骤(docs):

    1. 编写 C 代码;

    2. python 调用 C 代码,实现相应的 Function 或 Module。

    在之前的文章中,我们已经了解了如何自定义 Module。至于 [py]torch 的 C 代码库的结构,我们留待之后讨论; 这里,重点关注,如何在 pytorch C 代码库高层接口的基础上,编写 C 代码,以及如何调用自己编写的 C 代码。

    官方示例了如何定义一个加法运算(见 repo)。这里我们定义ReLU函数(见 repo)。

    1. C 代码

    pytorch C 的基本数据结构是 THTensor(THFloatTensor、THByteTensor等)。我们以简单的 ReLU 函数为例,示例编写 C 。

    y=ReLU(x)=max(x,0)

    Function 需要定义前向和后向两个方向的操作,因此,C 代码要实现相应的功能。

    1.1 头文件声明

    /* ext_lib.h */
    int relu_forward(THFloatTensor *input, THFloatTensor *output);
    int relu_backward(THFloatTensor *grad_output, THFloatTensor *input, THFloatTensor *grad_input);
    

    1.2 函数实现

    TH/TH.h 包括了 pytorch C 代码数据结构和函数的声明,这是唯一需要添加的 include 依赖。

    /* ext_lib.c */
    
    #include <TH/TH.h>
    
    int relu_forward(THFloatTensor *input, THFloatTensor *output)
    {
      THFloatTensor_resizeAs(output, input);
      THFloatTensor_clamp(output, input, 0, INFINITY);
      return 1;
    }
    
    int relu_backward(THFloatTensor *grad_output, THFloatTensor *input, THFloatTensor *grad_input)
    {
      THFloatTensor_resizeAs(grad_input, grad_output);
      THFloatTensor_zero(grad_input);
    
      THLongStorage* size = THFloatTensor_newSizeOf(grad_output);
      THLongStorage *stride = THFloatTensor_newStrideOf(grad_output);
      THByteTensor *mask = THByteTensor_newWithSize(size, stride);
    
      THFloatTensor_geValue(mask, input, 0);
      THFloatTensor_maskedCopy(grad_input, mask, grad_output);
      return 1;
    }
    

    2. 编译代码

    2.1 依赖

    由于 pytorch 的代码是纯 C 的,因此没有过多的依赖,只需要安装:

    • pytorch - 安装方法见官网

    • cffi - pip install cffi

    编译文件非常简单,主要是添加头文件和实现文件,以及相关的宏定义; 同时文件还指定了编译后的调用位置(此外为_ext.ext_lib):

    # build.py
    import os
    import torch
    from torch.utils.ffi import create_extension
    
    
    sources = ['src/ext_lib.c']
    headers = ['src/ext_lib.h']
    defines = []
    with_cuda = False
    
    if torch.cuda.is_available():
        print('Including CUDA code.')
        sources += ['src/ext_lib_cuda.c']
        headers += ['src/ext_lib_cuda.h']
        defines += [('WITH_CUDA', None)]
        with_cuda = True
    
    ffi = create_extension(
        '_ext.ext_lib',
        headers=headers,
        sources=sources,
        define_macros=defines,
        relative_to=__file__,
        with_cuda=with_cuda
    )
    
    if __name__ == '__main__':
        ffi.build()
    python build.py
    

    3. python 调用

    3.1 编写配置文件

    python 的调用非常简单——pytorch 的 tensor 对象,对应 C 代码的 THTensor 对象,以此作参数进行调用即可。配置文件如下:

    import torch
    from torch.autograd import Function
    from _ext import ext_lib
    
    class ReLUF(Function):
        def forward(self, input):
            self.save_for_backward(input)
    
            output = input.new()
            if not input.is_cuda:
                ext_lib.relu_forward(input, output)
            else:
                raise Exception, "No CUDA Implementation"
            return output
    
        def backward(self, grad_output):
            input, = self.saved_tensors
    
            grad_input = grad_output.new()
            if not grad_output.is_cuda:
                ext_lib.relu_backward(grad_output, input, grad_input)
            else:
                raise Exception, "No CUDA Implementation"
            return grad_input
    

    3.2 测试

    此处省略 Module 的定义。下面测试下新定义的基于 C 的 ReLU 函数。

    import torch
    import torch.nn as nn
    from torch.autograd import Variable
    
    from modules.relu import ReLUM
    
    torch.manual_seed(1111)
    
    class MyNetwork(nn.Module):
        def __init__(self):
            super(MyNetwork, self).__init__()
            self.relu = ReLUM()
    
        def forward(self, input):
            return self.relu(input)
    
    model = MyNetwork()
    x = torch.randn(1, 25).view(5, 5)
    input = Variable(x, requires_grad=True)
    output = model(input)
    print(output)
    print(input.clamp(min=0))
    
    output.backward(torch.ones(input.size()))
    print(input.grad.data)
    

    输出结果如下:

    Variable containing:
     0.8749  0.5990  0.6844  0.0000  0.0000
     0.6516  0.0000  1.5117  0.5734  0.0072
     0.1286  1.4171  0.0796  1.0355  0.0000
     0.0000  0.0000  0.0312  0.0999  0.0000
     1.0401  1.0599  0.0000  0.0000  0.0000
    [torch.FloatTensor of size 5x5]
    
    Variable containing:
     0.8749  0.5990  0.6844  0.0000  0.0000
     0.6516  0.0000  1.5117  0.5734  0.0072
     0.1286  1.4171  0.0796  1.0355  0.0000
     0.0000  0.0000  0.0312  0.0999  0.0000
     1.0401  1.0599  0.0000  0.0000  0.0000
    [torch.FloatTensor of size 5x5]
    
    
     1  1  1  0  0
     1  0  1  1  1
     1  1  1  1  0
     0  0  1  1  0
     1  1  0  0  0

    原文出自腾讯云技术社区

    原文链接https://www.qcloud.com/community/article/314920

     
  • 相关阅读:
    MySQL连接数过多登录不上
    Linux中盘符的两种挂载方法
    Linux杀毒软件ClamAV初次体验
    VS2013开发asmx接口根据ID查询对象
    VS2013开发asmx接口返回一个自定义XML
    VS2013开发一个简单的asmx接口程序
    Java编译过程(传送门)
    凡人和神学习和使用软件的七个层次
    CentOS7 限制SSH密码尝试次数
    马云是如何招聘到多隆这样的牛人的?(转)
  • 原文地址:https://www.cnblogs.com/hongge66/p/6815114.html
Copyright © 2011-2022 走看看