zoukankan      html  css  js  c++  java
  • Caffe实现概述

    Caffe实现概述

    目录

    一、caffe配置文件介绍

    二、标准层的定义

     三、网络微调技巧

    四、Linux脚本使用及LMDB文件生成

    五、带你设计一个Caffe网络,用于分类任务

    一、caffe配置文件介绍

     

     

     

       

     

     

     二、标准层的定义

     

     

     三、网络微调技巧

     

     

     其中,multistep最为常用

     

     

     四、Linux脚本使用及LMDB文件生成

     

     

     五、带你设计一个Caffe网络,用于分类任务

     

     

     下面:

    使用pycaffe生成solver配置

    使用pycaffe生成caffe测试网络和训练网络

     

    数据集下载

    # demoCaffe

    数据集下载,cifar mnist:
    百度云盘:

    链接: https://pan.baidu.com/s/1bHFQUz7Q6BMBZv25AhsXKQ 密码: dva9
    链接: https://pan.baidu.com/s/1rPRjf2hanlYYjBQQDmIjNQ 密码: 5nhv

    1. lmdb数据制作:

    手动实现: https://blog.csdn.net/yx2017/article/details/72953537   

                   https://www.jianshu.com/p/9d7ed35960cb

    代码实现:https://www.cnblogs.com/leemo-o/p/4990021.html

                     https://www.jianshu.com/p/ef84715e0fdc

    以下仅供对比阅读:

    demo_lmdb.py:  生成lmdb格式数据

     

    1. import lmdb
    2. import numpy as np
    3. import cv2
    4. import caffe
    5. from caffe.proto import caffe_pb2
    6.  
    7. def write():
    8.     # basic setting
    9.  
    10. 10.     lmdb_file = 'lmdb_data'
    11. 11.     batch_size = 256
    12. 12.  
    13. 13.  
    14. 14.     lmdb_env = lmdb.open(lmdb_file, map_size = int(1e12))
    15. 15.  
    16. 16.     lmdb_txn = lmdb_env.begin(write = True)
    17. 17.  
    18. 18.     for x in range(batch_size):
    19. 19.         data = np.ones((3, 64, 64), np.uint8)
    20. 20.         label = x
    21. 21.  
    22. 22.         datum = caffe.io.array_to_datum(data,label)
    23. 23.         keystr = "{:0>8d}".format(x)
    24. 24.  
    25. 25.         lmdb_txn.put(keystr, datum.SerializeToString())
    26. 26.  
    27. 27.     lmdb_txn.commit()
    28. 28.  

    29. def read():

    1. 30.     lmdb_env = lmdb.open('lmdb_data')
    2. 31.     lmdb_txt = lmdb_env.begin()
    3. 32.  
    4. 33.     datum = caffe_pb2.Datum()
    5. 34.  
    6. 35.     for key, value in lmdb_txt.cursor():
    7. 36.  
    8. 37.         datum.ParseFromString(value)
    9. 38.  
    10. 39.         label = datum.label
    11. 40.  
    12. 41.         data = caffe.io.datum_to_array(datum)
    13. 42.  
    14. 43.         print(label)
    15. 44.         print(data)
    16. 45.  
    17. 46.  

    47. if __name__ == '__main__':

    1. 48.     write()
    2. 49.     read()

    demo_create_solver.py:  生成solver配置文件

    1. from caffe.proto import caffe_pb2
    2.  
    3. s = caffe_pb2.SolverParameter()
    4.  
    5. s.train_net = "train.prototxt"
    6. s.test_net.append("test.prototxt")
    7.  
    8. s.test_interval = 100
    9. s.test_iter.append(10)
    10. 10.  

    11. s.max_iter = 1000

    1. 12.  

    13. s.base_lr = 0.1

    1. 14.  

    15. s.weight_decay = 5e-4

    1. 16.  

    17. s.lr_policy = "step"

    1. 18.  

    19. s.display = 10

    1. 20.  

    21. s.snapshot = 10

    1. 22.  

    23. s.snapshot_prefix = "model"

    1. 24.  

    25. s.type = "SGD"

    1. 26.  

    27. s.solver_mode = caffe_pb2.SolverParameter.GPU

    1. 28.  

    29. with open("net/s.prototxt", "w") as f:

    1. 30.     f.write(str(s))
    2. 31.  
    3. 32.  
    4. 33.  
    5. 34.  

    结果如下

    1. train_net: "/home/kuan/PycharmProjects/demo_cnn_net/net/train.prototxt"
    2. test_net: "/home/kuan/PycharmProjects/demo_cnn_net/net/test.prototxt"
    3. test_iter: 1000
    4. test_interval: 100
    5. base_lr: 0.10000000149
    6. display: 100
    7. max_iter: 100000
    8. lr_policy: "step"
    9. weight_decay: 0.000500000023749

    10. snapshot: 100

    11. snapshot_prefix: "/home/kuan/PycharmProjects/demo_cnn_net/cnn_model/mnist/lenet/"

    12. solver_mode: GPU

    13. type: "SGD"

    demo_creat_net.py:    创建网络

    1. import caffe
    2.  
    3. def create_net():
    4.     net = caffe.NetSpec()
    5.  
    6.     net.data, net.label = caffe.layers.Data(source="data.lmdb",
    7.                                             backend=caffe.params.Data.LMDB,
    8.                                             batch_size=32,
    9.                                             ntop=2,  #数据层数据个数,分别为data,label
    10. 10.                                             transform_param=dict(crop_size=40, mirror=True)
    11. 11.                                             )
    12. 12.  
    13. 13.     net.conv1 = caffe.layers.Convolution(net.data, num_output=20, kernel_size=5,
    14. 14.                                          weight_filler={"type": "xavier"},
    15. 15.                                          bias_filler={"type":"xavier"})  #卷积核参数
    16. 16.  
    17. 17.     net.relu1 = caffe.layers.ReLU(net.conv1, in_place=True)
    18. 18.  
    19. 19.     net.pool1 = caffe.layers.Pooling(net.relu1, pool=caffe.params.Pooling.MAX,
    20. 20.                                      kernel_size=3, stride=2)
    21. 21.  
    22. 22.     net.conv2 = caffe.layers.Convolution(net.pool1, num_output=32, kernel_size=3,
    23. 23.                                          pad=1,
    24. 24.                                          weight_filler={"type": "xavier"},
    25. 25.                                          bias_filler={"type": "xavier"})
    26. 26.  
    27. 27.     net.relu2 = caffe.layers.ReLU(net.conv2, in_place=True)
    28. 28.  
    29. 29.     net.pool2 = caffe.layers.Pooling(net.relu2, pool=caffe.params.Pooling.MAX,
    30. 30.                                      kernel_size=3, stride=2)
    31. 31.     #下面为全连接层
    32. 32.     net.fc3 = caffe.layers.InnerProduct(net.pool2, num_output=1024, weight_filler=dict(type='xavier'))
    33. 33.  
    34. 34.     net.relu3 = caffe.layers.ReLU(net.fc3, in_place=True)
    35. 35.  
    36. 36.     ##drop
    37. 37.     net.drop = caffe.layers.Dropout(net.relu3, dropout_param=dict(dropout_ratio=0.5))
    38. 38.  
    39. 39.     net.fc4 = caffe.layers.InnerProduct(net.drop, num_output=10, weight_filler=dict(type='xavier'))
    40. 40.  
    41. 41.     net.loss = caffe.layers.SoftmaxWithLoss(net.fc4, net.label)
    42. 42.  
    43. 43.     with open("net/tt.prototxt", 'w') as f:
    44. 44.         f.write(str(net.to_proto()))
    45. 45.  
    46. 46.  

    47. if __name__ == '__main__':

    1. 48.     create_net()

    生成结果如下

    1. layer {
    2.   name: "data"
    3.   type: "Data"
    4.   top: "data"
    5.   top: "label"
    6.   transform_param {
    7.     mirror: true
    8.     crop_size: 40
    9.   }
    10. 10.   data_param {
    11. 11.     source: "/home/kuan/PycharmProjects/demo_cnn_net/lmdb_data"
    12. 12.     batch_size: 32
    13. 13.     backend: LMDB
    14. 14.   }

    15. }

    16. layer {

    1. 17.   name: "conv1"
    2. 18.   type: "Convolution"
    3. 19.   bottom: "data"
    4. 20.   top: "conv1"
    5. 21.   convolution_param {
    6. 22.     num_output: 20
    7. 23.     kernel_size: 5
    8. 24.     weight_filler {
    9. 25.       type: "xavier"
    10. 26.     }
    11. 27.     bias_filler {
    12. 28.       type: "xavier"
    13. 29.     }
    14. 30.   }

    31. }

    32. layer {

    1. 33.   name: "relu1"
    2. 34.   type: "ReLU"
    3. 35.   bottom: "conv1"
    4. 36.   top: "conv1"

    37. }

    38. layer {

    1. 39.   name: "pool1"
    2. 40.   type: "Pooling"
    3. 41.   bottom: "conv1"
    4. 42.   top: "pool1"
    5. 43.   pooling_param {
    6. 44.     pool: MAX
    7. 45.     kernel_size: 3
    8. 46.     stride: 2
    9. 47.   }

    48. }

    49. layer {

    1. 50.   name: "conv2"
    2. 51.   type: "Convolution"
    3. 52.   bottom: "pool1"
    4. 53.   top: "conv2"
    5. 54.   convolution_param {
    6. 55.     num_output: 32
    7. 56.     pad: 1
    8. 57.     kernel_size: 3
    9. 58.     weight_filler {
    10. 59.       type: "xavier"
    11. 60.     }
    12. 61.     bias_filler {
    13. 62.       type: "xavier"
    14. 63.     }
    15. 64.   }

    65. }

    66. layer {

    1. 67.   name: "relu2"
    2. 68.   type: "ReLU"
    3. 69.   bottom: "conv2"
    4. 70.   top: "conv2"

    71. }

    72. layer {

    1. 73.   name: "pool2"
    2. 74.   type: "Pooling"
    3. 75.   bottom: "conv2"
    4. 76.   top: "pool2"
    5. 77.   pooling_param {
    6. 78.     pool: MAX
    7. 79.     kernel_size: 3
    8. 80.     stride: 2
    9. 81.   }

    82. }

    83. layer {

    1. 84.   name: "fc3"
    2. 85.   type: "InnerProduct"
    3. 86.   bottom: "pool2"
    4. 87.   top: "fc3"
    5. 88.   inner_product_param {
    6. 89.     num_output: 1024
    7. 90.     weight_filler {
    8. 91.       type: "xavier"
    9. 92.     }
    10. 93.   }

    94. }

    95. layer {

    1. 96.   name: "relu3"
    2. 97.   type: "ReLU"
    3. 98.   bottom: "fc3"
    4. 99.   top: "fc3"
    5. }
    6. layer {
    7.   name: "drop"
    8.   type: "Dropout"
    9.   bottom: "fc3"
    10.   top: "drop"
    11.   dropout_param {
    12.     dropout_ratio: 0.5
    13.   }
    14. }
    15. layer {
    16.   name: "fc4"
    17.   type: "InnerProduct"
    18.   bottom: "drop"
    19.   top: "fc4"
    20.   inner_product_param {
    21.     num_output: 10
    22.     weight_filler {
    23.       type: "xavier"
    24.     }
    25.   }
    26. }
    27. layer {
    28.   name: "loss"
    29.   type: "SoftmaxWithLoss"
    30.   bottom: "fc4"
    31.   bottom: "label"
    32.   top: "loss"
    33. }

    demo_train.py训练网络:

    1. import sys
    2. sys.path.append('/home/kuan/AM-softmax_caffe/python')
    3. import caffe
    4.  
    5. solver = caffe.SGDSolver("/home/kuan/PycharmProjects/demo_cnn_net/cnn_net/alexnet/solver.prototxt")
    6.  
    7. solver.solve()

    demo_test.py:测试网络

    1. import sys
    2. sys.path.append('/home/kuan/AM-softmax_caffe/python')
    3. import caffe
    4. import numpy as np
    5.  
    6. ##caffemodel deploy.prototxt
    7.  
    8. deploy = "/home/kuan/PycharmProjects/demo_cnn_net/cnn_net/alexnet/deploy.prototxt"
    9.  

    10. model = "/home/kuan/PycharmProjects/demo_cnn_net/cnn_model/cifar/alexnet/alexnet_iter_110.caffemodel"

    1. 11.  

    12. net = caffe.Net(deploy, model, caffe.TEST)

    1. 13.  
    2. 14.  

    15. net.blobs["data"].data[...] = np.ones((3,32,32),np.uint8)

    1. 16.  

    17. net.forward()

    1. 18.  

    19. prob = net.blobs["prob"].data[0]

    1. 20.  

    21. print(prob)

    1. 22.  

     

     

     

     

     

     

    人工智能芯片与自动驾驶
  • 相关阅读:
    Linux之HugePages快速配置
    Bug 5323844-IMPDP无法导入远程数据库同义词的同义词
    Oracle之SQL优化专题02-稳固SQL执行计划的方法
    使用COE脚本绑定SQL Profile
    AIX挂载NFS写入效率低效解决
    javaWeb项目配置自定义404错误页
    eclipse Referenced file contains errors (http://www.springframework.org/schema/context/spring-context-3.0.xsd)
    tomcat Invalid character found in the request target. The valid characters are defined in RFC 7230 and RFC 3986
    orcle not like不建议使用(not like所踩过的坑!)
    eclipse debug调试 class文件 Source not found.
  • 原文地址:https://www.cnblogs.com/wujianming-110117/p/14391739.html
Copyright © 2011-2022 走看看