zoukankan      html  css  js  c++  java
  • [caffe学习笔记][03][生成配置文件]

    说明:

    caffe通过配置文件prototxt来描述网路结构,通过Python接口来生成网路配置文件比较简单。这里生成train.prototxttest.prototxt,分别用于训练阶段和验证阶段。


    步骤:

    1.生成配置文件

    touch create_train_val_prototxt.py

    spyder create_train_val_prototxt.py

     1 # -*- coding: utf-8 -*-
     2 """
     3 yuandanfei Editor
     4 
     5 This is a temporary script file.
     6 """
     7 
     8 from caffe import layers as L, params as P, to_proto
     9 path = '/home/yuandanfei/work/caffe/mnist/' #root path
    10 train_lmdb = path + 'train_lmdb'           #train_lmdb path
    11 test_lmdb = path + 'test_lmdb'             #test_lmdb path
    12 mean_file = path + 'mean.binaryproto'      #mean.binaryproto path
    13 train_proto = path + 'train.prototxt'      #train.prototxt path
    14 test_proto = path + 'test.prototxt'        #test.prototxt path
    15 
    16 def create_net(lmdb, batch_size, include_acc=False):
    17     #input layer
    18     data, label = L.Data(source=lmdb, backend=P.Data.LMDB, batch_size=batch_size, ntop=2, 
    19                          transform_param=dict(crop_size=28, mean_file=mean_file, mirror=True))
    20     #conv1 layer   n*c*w*h;c1=num_output;w1/h1=(w0/h0+2*pad-kernel_size)/stride+1;if stride=1 and pad=(kernel_size-1)/2, then w1/h1=w0/h0;
    21     conv1 = L.Convolution(data, kernel_size=5, stride=1, pad=2, num_output=16, weight_filler=dict(type='xavier'))
    22     #reul1 layer
    23     relu1 = L.ReLU(conv1, in_place=True)
    24     #pool1 layer   n*c*w*h;c1=c0;w1/h1=(w0/h0+2*pad-kernel_size)/stride+1;
    25     pool1 = L.Pooling(relu1, pool=P.Pooling.MAX, kernel_size=3, stride=2)
    26     #conv2 layer
    27     conv2 = L.Convolution(pool1, kernel_size=3, stride=1, pad=1, num_output=32, weight_filler=dict(type='xavier'))
    28     #relu2 layer
    29     relu2 = L.ReLU(conv2, in_place=True)
    30     #pool2 layer
    31     pool2 = L.Pooling(relu2, pool=P.Pooling.MAX, kernel_size=3, stride=2)
    32     #fc3 layer
    33     fc3 = L.InnerProduct(pool2, num_output=1024, weight_filler=dict(type='xavier'))
    34     #relu3 layer
    35     relu3 = L.ReLU(fc3, in_place=True)
    36     #drop3 layer
    37     drop3 = L.Dropout(relu3, in_place=True)
    38     #fc4 layer
    39     fc4 = L.InnerProduct(drop3, num_output=10, weight_filler=dict(type='xavier'))
    40     #softmax-loss layer
    41     loss = L.SoftmaxWithLoss(fc4, label)
    42     #accuracy layer
    43     if include_acc: #test
    44         acc = L.Accuracy(fc4, label)
    45         return to_proto(loss, acc)
    46     else:           #train
    47         return to_proto(loss)
    48 
    49     
    50 def write_net():
    51     #write train prototxt
    52     with open(train_proto, 'w') as f:
    53         f.write(str(create_net(train_lmdb, batch_size=64)))
    54     
    55     #write test prototxt
    56     with open(test_proto, 'w') as f:
    57         f.write(str(create_net(test_lmdb, batch_size=32, include_acc=True)))
    58 
    59 
    60 if __name__ == '__main__':
    61     write_net()
    62     


    2.绘制网络模型

    touch draw_net.sh

    vim draw_net.sh

    1 #!/usr/bin/bash
    2 
    3 DATA=train
    4 BULID=/home/yuandanfei/caffe/python/draw_net.py
    5 
    6 python $BULID ../out/$DATA.prototxt ../out/$DATA.png --rankdir=BT

    参考资料:

    https://www.cnblogs.com/denny402/p/5679037.html

    https://www.cnblogs.com/denny402/p/5106764.html

  • 相关阅读:
    Java-对象数组排序
    aoj 0118 Property Distribution
    poj 3009 Curling 2.0
    poj 1979 Red and Black
    AtCoder Regular Contest E
    AtCoder Beginner Contest 102
    AtCoder Beginner Contest 104
    SoundHound Inc. Programming Contest 2018
    poj 3233 Matrix Power Series
    poj 3734 Blocks
  • 原文地址:https://www.cnblogs.com/d442130165/p/12722308.html
Copyright © 2011-2022 走看看