zoukankan      html  css  js  c++  java
  • caffe源代码分析--Blob类代码研究


    作者:linger

    转自须注明转自:http://blog.csdn.net/lingerlanlan/article/details/24379689



    数据成员

    shared_ptr<SyncedMemory>data_;//data数据。指向SyncedMemory的智能指针

    shared_ptr<SyncedMemory>diff_;//表示“差”。用于更新data_

    intnum_;

    intchannels_;

    intheight_;

    intwidth_;

    intcount_;



    构造函数

    Blob():num_(0),channels_(0),height_(0),width_(0),count_(0),data_(),diff_(){}

    功能:简单的初始化


    explicitBlob(constintnum,constintchannels,constintheight,constintwidth);

    功能:调用Reshape函数。初始化数据成员

    template<typenameDtype>

    Blob<Dtype>::Blob(constintnum,constintchannels,constintheight,

    constintwidth) {

    Reshape(num,channels, height, width);

    }


    析构函数

    virtual~Blob(){}

    功能:啥都没做?






    voidReshape(constintnum,constintheight,

    constintwidth,constintchannels);

    功能:初始化数据成员,智能指针指向SyncedMemory对象。此时SyncedMemory对象事实上并没有为自己的“数据”申请内存,仅仅是自己“数据”的大小(size)。

    template<typenameDtype>

    voidBlob<Dtype>::Reshape(constintnum,constintchannels,constintheight,

    constintwidth) {

    CHECK_GE(num,0);

    CHECK_GE(channels,0);

    CHECK_GE(height,0);

    CHECK_GE(width,0);

    num_= num;

    channels_= channels;

    height_= height;

    width_= width;

    count_=num_*channels_*height_*width_;

    if(count_){

    data_.reset(newSyncedMemory(count_*sizeof(Dtype)));

    diff_.reset(newSyncedMemory(count_*sizeof(Dtype)));

    }else{

    data_.reset(reinterpret_cast<SyncedMemory*>(NULL));

    diff_.reset(reinterpret_cast<SyncedMemory*>(NULL));

    }

    }



    成员訪问函数

    功能:就是返回一些成员变量

    inlineintnum()const{returnnum_;}

    inlineintchannels()const{returnchannels_;}

    inlineintheight()const{returnheight_;}

    inlineintwidth()const{returnwidth_;}

    inlineintcount()const{returncount_;}

    inlineintoffset(constintn,constintc = 0, constinth = 0,constintw = 0) const{

    return((n * channels_+ c) *height_+ h) *width_+ w;

    //计算偏移量,由于数据在内存是一维数组形式的,所以须要计算偏移量来訪问

    }


    数据”指针返回函数

    功能:事实上这些函数就是调用SyncedMemory的函数,来返回数据的指针

    constDtype*cpu_data()const;

    constDtype*gpu_data()const;

    constDtype*cpu_diff()const;

    constDtype*gpu_diff()const;

    Dtype*mutable_cpu_data();

    Dtype*mutable_gpu_data();

    Dtype*mutable_cpu_diff();

    Dtype*mutable_gpu_diff();


    inlineDtypedata_at(constintn,constintc,constinth,

    constintw)const{

    //cpu訪问数据data

    return*(cpu_data()+ offset(n, c, h, w));

    }


    inlineDtypediff_at(constintn,constintc,constinth,

    constintw)const{

    //cpu訪问数据diff

    return*(cpu_diff() + offset(n, c, h, w));

    }



    函数voidUpdate()

    功能:更新data_的数据,就是减去diff_的数据。



    template<typenameDtype>

    voidBlob<Dtype>::Update(){

    //We will perform update based on where the data is located.

    switch(data_->head()){

    caseSyncedMemory::HEAD_AT_CPU:

    //perform computation on CPU

    caffe_axpy<Dtype>(count_,Dtype(-1),

    reinterpret_cast<constDtype*>(diff_->cpu_data()),

    reinterpret_cast<Dtype*>(data_->mutable_cpu_data()));

    //math_functions.cpp能够找到该函数的实现。事实上这函数也是封装了mkl的函数。这里调用是为了实现了两个向量的减法。

    break;

    caseSyncedMemory::HEAD_AT_GPU:

    caseSyncedMemory::SYNCED:

    //perform computation on GPU

    caffe_gpu_axpy<Dtype>(count_,Dtype(-1),

    reinterpret_cast<constDtype*>(diff_->gpu_data()),

    reinterpret_cast<Dtype*>(data_->mutable_gpu_data()));

    //math_functions.cpp能够找到该函数的实现。事实上这函数也是封装了cublas的函数。这里调用是为了实现了两个向量的减法。

    break;

    default:

    LOG(FATAL)<<"Syncedmemnot initialized.";

    }

    }



    函数voidCopyFrom(constBlob<Dtype>&source,boolcopy_diff = false,boolreshape = false);

    功能:从source拷贝数据。copy_diff作为标志来区分是拷贝data还是拷贝diff

    template<typenameDtype>

    voidBlob<Dtype>::CopyFrom(constBlob&source,boolcopy_diff,boolreshape) {

    if(num_!= source.num() || channels_!= source.channels() ||

    height_!= source.height() || width_!= source.width()) {

    if(reshape) {

    Reshape(source.num(),source.channels(), source.height(), source.width());

    }else{

    LOG(FATAL)<<"Tryingto copy blobs of different sizes.";

    }

    }

    switch(Caffe::mode()){

    caseCaffe::GPU:

    if(copy_diff){

    CUDA_CHECK(cudaMemcpy(diff_->mutable_gpu_data(),source.gpu_diff(),

    sizeof(Dtype)*count_,cudaMemcpyDeviceToDevice));

    }else{

    CUDA_CHECK(cudaMemcpy(data_->mutable_gpu_data(),source.gpu_data(),

    sizeof(Dtype)*count_,cudaMemcpyDeviceToDevice));

    }

    break;

    caseCaffe::CPU:

    if(copy_diff){

    memcpy(diff_->mutable_cpu_data(),source.cpu_diff(),

    sizeof(Dtype)*count_);

    }else{

    memcpy(data_->mutable_cpu_data(),source.cpu_data(),

    sizeof(Dtype)*count_);

    }

    break;

    default:

    LOG(FATAL)<<"Unknowncaffemode.";

    }

    }




    函数voidFromProto(constBlobProto&proto);

    功能:从proto读数据进来,事实上就是反序列化

    template<typenameDtype>

    voidBlob<Dtype>::FromProto(constBlobProto&proto){

    Reshape(proto.num(),proto.channels(),proto.height(),proto.width());

    //copy data

    Dtype*data_vec = mutable_cpu_data();

    for(inti = 0; i < count_;++i) {

    data_vec[i]=proto.data(i);

    }

    if(proto.diff_size()> 0) {

    Dtype*diff_vec = mutable_cpu_diff();

    for(inti = 0; i < count_;++i) {

    diff_vec[i]=proto.diff(i);

    }

    }

    }



    函数voidToProto(BlobProto*proto,boolwrite_diff = false)const;

    功能:序列化到proto保存

    template<typenameDtype>

    voidBlob<Dtype>::ToProto(BlobProto*proto,boolwrite_diff)const{

    proto->set_num(num_);

    proto->set_channels(channels_);

    proto->set_height(height_);

    proto->set_width(width_);

    proto->clear_data();

    proto->clear_diff();

    constDtype*data_vec = cpu_data();

    for(inti = 0; i < count_;++i) {

    proto->add_data(data_vec[i]);

    }

    if(write_diff) {

    constDtype*diff_vec = cpu_diff();

    for(inti = 0; i < count_;++i) {

    proto->add_diff(diff_vec[i]);

    }

    }

    }


  • 相关阅读:
    【React Native】某个页面禁用物理返回键
    【React Native】DeviceEventEmitter监听通知及带参数传值
    转载【React Native代码】手写验证码倒计时组件
    【React Native】 中设置 APP 名称、应用图标、为安卓添加启动图
    【React Native错误集】* What went wrong: Execution failed for task ':app:installDebug'.
    【React Native错误集】Import fails with "Failed to execute 'ImportScripts' on 'WorkerGlobalScope'"
    【React Native错误集】Android error “Could not get BatchedBridge, make sure your bundle is packaged properly” on start of app
    「React Native笔记」在React的 setState 中操作数组和对象的多种方法(合集)
    【React Native】Error: Attribute application@allowBackup value=(false) from AndroidManifest.xml
    坚果云如何使用二次验证码/谷歌身份验证器/两步验证/虚拟MFA?
  • 原文地址:https://www.cnblogs.com/tlnshuju/p/6752284.html
Copyright © 2011-2022 走看看