zoukankan      html  css  js  c++  java
  • Kaldi CUDA矩阵类与CPU矩阵类的互相转换

    CUDA类转换为CPU类

    示例1

    CuMatrix<BaseFloat> A(RandInt(1, 50), RandInt(1, 50));

    A.SetRandn();

    Matrix<BaseFloat> A2(A);

    示例2

    CuMatrix<BaseFloat> A(RandInt(1, 50), RandInt(1, 50));

    A.SetRandn();

    Matrix<BaseFloat> A2();

    A.CopyToMat(A2);

    代码实现

    src/matrix/kaldi-matrix.h

    template<typename OtherReal>

    explicit Matrix(const CuMatrixBase<OtherReal> &cu,

    MatrixTransposeType trans = kNoTrans);

    src/cudamatrix/cu-matrix.h

    template<typename Real>

    template<typename OtherReal>

    Matrix<Real>::Matrix(const CuMatrixBase<OtherReal> &M,

    MatrixTransposeType trans) {

    if (trans == kNoTrans) Init(M.NumRows(), M.NumCols(), kDefaultStride);

    else Init(M.NumCols(), M.NumRows(), kDefaultStride);

    M.CopyToMat(this, trans);

    }

    template<typename OtherReal>

    void CopyToMat(MatrixBase<OtherReal> *dst,

    MatrixTransposeType trans = kNoTrans) const;

      

    src/cudamatrix/cu-matrix.cc

    template<typename Real>

    template<typename OtherReal>

    void CuMatrixBase<Real>::CopyToMat(MatrixBase<OtherReal> *dst,

    MatrixTransposeType trans) const {

    #if HAVE_CUDA == 1

    if (CuDevice::Instantiate().Enabled()) {

    if (trans == kTrans || sizeof(OtherReal) != sizeof(Real)) {

    CuMatrix<OtherReal> this_trans(*this, trans);

    this_trans.CopyToMat(dst, kNoTrans);

    } else {

    KALDI_ASSERT(dst->NumRows() == NumRows() && dst->NumCols() == NumCols());

    if (num_rows_ == 0) return;

    CuTimer tim;

     

    MatrixIndexT src_pitch = stride_*sizeof(Real);

    MatrixIndexT dst_pitch = dst->Stride()*sizeof(Real);

    MatrixIndexT width = NumCols()*sizeof(Real);

    CU_SAFE_CALL(cudaMemcpy2DAsync(dst->Data(), dst_pitch, this->data_,

    src_pitch, width, this->num_rows_,

    cudaMemcpyDeviceToHost,

    cudaStreamPerThread));

    CU_SAFE_CALL(cudaStreamSynchronize(cudaStreamPerThread));

    CuDevice::Instantiate().AccuProfile("CuMatrix::CopyToMatD2H", tim);

    }

    } else

    #endif

    {

    dst->CopyFromMat(Mat(), trans);

    }

    }

      

       

    CPU类转换为CUDA类

    示例1

    Matrix<BaseFloat> A(RandInt(1, 50), RandInt(1, 50));

    A.SetRandn();

    CuMatrix<BaseFloat> A2(A);

    示例2

    Matrix<BaseFloat> A(RandInt(1, 50), RandInt(1, 50));

    A.SetRandn();

    CuMatrix<BaseFloat> A2();

    A2.CopyFromMat(A);

    代码实现

    src/cudamatrix/cu-matrix.h

    template<typename OtherReal>

    explicit CuMatrix(const MatrixBase<OtherReal> &other,

    MatrixTransposeType trans = kNoTrans);

    src/cudamatrix/cu-matrix.cc

    template<typename Real>

    template<typename OtherReal>

    CuMatrix<Real>::CuMatrix(const MatrixBase<OtherReal> &other, MatrixTransposeType trans) {

    if (trans == kNoTrans)

    this->Resize(other.NumRows(), other.NumCols(), kUndefined);

    else

    this->Resize(other.NumCols(), other.NumRows(), kUndefined);

    this->CopyFromMat(other, trans);

    }

       

    template<typename Real>

    void CuMatrixBase<Real>::CopyFromMat(const MatrixBase<Real> &src,

    MatrixTransposeType trans) {

    #if HAVE_CUDA == 1

    if (CuDevice::Instantiate().Enabled()) {

    if (trans == kNoTrans) {

    KALDI_ASSERT(src.NumRows() == num_rows_ && src.NumCols() == num_cols_);

    CuTimer tim;

     

    MatrixIndexT dst_pitch = stride_*sizeof(Real);

    MatrixIndexT src_pitch = src.Stride()*sizeof(Real);

    MatrixIndexT width = src.NumCols()*sizeof(Real);

    CU_SAFE_CALL(cudaMemcpy2DAsync(data_, dst_pitch, src.Data(), src_pitch,

    width, src.NumRows(), cudaMemcpyHostToDevice,

    cudaStreamPerThread));

    CU_SAFE_CALL(cudaStreamSynchronize(cudaStreamPerThread));

     

    CuDevice::Instantiate().AccuProfile("CuMatrixBase::CopyFromMat(from CPU)", tim);

    } else {

    CuMatrix<Real> trans_mat(src); // Do the transpose on the GPU board.

    this->CopyFromMat(trans_mat, kTrans);

    }

    } else

    #endif

    {

    Mat().CopyFromMat(src, trans);

    }

    }

     

  • 相关阅读:
    移动端开发基础【4】uniapp项目发布
    移动端开发案例【3】通讯录开发
    移动端开发基础【2】uni-app项目调试
    np.cross, np.count_nonzeros, np.isnan, np.transpose
    numpy中用None扩充维度
    NTU RGB+D数据集,骨架数据可视化
    文件映射,mmap
    转:Python pickle模块:实现Python对象的持久化存储
    Temporal Convolutional Networks (TCN)资料,扩张卷积
    梯度消失和爆炸,RNN,LSTM
  • 原文地址:https://www.cnblogs.com/JarvanWang/p/11759255.html
Copyright © 2011-2022 走看看