zoukankan      html  css  js  c++  java
  • pytorch使用不完全文档

    1. 利用tensorboard看loss:

    tensorflow和pytorch环境是好的的话,链接中的logger.py拉到自己的工程里,train.py里添加相应代码,直接能用。

    关于环境,小小折腾了下,大概一小时:

    大概一年前用过tensorflow, mac里的环境还在,当时装的虚拟环境,由于工程中用到了caffe,在虚拟环境中编译caffe,去掉之前工程中用的caffe路径,把虚拟环境中的caffe路径添加到虚拟环境下的pythonpath

    export PYTHONPATH=/Users/tensorflow虚拟环境中的caffe路径/python:$PYTHONPATH
    pip install torch torchvision

    然后进入python,依次import caffe torch tensorflow没问题就可以了,一个粗陋的loss曲线到手:

    2. 多进程load数据

    这个折腾了一天,最开始打算用Queue实现,陷在里面半天,返回label没问题,一旦加入图像就死掉了,自己生成一个超大的数据也不行,怀疑是Queue容量有限,暂且存疑;

    考虑pytorch的__getitem__就是多进程的,打算在里面判断,如果self.src_map没有对应的key就读入,有就直接用。结果在mac上一个worker没啥问题,放到服务器上拆开多个子线程工作,会一直重新读入数据,因为self.src_map是属于主进程的,主进程并不会跟子进程共享这个字典,所以对每个子进程来说self.src_map都是空的,定位到这个问题就好办,最后是用链接里面“进程之间共享数据”方法实现的:

    class FaceDataSet(data.Dataset):
        def __init__(self, root, list, dst_size = 128, n_worker = 6):
            super(DataSet, self).__init__()
            self.all_data = []
            self.dst_size = dst_size
            self.src_map =  {}
            self.n_worker = n_worker
    
            fread =  open(root +'/'+ list, 'r')
            for line in fread.readlines():
                img_filename = line.strip()
                pt_filename = img_filename.replace('.jpg', '.txt')
                imgfile_fullpath = os.path.join(root, img_filename)
                ptfile_fullpath = os.path.join(root, pt_filename)
                label = np.loadtxt(ptfile_fullpath, dtype=float)
                #路径保存到all_data
                if os.path.exists(ptfile_fullpath) and os.path.exists(imgfile_fullpath):
                    self.all_data.append(DataPath(imgfile_fullpath=imgfile_fullpath, ptfile_fullpath=ptfile_fullpath))
    
            print time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) +
                  ' All data length: %d' % len(self.all_data) + 
                  " load workers:%d"%(self.n_worker)
            data_length = len(self.all_data)
            q_in = [[] for i in range(self.n_worker)]
            for data_idx in range(data_length):
                q_in[data_idx % len(q_in)].append(self.all_data[data_idx])
            with multiprocessing.Manager() as MG:
                p_map = [ multiprocessing.Manager().dict() for i in range(self.n_worker) ]
                readers = [multiprocessing.Process( target=self.load_func, args=(q_in[i], p_map[i]) )
                       for i in range(self.n_worker) ]
                for p in readers:
                    p.start()
                for p in readers:
                    p.join()
                # 至此,主程序会等待最后一个进程执行完
                for map in p_map:         #把每个子进程读取结果拼到一起
                    self.src_map.update(map)
            print time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + " src_map length:%d"%( len(self.src_map.keys()) )
    
        def load_func(self, qin, pmap):
            for item in qin:
                img_path, pt_path = item
                img = cv2.imread(img_path)   #读入数据,每个子进程读入的先放在自己的字典里
                label = np.loadtxt(pt_path).astype(np.float32)
                pmap[img_path] = self.__get_small_img(img, label, dst_size=self.dst_size, data_aug=False)

    3. 打印网络每层输出形状

    pip install torchsummary #命令行安装
    from torchsummary import summary  #代码
    summary(model, (1, 112, 112))

    4. view()的用法

    define __init__:
        self.conv = conv2d(512,512,kernel_size=7,stride=1,pad=0,group=512)
        self.fc = nn.Linear(512,128)
    define forward(self, x):        #假设batchsize = 128
        x = self.conv               #(128,512,1,1,)
        x = x.view(x.size(0), -1)   #(128,512) - x.size(0)=batchsize, 按batchsize拉平
        x = self.fc(x)              #(128,128)
        return x

     5. 单机多卡显存占用不均衡: 不是很明显,应该没有抓住主要矛盾,聊胜于无。

    checkpoint = torch.load("checkpoint.pth", map_location=torch.device('cpu'))  #map到cpu只能起到一点点效果

    6. 在gpu卡n上训练的网络换卡加载报错:Attempting to deserialize object on CUDA device 4 but torch.cuda.device_count()

     torch.load(model_path, map_location='cuda:0' )

    7. Apex半精度训练:

    虽然没有宣称的3行代码那么简单,也是很容易了,过一遍示例代码,花不到一小时可以跑起来。参考官方git quick start,完全安装没有成功,只安装了python版,然后执行下main_amp.py可以正常import了,再按示例代码添加几行即可。其他参考这里

    8. pytorch和numpy默认数据格式不一致导致的CUDNN错误:

    一个大坑,从自己写的dataloader加载数据莫名其妙的报错:RuntimeError: cuDNN error: CUDNN_STATUS_BAD_PARAM;

      1) 训练代码中用torch.randn()生成的数据作为input没问题,生成的数据格式是troch.float32;

      2) np.loadtxt()加载数据默认是np.float64格式,pytroch网络是float,所以在执行input.to(Device)之前先input = input.float()强制转换一下;

      3) np.load()加载数据保留了保存数据时候的np.float32格式,送给pytorch的dataloader之前需要astype成float64,否则会报错"RuntimmeError: Expected object of scalar type Double but got scalar type Float ....",所以在送给dataloader之前先转成np.float64,训练代码中再和第二条一样,强制转为float

    总结就是dataloader需要double型的数据,网络需要float型数据,匹配一致即可。

    -------------2019.4.22----------

    关于环境, cuda 9.1 + anaconda2用pip install和conda install都失败,最后分别下载pytorch和vision源码

    python setup.py install

    装完pytorch需要重启下才能找到

    nn.ReLU和F.ReLU的区别

    torch.where和np.where的区别,另外,np.where可以缺省后面两个参数返回满足条件的索引,而torch.where省略参数会报错

    打印网络参数

    for name, parameters in TEACHER.named_parameters():
          print(name,parameters.size())
  • 相关阅读:
    权值线段树&&可持久化线段树&&主席树
    扩展中国剩余定理(EXCRT)快速入门
    jquery学习记录
    隐藏vbs执行cmd命令的窗口
    eclipse打开出错 Error: opening registry key 'SoftwareJavaSoftJava Runtime Environment'
    正则表达式学习总结
    什么是xss攻击?
    什么是浏览器的同源策略?
    关于axios的封装
    关于递归。
  • 原文地址:https://www.cnblogs.com/zhengmeisong/p/10130779.html
Copyright © 2011-2022 走看看