zoukankan      html  css  js  c++  java
  • CUDA实例练习(七):点积运算

      1 #include <stdio.h>
      2 #include <cuda_runtime.h>
      3 #include <device_launch_parameters.h>
      4 #include <book.h>
      5 #define imin(a,b) (a<b?a:b)
      6 
      7 const int N = 33 * 1024;
      8 //const int N = 3;
      9 const int threadsPerBlock = 256;
     10 const int blocksPerGrid = imin(32, (N + threadsPerBlock - 1) / threadsPerBlock);
     11 /*理解该算法,假设有1个数据,每个线程块的线程数为128,那么仍至少需要1个线程块,即1+128-1/128*/
     12 
     13 __global__ void dot(float *a, float *b, float *c) {
     14     __shared__ float cache[threadsPerBlock];
     15     int tid = threadIdx.x + blockIdx.x * blockDim.x;
     16     int cacheIndex = threadIdx.x;
     17     /*通过线程块索引和线程索引计算出输入数组中的一个全局偏移tid。共享内存缓存中的偏移cacheIndex
     18     就等于线程索引.线程块索引与这个偏移无关,因为每个线程块都拥有该共享内存的私有副本*/
     19 
     20     float temp = 0;
     21     while (tid < N) {/*防止索引越过数组边界*/
     22         temp += a[tid] * b[tid];
     23         /*在每个线程计算完当前索引上的任务后,接着就需要对索引进行递增,其中递增的步长为线程格中正在运行
     24         的线程数量。这个数值等于每个线程块中的线程数量乘以线程中线程块的数量。*/
     25         tid += blockDim.x * gridDim.x;
     26     }
     27 
     28     /*设置cache中相应位置上的值。*/
     29     cache[cacheIndex] = temp;
     30 
     31     /*对线程块中的线程进行同步。确保所有对共享数组cache[]的写入操作在读取cache之前完成。*/
     32     __syncthreads();
     33 
     34     //对于归约运算来说,以下代码要求threadPerBlock必须是2的指数
     35     int i = blockDim.x / 2;
     36     while (i != 0){
     37         if (cacheIndex < i)
     38             cache[cacheIndex] += cache[cacheIndex + i];
     39         /*在读取cache[]中的值之前,首先需要确保每个写入cache[]的线程都已经执行完毕。*/
     40         __syncthreads();
     41         i /= 2;
     42     }
     43     if (cacheIndex == 0)
     44         c[blockIdx.x] = cache[0];
     45     /*在结束了while循环后,每个线程块都得到了一个值,这个值位于cache[]的第一个元素中。因为只有一个值
     46     写入到全局内存,因此只需要一个线程来执行这个操作。当然,每个线程都可以执行这个写入操作,但这么做将
     47     使得在写入单个值时带来不必要的内存通信量。为了简单,选择了索引为0的线程。最后,由于每个线程块都
     48     只写入一个值到全局数据c[]中,因此可以通过blockIdx来索引这个值。*/
     49 }
     50 
     51 int main(void){
     52     float *a, *b, c, *partial_c;
     53     float *dev_a, *dev_b, *dev_partial_c;
     54 
     55     //在CPU上分配内存
     56     a = (float *)malloc(N*sizeof(float));
     57     b = (float *)malloc(N*sizeof(float));
     58     partial_c = (float *)malloc(blocksPerGrid*sizeof(float));
     59 
     60     //在GPU上分配内存
     61     HANDLE_ERROR(cudaMalloc((void**)&dev_a, N * sizeof(float)));
     62     HANDLE_ERROR(cudaMalloc((void**)&dev_b, N * sizeof(float)));
     63     HANDLE_ERROR(cudaMalloc((void**)&dev_partial_c, blocksPerGrid * sizeof(float)));
     64 
     65     //填充主机内存
     66     for (int i = 0; i < N; i++){
     67         a[i] = i;
     68         b[i] = i * 2;
     69     }
     70 
     71     //将数组'a'和'b'复制到GPU
     72     HANDLE_ERROR(cudaMemcpy(dev_a, a, N*sizeof(float), cudaMemcpyHostToDevice));
     73     HANDLE_ERROR(cudaMemcpy(dev_b, b, N*sizeof(float), cudaMemcpyHostToDevice));
     74     dot << <blocksPerGrid, threadsPerBlock >> >(dev_a, dev_b, dev_partial_c);
     75 
     76     //将数组'c'从GPU复制到CPU
     77     HANDLE_ERROR(cudaMemcpy(partial_c, dev_partial_c, blocksPerGrid * sizeof(float),
     78         cudaMemcpyDeviceToHost));
     79 
     80     //在CPU上完成最终的求和运算
     81     c = 0;
     82     for (int i = 0; i < blocksPerGrid; i++){
     83         c += partial_c[i];
     84     }
     85 
     86 #define sum_squares(x) (x*(x+1)*(2*x+1)/6)
     87     printf("Does GPU value %.6g = %.6g?
    ", c, 2 * sum_squares((float)(N - 1)));
     88 
     89     //释放GPU上的内存
     90     cudaFree(dev_a);
     91     cudaFree(dev_b);
     92     cudaFree(dev_partial_c);
     93 
     94     //释放CPU上的内存
     95     free(a);
     96     free(b);
     97     free(partial_c);
     98 }
     99 
    100    

     

  • 相关阅读:
    try_files $uri $uri/ /index.php?$query_string;
    关于declare(strict_types=1)的有效范围
    SVN客户端安装与使用
    Java日志框架中真的需要判断log.isDebugEnabled()吗?
    Spring4自动装配(default-autowire)
    java的@PostConstruct注解
    Google Guava之--cache
    Java类加载机制与Tomcat类加载器架构
    搞懂JVM类加载机制
    Java 类加载机制
  • 原文地址:https://www.cnblogs.com/zhangshuwen/p/7309375.html
Copyright © 2011-2022 走看看