zoukankan      html  css  js  c++  java
  • caffe训练自己的数据集

    https://www.cnblogs.com/wktwj/p/6715110.html

    默认caffe已经编译好了,并且编译好了pycaffe

    1 数据准备

    首先准备训练和测试数据集,这里准备两类数据,分别放在文件夹0和文件夹1中(之所以使用0和1命名数据类别,是因为方便标注数据类别,直接用文件夹的名字即可)。即训练数据集:/data/train/0、/data/train/1  训练数据集:/data/val/0、/data/val/1。

    数据准备好之后,创建记录数据文件和对应标签的txt文件

    (1)创建训练数据集的train.txt

    复制代码
     1 import os
     2 f =open(r'train.txt',"w")
     3 path = os.getcwd()+'/data/train/'
     4 for filename in os.listdir(path) :
     5     count = 0
     6     for file in os.listdir(path+filename) :
     7         count = count + 1
     8         ff='/'+filename+"/"+file+" "+filename+"
    "
     9         f.write(ff)
    10     print '{} class: {}'.format(filename,count)
    11 f.close()
    复制代码

    (2)创建测试数据集val.txt

    复制代码
     1 import os
     2 f =open(r'val.txt',"w")
     3 path = os.getcwd()+'/data/val/'
     4 for filename in os.listdir(path) :
     5     count = 0
     6     for file in os.listdir(path+filename) :
     7         count = count + 1
     8         ff='/'+filename+"/"+file+" "+filename+"
    "
     9         f.write(ff)
    10     print '{} class: {}'.format(filename,count)
    11 f.close()
    复制代码

    注意,txt中文件的路径为: /类别文件夹名/文件名(空格,不能是制表符)类别

    2 创建LMDB数据文件

    创建createlmdb.sh使用caffe自带的(bulid/tools下的)convert_imageset创建LMDB数据文件,主要是注意数据文件以及上一步生成的txt文件的位置,注意数据文件的RESIZE,后边在进行训练和测试的时候还要用到,其余就是文件的路径的问题了。

    复制代码
     1 #!/usr/bin/env sh
     2 
     3 CAFFE_ROOT=/home/caf/object/caffe
     4 TOOLS=$CAFFE_ROOT/build/tools
     5 TRAIN_DATA_ROOT=/home/caf/wk/learn/data/train
     6 VAL_DATA_ROOT=/home/caf/wk/learn/data/val
     7 DATA=/home/caf/wk/learn/data
     8 EXAMPLE=/home/caf/wk/learn/data/lmdb
     9 # Set RESIZE=true to resize the images to 60 x 60. Leave as false if images have
    10 # already been resized using another tool.
    11 RESIZE=true
    12 if $RESIZE; then
    13   RESIZE_HEIGHT=227
    14   RESIZE_WIDTH=227
    15 else
    16   RESIZE_HEIGHT=0
    17   RESIZE_WIDTH=0
    18 fi
    19 
    20 if [ ! -d "$TRAIN_DATA_ROOT" ]; then
    21   echo "Error: TRAIN_DATA_ROOT is not a path to a directory: $TRAIN_DATA_ROOT"
    22   echo "Set the TRAIN_DATA_ROOT variable in create_face_48.sh to the path" 
    23        "where the face_48 training data is stored."
    24   exit 1
    25 fi
    26 
    27 if [ ! -d "$VAL_DATA_ROOT" ]; then
    28   echo "Error: VAL_DATA_ROOT is not a path to a directory: $VAL_DATA_ROOT"
    29   echo "Set the VAL_DATA_ROOT variable in create_face_48.sh to the path" 
    30        "where the face_48 validation data is stored."
    31   exit 1
    32 fi
    33 
    34 echo "Creating train lmdb..."
    35 
    36 GLOG_logtostderr=1 $TOOLS/convert_imageset 
    37     --resize_height=$RESIZE_HEIGHT 
    38     --resize_width=$RESIZE_WIDTH 
    39     --shuffle 
    40     $TRAIN_DATA_ROOT 
    41     $DATA/train.txt 
    42     $EXAMPLE/face_train_lmdb
    43 
    44 echo "Creating val lmdb..."
    45 
    46 GLOG_logtostderr=1 $TOOLS/convert_imageset 
    47     --resize_height=$RESIZE_HEIGHT 
    48     --resize_width=$RESIZE_WIDTH 
    49     --shuffle 
    50     $VAL_DATA_ROOT 
    51     $DATA/val.txt 
    52     $EXAMPLE/face_val_lmdb
    53 
    54 echo "Done."
    复制代码

    3 定义网络

    caffe接受的网络模型是prototxt文件,对于caffe网络的定义语法有详细的解释,本次实验用的是AlexNet,保存在train_val.prototxt

     View Code

    创建超参数文件slover.prototxt,主要定义训练的参数,包括迭代次数,每迭代多少次保存模型文件,学习率等等,net就是刚才定义的训练网络,这里训练和测试使用同一个网络。

    复制代码
     1 net: "train_val.prototxt"
     2 test_iter: 2
     3 test_interval: 10
     4 base_lr: 0.001
     5 lr_policy: "step"
     6 gamma: 0.1
     7 stepsize: 100
     8 display: 20
     9 max_iter: 100
    10 momentum: 0.9
    11 weight_decay: 0.005
    12 solver_mode: GPU
    13 snapshot: 20
    14 snapshot_prefix: "model/"
    复制代码

    4 训练模型

    创建train.sh使用GPU进行训练,否则太慢!!!

    1 #!/usr/bin/env sh
    2 CAFFE_ROOT=/home/caf/object/caffe
    3 SLOVER_ROOT=/home/caf/wk/learn
    4 $CAFFE_ROOT/build/tools/caffe train --solver=$SLOVER_ROOT/slover.prototxt --gpu=0

     在model文件夹下会生成caffemodel文件,使用这些文件用于图像的分类等操作。

    4 测试

    创建deploy.prototxt进行测试,和训练网络一样,只不过用于实际分类的网络并不需要训练网络那些参数了,因此需要重新定义一个模型文件,测试的图片在该模型中进行。

    deploy.prototxt文件和train_val.prototxt文件不同的地方在于:

    (1)输入的数据不再是LMDB,也不分为测试集和训练集,输入的类型为Input,定义的维度,和训练集的数据维度保持一致,227*227,否则会报错;

    (2)去掉weight_filler和bias_filler,这些参数已经存在于caffemodel中了,由caffemodel进行初始化。

    (3)去掉最后的Accuracy层和loss层,换位Softmax层,表示分为某一类的概率。

     View Code

    用于训练的python代码,使用caffe中python的接口,主要定义好自己训练好的参数文件,模型文件的位置,以及均值文件的位置。

    复制代码
     1 import numpy as np
     2 import matplotlib.pyplot as plt
     3 
     4 import sys
     5 caffe_root="/home/caf/object/caffe/"
     6 sys.path.insert(0,caffe_root+'python')
     7 import caffe
     8 caffe.set_device(0)
     9 caffe.set_mode_gpu()
    10 model_def = 'deploy.prototxt'
    11 model_weights = 'model/_iter_100.caffemodel'
    12 net = caffe.Net(model_def,
    13                 model_weights,  
    14                 caffe.TEST)     
    15 mu = np.load(caffe_root + 'python/caffe/imagenet/ilsvrc_2012_mean.npy')
    16 mu = mu.mean(1).mean(1)
    17 #print 'mean-subtracted values:', zip('BGR', mu)
    18 transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
    19 transformer.set_transpose('data', (2,0,1))
    20 transformer.set_mean('data', mu)
    21 transformer.set_raw_scale('data', 255)    
    22 transformer.set_channel_swap('data', (2,1,0))
    23 net.blobs['data'].reshape(3,227, 227)
    24 image = caffe.io.load_image('test.jpg')
    25 transformed_image = transformer.preprocess('data', image)
    26 #plt.imshow(image)
    27 #plt.show()
    28 net.blobs['data'].data[...] = transformed_image
    29 output = net.forward()  
    30 output_prob = output['prob']
    31 print output_prob
    32 print 'predicted class is:', output_prob.argmax()
    复制代码

    遇到的问题

    (1)标签文件不能用制表符,必须是空格,否则会找不到数据文件

    (2)CUDA问题,报一个类似叫CUDASuccess的错误,说明GPU空间不够,需要释放空间,使用  nvidia-smi  命令查看那个程序占用GPU过高,使用   kill -9 PID结束掉即可

    (3)由于caffe版本的问题,层的定义 有layer和layers,使用layer定义,type需要加双引号,是字符格式;使用layers定义,type不用加双引号,变为全大写字母

        
  • 相关阅读:
    js string to int
    有的事情是无可奈何的,有的事情是能够改变的……
    拼接字符串去掉最后多余的串,JSON的遍历
    git入门
    js的闭包
    nodejs系列(二)REPL交互解释 事件循环
    nodejs系列(一)安装和介绍
    学习mongo系列(十一)关系
    学习mongo系列(十)MongoDB 备份(mongodump)与恢复(mongorerstore) 监控(mongostat mongotop)
    学习mongo系列(九)索引,聚合,复制(副本集),分片
  • 原文地址:https://www.cnblogs.com/shuimuqingyang/p/10097392.html
Copyright © 2011-2022 走看看