zoukankan      html  css  js  c++  java
  • 从零开始学习MXnet(二)之dataiter

      MXnet的设计结构是C++做后端运算,python、R等做前端来使用,这样既兼顾了效率,又让使用者方便了很多,完整的使用MXnet训练自己的数据集需要了解几个方面。今天我们先谈一谈Data iterators。

      MXnet中的data iterator和python中的迭代器是很相似的, 当其内置方法next被call的时候它每次返回一个 data batch。所谓databatch,就是神经网络的输入和label,一般是(n, c, h, w)的格式的图片输入和(n, h, w)或者标量式样的label。直接上官网上的一个简单的例子来说说吧。

      

     1 import numpy as np
     2 class SimpleIter:
     3     def __init__(self, data_names, data_shapes, data_gen,
     4                  label_names, label_shapes, label_gen, num_batches=10):
     5         self._provide_data = zip(data_names, data_shapes)
     6         self._provide_label = zip(label_names, label_shapes)
     7         self.num_batches = num_batches
     8         self.data_gen = data_gen
     9         self.label_gen = label_gen
    10         self.cur_batch = 0
    11 
    12     def __iter__(self):
    13         return self
    14 
    15     def reset(self):
    16         self.cur_batch = 0        
    17 
    18     def __next__(self):
    19         return self.next()
    20 
    21     @property
    22     def provide_data(self):
    23         return self._provide_data
    24 
    25     @property
    26     def provide_label(self):
    27         return self._provide_label
    28 
    29     def next(self):
    30         if self.cur_batch < self.num_batches:
    31             self.cur_batch += 1
    32             data = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_data, self.data_gen)]
    33             assert len(data) > 0, "Empty batch data."
    34             label = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_label, self.label_gen)]
    35             assert len(label) > 0, "Empty batch label."
    36             return SimpleBatch(data, label)
    37         else:
    38             raise StopIteration

      上面的代码是最简单的一个dataiter了,没有对数据的预处理,甚至于没有自己去读取数据,但是基本的意思是到了,一个dataiter必须要实现上面的几个方法,provide_data返回的格式是(dataname, batchsize, channel, width, height), provide_label返回的格式是(label_name, batchsize, width, height),reset()的目的是在每个epoch后打乱读取图片的顺序,这样随机采样的话训练效果会好一点,一般情况下是用shuffle你的lst(上篇用来读取图片的lst)实现的,next()的方法就很显然了,用来返回你的databatch,如果出现问题...记得raise stopIteration,这里或许用try更好吧...需要注意的是,databatch返回的数据类型是mx.nd.ndarry。

      下面是我最近做segmentation的时候用的一个稍微复杂的dataiter,多了预处理和shuffle等步骤:

      

      1 # pylint: skip-file
      2 import random
      3 
      4 import cv2
      5 import mxnet as mx
      6 import numpy as np
      7 import os
      8 from mxnet.io import DataIter, DataBatch
      9 
     10 
     11 class FileIter(DataIter): #一般都是继承DataIter
     12     """FileIter object in fcn-xs example. Taking a file list file to get dataiter.
     13     in this example, we use the whole image training for fcn-xs, that is to say
     14     we do not need resize/crop the image to the same size, so the batch_size is
     15     set to 1 here
     16     Parameters
     17     ----------
     18     root_dir : string
     19         the root dir of image/label lie in
     20     flist_name : string
     21         the list file of iamge and label, every line owns the form:
     22         index 	 image_data_path 	 image_label_path
     23     cut_off_size : int
     24         if the maximal size of one image is larger than cut_off_size, then it will
     25         crop the image with the minimal size of that image
     26     data_name : string
     27         the data name used in symbol data(default data name)
     28     label_name : string
     29         the label name used in symbol softmax_label(default label name)
     30     """
     31 
     32     def __init__(self, root_dir, flist_name, rgb_mean=(117, 117, 117),
     33                  data_name="data", label_name="softmax_label", p=None):
     34         super(FileIter, self).__init__()
     35 
     36         self.fac = p.fac #这里的P是自己定义的config
     37         self.root_dir = root_dir
     38         self.flist_name = os.path.join(self.root_dir, flist_name)
     39         self.mean = np.array(rgb_mean)  # (R, G, B)
     40         self.data_name = data_name
     41         self.label_name = label_name
     42         self.batch_size = p.batch_size
     43         self.random_crop = p.random_crop
     44         self.random_flip = p.random_flip
     45         self.random_color = p.random_color
     46         self.random_scale = p.random_scale
     47         self.output_size = p.output_size
     48         self.color_aug_range = p.color_aug_range
     49         self.use_rnn = p.use_rnn
     50         self.num_hidden = p.num_hidden
     51         if self.use_rnn:
     52             self.init_h_name = 'init_h'
     53             self.init_h = mx.nd.zeros((self.batch_size, self.num_hidden))
     54         self.cursor = -1
     55 
     56         self.data = mx.nd.zeros((self.batch_size, 3, self.output_size[0], self.output_size[1]))
     57         self.label = mx.nd.zeros((self.batch_size, self.output_size[0] / self.fac, self.output_size[1] / self.fac))
     58         self.data_list = []
     59         self.label_list = []
     60         self.order = []
     61         self.dict = {}
     62         lines = file(self.flist_name).read().splitlines()
     63         cnt = 0
     64         for line in lines: #读取lst,为后面读取图片做好准备
     65             _, data_img_name, label_img_name = line.strip('
    ').split("	")
     66             self.data_list.append(data_img_name)
     67             self.label_list.append(label_img_name)
     68             self.order.append(cnt)
     69             cnt += 1
     70         self.num_data = cnt
     71         self._shuffle()
     72 
     73     def _shuffle(self):
     74         random.shuffle(self.order)
     75 
     76     def _read_img(self, img_name, label_name):
     77      # 这个是在服务器上跑的时候,因为数据集很小,而且经常被同事卡IO,所以我就把数据全部放进了内存
     78         if os.path.join(self.root_dir, img_name) in self.dict:
     79             img = self.dict[os.path.join(self.root_dir, img_name)]
     80         else:
     81             img = cv2.imread(os.path.join(self.root_dir, img_name))
     82             self.dict[os.path.join(self.root_dir, img_name)] = img
     83 
     84         if os.path.join(self.root_dir, label_name) in self.dict:
     85             label = self.dict[os.path.join(self.root_dir, label_name)]
     86         else:
     87             label = cv2.imread(os.path.join(self.root_dir, label_name),0)
     88             self.dict[os.path.join(self.root_dir, label_name)] = label
     89 
     90 
     91      # 下面是读取图片后的一系统预处理工作
     92         if self.random_flip:
     93             flip = random.randint(0, 1)
     94             if flip == 1:
     95                 img = cv2.flip(img, 1)
     96                 label = cv2.flip(label, 1)
     97         # scale jittering
     98         scale = random.uniform(self.random_scale[0], self.random_scale[1])
     99         new_width = int(img.shape[1] * scale)  # 680
    100         new_height = int(img.shape[0] * scale)  # new_width * img.size[1] / img.size[0]
    101         img = cv2.resize(img, (new_width, new_height), interpolation=cv2.INTER_NEAREST)
    102         label = cv2.resize(label, (new_width, new_height), interpolation=cv2.INTER_NEAREST)
    103         #img = cv2.resize(img, (900,450), interpolation=cv2.INTER_NEAREST)
    104         #label = cv2.resize(label, (900, 450), interpolation=cv2.INTER_NEAREST)
    105         if self.random_crop:
    106             start_w = np.random.randint(0, img.shape[1] - self.output_size[1] + 1)
    107             start_h = np.random.randint(0, img.shape[0] - self.output_size[0] + 1)
    108             img = img[start_h : start_h + self.output_size[0], start_w : start_w + self.output_size[1], :]
    109             label = label[start_h : start_h + self.output_size[0], start_w : start_w + self.output_size[1]]
    110         if self.random_color:
    111             img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
    112             hue = random.uniform(-self.color_aug_range[0], self.color_aug_range[0])
    113             sat = random.uniform(-self.color_aug_range[1], self.color_aug_range[1])
    114             val = random.uniform(-self.color_aug_range[2], self.color_aug_range[2])
    115             img = np.array(img, dtype=np.float32)
    116             img[..., 0] += hue
    117             img[..., 1] += sat
    118             img[..., 2] += val
    119             img[..., 0] = np.clip(img[..., 0], 0, 255)
    120             img[..., 1] = np.clip(img[..., 1], 0, 255)
    121             img[..., 2] = np.clip(img[..., 2], 0, 255)
    122             img = cv2.cvtColor(img.astype('uint8'), cv2.COLOR_HSV2BGR)
    123             is_rgb = True
    124         #cv2.imshow('main', img)
    125         #cv2.waitKey()
    126         #cv2.imshow('maain', label)
    127         #cv2.waitKey()
    128         img = np.array(img, dtype=np.float32)  # (h, w, c)
    129         reshaped_mean = self.mean.reshape(1, 1, 3)
    130         img = img - reshaped_mean
    131         img[:, :, :] = img[:, :, [2, 1, 0]]
    132         img = img.transpose(2, 0, 1)
    133         # img = np.expand_dims(img, axis=0)  # (1, c, h, w)
    134 
    135         label_zoomed = cv2.resize(label, None, fx = 1.0 / self.fac, fy = 1.0 / self.fac)
    136         label_zoomed = label_zoomed.astype('uint8')
    137         return (img, label_zoomed)
    138 
    139     @property
    140     def provide_data(self):
    141         """The name and shape of data provided by this iterator"""
    142         if self.use_rnn:
    143             return [(self.data_name, (self.batch_size, 3, self.output_size[0], self.output_size[1])),
    144                     (self.init_h_name, (self.batch_size, self.num_hidden))]
    145         else:
    146             return [(self.data_name, (self.batch_size, 3, self.output_size[0], self.output_size[1]))]
    147 
    148     @property
    149     def provide_label(self):
    150         """The name and shape of label provided by this iterator"""
    151         return [(self.label_name, (self.batch_size, self.output_size[0] / self.fac, self.output_size[1] / self.fac))]
    152 
    153     def get_batch_size(self):
    154         return self.batch_size
    155 
    156     def reset(self):
    157         self.cursor = -self.batch_size
    158         self._shuffle()
    159 
    160     def iter_next(self):
    161         self.cursor += self.batch_size
    162         return self.cursor < self.num_data
    163 
    164     def _getpad(self):
    165         if self.cursor + self.batch_size > self.num_data:
    166             return self.cursor + self.batch_size - self.num_data
    167         else:
    168             return 0
    169 
    170     def _getdata(self):
    171         """Load data from underlying arrays, internal use only"""
    172         assert(self.cursor < self.num_data), "DataIter needs reset."
    173         data = np.zeros((self.batch_size, 3, self.output_size[0], self.output_size[1]))
    174         label = np.zeros((self.batch_size, self.output_size[0] / self.fac, self.output_size[1] / self.fac))
    175         if self.cursor + self.batch_size <= self.num_data:
    176             for i in range(self.batch_size):
    177                 idx = self.order[self.cursor + i]
    178                 data_, label_ = self._read_img(self.data_list[idx], self.label_list[idx])
    179                 data[i] = data_
    180                 label[i] = label_
    181         else:
    182             for i in range(self.num_data - self.cursor):
    183                 idx = self.order[self.cursor + i]
    184                 data_, label_ = self._read_img(self.data_list[idx], self.label_list[idx])
    185                 data[i] = data_
    186                 label[i] = label_
    187             pad = self.batch_size - self.num_data + self.cursor
    188             #for i in pad:
    189             for i in range(pad):
    190                 idx = self.order[i]
    191                 data_, label_ = self._read_img(self.data_list[idx], self.label_list[idx])
    192                 data[i + self.num_data - self.cursor] = data_
    193                 label[i + self.num_data - self.cursor] = label_
    194         return mx.nd.array(data), mx.nd.array(label)
    195 
    196     def next(self):
    197         """return one dict which contains "data" and "label" """
    198         if self.iter_next():
    199             data, label = self._getdata()
    200             data = [data, self.init_h] if self.use_rnn else [data]
    201             label = [label]
    202             return DataBatch(data=data, label=label,
    203                              pad=self._getpad(), index=None,
    204                              provide_data=self.provide_data,
    205                              provide_label=self.provide_label)
    206         else:
    207             raise StopIteration

        到这里基本上正常的训练我们就可以开始了,但是当你有了很多新的想法的时候,你又会遇到新的问题...比如:multi input/output怎么办?

        其实也很简单,只需要修改几个地方:

          1、provide_label和provide_data,注意到之前我们的return都是一个list,所以之间在里面添加和之前一样的格式就行了。

          2. next() 如果你需要传 data和depth两个输入,只需要传 input = sum([[data],[depth],[]])到databatch的data就行了,label也同理。

        值得一提的时候,MXnet的multi loss实现起来需要在写network的symbol的时候注意一点,假设你有softmax_loss和regression_loss。那么只要在最后return mx.symbol.Group([softmax_loss, regression_loss])。

        总之......That's all~~~~

      

  • 相关阅读:
    linux服务 ssh
    详细教你两台电脑之间传文件
    openstack之keystone
    Token
    mybatis返回刚刚插入数据的自增长的id值
    事务配置在applicationContext.xml文件中不起作用,控制不了异常回滚
    SSM框架整合
    ssm框架的小总结
    spring_mvc入门项目的小总结
    tcp和udp的网络编程(发送消息及回复)
  • 原文地址:https://www.cnblogs.com/daihengchen/p/6367701.html
Copyright © 2011-2022 走看看