zoukankan      html  css  js  c++  java
  • pytorch中调用C进行扩展

    pytorch中调用C进行扩展,使得某些功能在CPU上运行更快;

    第一步:编写头文件

    /* src/my_lib.h */
    int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2, THFloatTensor *output);
    int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input);

    第二步:编写源文件

    /* src/my_lib.c */
    #include <TH/TH.h>
    
    int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2,
    THFloatTensor *output)
    {
        if (!THFloatTensor_isSameSizeAs(input1, input2))
            return 0;
        THFloatTensor_resizeAs(output, input1);
        THFloatTensor_cadd(output, input1, 1.0, input2);
        return 1;
    }
    
    int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input)
    {
        THFloatTensor_resizeAs(grad_input, grad_output);
        THFloatTensor_fill(grad_input, 1);
        return 1;
    }

    注意:头文件TH就是pytorch底层代码的接口头文件,它是CPU模式,GPU下则为THC;

     第三步:在同级目录下创建一个.py文件(比如叫“build.py”)

    该文件用于对该C扩展模块进行编译(使用torch.util.ffi模块进行扩展编译);

    # build.py
    from torch.utils.ffi import create_extension
    ffi = create_extension(
    name='_ext.my_lib',        # 输出文件地址及名称
    headers='src/my_lib.h',    # 编译.h文件地址及名称
    sources=['src/my_lib.c'],  # 编译.c文件地址及名称
    with_cuda=False            # 不使用cuda
    )
    ffi.build()

    第四步:编写.py脚本调用编译好的C扩展模块

    import torch
    from torch.autograd import Function
    from _ext import my_lib
    import torch.nn as nn
    
    class MyAddFunction(Function):
        def forward(self, input1, input2):
            output = torch.FloatTensor()
            my_lib.my_lib_add_forward(input1, input2, output)
            return output
    
        def backward(self, grad_output):
            grad_input = torch.FloatTensor()
            my_lib.my_lib_add_backward(grad_input, grad_output)
            return grad_input
    
    class MyAddModule(nn.Module):
        def forward(self, input1, input2):
            return MyAddFunction()(input1, input2)
    
    class MyNetWork(nn.Module):
        def __init__(self):
            super(MyNetWork, self).__init__()
            self.add = MyAddModule()
    
        def forward(self, input1, input2):
            return self.add(input1, input2)
    
    model = MyNetWork()
    input1, input2 = torch.randn(5, 5), torch.randn(5, 5)
    print(model(input1, input2))
    print(input1 + input2)

    至此,用这个简单的例子抛砖引玉~

  • 相关阅读:
    基础知识---抽象类和接口
    基础知识---数组和链表
    基础知识---枚举
    基础知识---IEnumerable、ICollection、IList、IQueryable
    [翻译]微软 Build 2019 正式宣布 .NET 5
    基础知识---const、readonly、static
    简说设计模式
    Java修行之路
    简说设计模式——迭代器模式
    简说设计模式——备忘录模式
  • 原文地址:https://www.cnblogs.com/zf-blog/p/11857580.html
Copyright © 2011-2022 走看看