zoukankan      html  css  js  c++  java
  • pytorch中调用C进行扩展

    pytorch中调用C进行扩展,使得某些功能在CPU上运行更快;

    第一步:编写头文件

    /* src/my_lib.h */
    int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2, THFloatTensor *output);
    int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input);

    第二步:编写源文件

    /* src/my_lib.c */
    #include <TH/TH.h>
    
    int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2,
    THFloatTensor *output)
    {
        if (!THFloatTensor_isSameSizeAs(input1, input2))
            return 0;
        THFloatTensor_resizeAs(output, input1);
        THFloatTensor_cadd(output, input1, 1.0, input2);
        return 1;
    }
    
    int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input)
    {
        THFloatTensor_resizeAs(grad_input, grad_output);
        THFloatTensor_fill(grad_input, 1);
        return 1;
    }

    注意:头文件TH就是pytorch底层代码的接口头文件,它是CPU模式,GPU下则为THC;

     第三步:在同级目录下创建一个.py文件(比如叫“build.py”)

    该文件用于对该C扩展模块进行编译(使用torch.util.ffi模块进行扩展编译);

    # build.py
    from torch.utils.ffi import create_extension
    ffi = create_extension(
    name='_ext.my_lib',        # 输出文件地址及名称
    headers='src/my_lib.h',    # 编译.h文件地址及名称
    sources=['src/my_lib.c'],  # 编译.c文件地址及名称
    with_cuda=False            # 不使用cuda
    )
    ffi.build()

    第四步:编写.py脚本调用编译好的C扩展模块

    import torch
    from torch.autograd import Function
    from _ext import my_lib
    import torch.nn as nn
    
    class MyAddFunction(Function):
        def forward(self, input1, input2):
            output = torch.FloatTensor()
            my_lib.my_lib_add_forward(input1, input2, output)
            return output
    
        def backward(self, grad_output):
            grad_input = torch.FloatTensor()
            my_lib.my_lib_add_backward(grad_input, grad_output)
            return grad_input
    
    class MyAddModule(nn.Module):
        def forward(self, input1, input2):
            return MyAddFunction()(input1, input2)
    
    class MyNetWork(nn.Module):
        def __init__(self):
            super(MyNetWork, self).__init__()
            self.add = MyAddModule()
    
        def forward(self, input1, input2):
            return self.add(input1, input2)
    
    model = MyNetWork()
    input1, input2 = torch.randn(5, 5), torch.randn(5, 5)
    print(model(input1, input2))
    print(input1 + input2)

    至此,用这个简单的例子抛砖引玉~

  • 相关阅读:
    Eclipse 开发过程中利用 JavaRebel 提高效率
    数字转化为大写中文
    网页变灰
    解决QQ截图无法在PS中粘贴
    ORACLE操作表时”资源正忙,需指定nowait"的解锁方法
    网页常用代码
    SQL Server 2000 删除注册的服务器
    GridView 显示序号
    读取Excel数据到DataTable
    清除SVN版本控制
  • 原文地址:https://www.cnblogs.com/zf-blog/p/11857580.html
Copyright © 2011-2022 走看看