zoukankan      html  css  js  c++  java
  • pytorch中使用cuda扩展

    以下面这个例子作为教程,实现功能是element-wise add

    (pytorch中想调用cuda模块,还是用另外使用C编写接口脚本)

    第一步:cuda编程的源文件和头文件

    // mathutil_cuda_kernel.cu
    // 头文件,最后一个是cuda特有的
    #include <curand.h>
    #include <stdio.h>
    #include <math.h>
    #include <float.h>
    #include "mathutil_cuda_kernel.h"
    
    // 获取GPU线程通道信息
    dim3 cuda_gridsize(int n)
    {
        int k = (n - 1) / BLOCK + 1;
        int x = k;
        int y = 1;
        if(x > 65535) {
            x = ceil(sqrt(k));
            y = (n - 1) / (x * BLOCK) + 1;
        }
        dim3 d(x, y, 1);
        return d;
    }
    // 这个函数是cuda执行函数,可以看到细化到了每一个元素
    __global__ void broadcast_sum_kernel(float *a, float *b, int x, int y, int size)
    {
        int i = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
        if(i >= size) return;
        int j = i % x; i = i / x;
        int k = i % y;
        a[IDX2D(j, k, y)] += b[k];
    }
    
    
    // 这个函数是与c语言函数链接的接口函数
    void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream)
    {
        int size = x * y;
        cudaError_t err;
        
        // 上面定义的函数
        broadcast_sum_kernel<<<cuda_gridsize(size), BLOCK, 0, stream>>>(a, b, x, y, size);
    
        err = cudaGetLastError();
        if (cudaSuccess != err)
        {
            fprintf(stderr, "CUDA kernel failed : %s
    ", cudaGetErrorString(err));
            exit(-1);
        }
    }
    #ifndef _MATHUTIL_CUDA_KERNEL
    #define _MATHUTIL_CUDA_KERNEL
    
    #define IDX2D(i, j, dj) (dj * i + j)
    #define IDX3D(i, j, k, dj, dk) (IDX2D(IDX2D(i, j, dj), k, dk))
    
    #define BLOCK 512
    #define MAX_STREAMS 512
    
    #ifdef __cplusplus
    extern "C" {
    #endif
    
    void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream);
    
    #ifdef __cplusplus
    }
    #endif
    
    #endif

    第二步:C编程的源文件和头文件(接口函数)

    // mathutil_cuda.c
    // THC是pytorch底层GPU库
    #include <THC/THC.h>
    #include "mathutil_cuda_kernel.h"
    
    extern THCState *state;
    
    int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y)
    {
        float *a = THCudaTensor_data(state, a_tensor);
        float *b = THCudaTensor_data(state, b_tensor);
        cudaStream_t stream = THCState_getCurrentStream(state);
    
        // 这里调用之前在cuda中编写的接口函数
        broadcast_sum_cuda(a, b, x, y, stream);
    
        return 1;
    }
    int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y);

    第三步:编译,先编译cuda模块,再编译接口函数模块(不能放在一起同时编译)

    nvcc -c -o mathutil_cuda_kernel.cu.o mathutil_cuda_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52
    import os
    import torch
    from torch.utils.ffi import create_extension
    
    this_file = os.path.dirname(__file__)
    
    sources = []
    headers = []
    defines = []
    with_cuda = False
    
    if torch.cuda.is_available():
        print('Including CUDA code.')
        sources += ['src/mathutil_cuda.c']
        headers += ['src/mathutil_cuda.h']
        defines += [('WITH_CUDA', None)]
        with_cuda = True
    
    this_file = os.path.dirname(os.path.realpath(__file__))
    
    extra_objects = ['src/mathutil_cuda_kernel.cu.o']   # 这里是编译好后的.o文件位置
    extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]
    
    
    ffi = create_extension(
        '_ext.cuda_util',
        headers=headers,
        sources=sources,
        define_macros=defines,
        relative_to=__file__,
        with_cuda=with_cuda,
        extra_objects=extra_objects
    )
    
    if __name__ == '__main__':
        ffi.build()

    第四步:调用cuda模块

    from _ext import cuda_util  #从对应路径中调用编译好的模块
    
    a = torch.randn(3, 5).cuda()
    b = torch.randn(3, 1).cuda()
    mathutil.broadcast_sum(a, b, *map(int, a.size()))
    
    # 上面等价于下面的效果:
    
    a = torch.randn(3, 5)
    b = torch.randn(3, 1)
    a += b
  • 相关阅读:
    锤子科技官网:问题整理及注意事项
    springboot中文文档
    Spring Framework 开发参考手册中文(在线HTML)
    .is() 全选复选的判断
    c:forEach用法
    SSM框架——详细整合教程(Spring+SpringMVC+MyBatis)
    火狐浏览器下载文件保存文件名的乱码问题
    多线程安全的解决方法
    MySQL的concat以及group_concat的用法
    mysql 将时间转换成时间戳
  • 原文地址:https://www.cnblogs.com/zf-blog/p/11883166.html
Copyright © 2011-2022 走看看