zoukankan      html  css  js  c++  java
  • [caffe学习笔记][05][使用LeNet训练mnist数据集]

    说明:

    通过使用LeNet网络模型来训练mnist数据集,分别使用直接读入图片方式和通过lmdb读入两种方式来训练LeNet。训练结果准确率可达到99.2%


    步骤:

    1.通过直接读入图片方式

    vim train_mnist.py

    # -*- coding: utf-8 -*-
    """
    yuandanfei Editor
    输入ImageData格式
    """
    import caffe
    from caffe import layers as L, params as P, proto, to_proto
    
    #设定文件路径
    root = '/home/yuandanfei/work/caffe/mnist2/' #根目录路径
    train_list = root + 'mnist/train/train.txt'  #训练图片列表
    test_list = root + 'mnist/test/test.txt'     #测试图片列表
    train_proto = root + 'train.prototxt'        #训练网络文件
    test_proto = root + 'test.prototxt'          #测试网络文件
    solver_proto = root + 'solver.prototxt'      #参数配置文件
    
    #Lenet网络模型
    def Lenet(img_list, batch_size, include_acc=False):
        #数据层0: n*3*28*28; 输入ImageData格式
        data, label = L.ImageData(source=img_list, batch_size=batch_size, ntop=2, root_folder=root,
                                  transform_param=dict(scale=0.00390625) )
        #卷积层1: n*20*24*24; c1=num_output;w1/h1=(w0/h0+2*pad-kernel_size)/stride+1
        conv1 = L.Convolution(data, kernel_size=5, stride=1, pad=0, num_output=20, weight_filler=dict(type='xavier'))
        #池化层1: n*20*12*12; c1=c0;w1/h1=(w0/h0+2*pad-kernel_size)/stride+1
        pool1 = L.Pooling(conv1, pool=P.Pooling.MAX, kernel_size=2, stride=2)
        #卷积层2: n*50*8*8
        conv2 = L.Convolution(pool1, kernel_size=5, stride=1, pad=0, num_output=50, weight_filler=dict(type='xavier'))
        #池化层2: n*50*4*4
        pool2 = L.Pooling(conv2, pool=P.Pooling.MAX, kernel_size=2, stride=2)    
        #全连接层3: n*500*1*1
        fc3 = L.InnerProduct(pool2, num_output=500, weight_filler=dict(type='xavier'))
        #激活层3: n*500*1*1
        relu3 = L.ReLU(fc3, in_place=True)
        #全连接层4: n*10*1*1
        fc4 = L.InnerProduct(relu3, num_output=10, weight_filler=dict(type='xavier'))
        #损失率层5
        loss = L.SoftmaxWithLoss(fc4, label)
        #准确率层5
        if include_acc: #测试网络
            acc = L.Accuracy(fc4, label)
            return to_proto(loss, acc)
        else:           #训练网络
            return to_proto(loss)
    
    #写入网络模型文件
    def write_Lenet():
        #写入训练网络文件
        with open(train_proto, 'w') as f:
            f.write( str(Lenet(train_list, batch_size=64)) )
        #写入测试网络文件
        with open(test_proto, 'w') as f:
            f.write( str(Lenet(test_list, batch_size=100, include_acc=True)) )
    
    #写入参数配置文件
    def write_solver(solver_file, train_net, test_net):
        s = proto.caffe_pb2.SolverParameter()
        s.train_net = train_net                             #训练网络路径
        s.test_net.append(test_net)                         #测试网络路径
        
        s.test_interval = 938                               #测试间隔次数: 938 = 60000/64
        s.test_iter.append(100)                             #测试迭代次数: 100 = 10000/100
        s.max_iter = 9380                                   #最大迭代次数: 9380 = 938*10
        
        s.lr_policy = 'step' #学习率变化规则: learning_rate = base_lr*gamma^(floor(iter/stepsize))
        s.base_lr = 0.01                                    #基础学习率
        s.gamma = 0.1                                       #学习率变化指数
        s.stepsize = 3127                                   #学习率变化频率: 3127 = 9380/3
        s.momentum = 0.9                                    #学习率动量
        s.weight_decay = 5e-4                               #权重衰减率
        
        s.display = 938                                     #屏幕显示间隔: 938 = 9380/10
        s.snapshot = 4690                                   #保存权重间隔: 4690 = 9380/2
        s.snapshot_prefix = root + 'mnist/lenet'            #权重路径前缀
        
        s.type = 'SGD'                                      #优化算法
        s.solver_mode = proto.caffe_pb2.SolverParameter.GPU #计算方式
        #写入参数配置文件
        with open(solver_file, 'w') as f:
            f.write( str(s) )
    
    def train_model():
        caffe.set_device(0)
        caffe.set_mode_gpu()
        solver = caffe.SGDSolver(solver_proto)
        solver.solve()
        
    if __name__ == '__main__':
        write_Lenet()
        write_solver(solver_proto, train_proto, test_proto)
        train_model()


    2.通过lmdb读入图片方式

    vim train_mnist.py

    # -*- coding: utf-8 -*-
    """
    yuandanfei Editor
    输入LMDB格式
    """
    import caffe
    from caffe import layers as L, params as P, proto, to_proto
    
    #设定文件路径
    root = '/home/yuandanfei/work/caffe/mnist2/' #根目录路径
    train_lmdb = root + 'mnist/train_lmdb'       #训练图片路径
    test_lmdb = root + 'mnist/test_lmdb'         #测试图片路径
    train_proto = root + 'train.prototxt'        #训练网络文件
    test_proto = root + 'test.prototxt'          #测试网络文件
    solver_proto = root + 'solver.prototxt'      #参数配置文件
    
    #Lenet网络模型
    def Lenet(lmdb, batch_size, include_acc=False):
        #数据层0: n*3*28*28; 输入LMDB格式
        data, label = L.Data(source=lmdb, backend=P.Data.LMDB, batch_size=batch_size, ntop=2,
                             transform_param=dict(scale=0.00390625) )
        #卷积层1: n*20*24*24; c1=num_output;w1/h1=(w0/h0+2*pad-kernel_size)/stride+1
        conv1 = L.Convolution(data, kernel_size=5, stride=1, pad=0, num_output=20, weight_filler=dict(type='xavier'))
        #池化层1: n*20*12*12; c1=c0;w1/h1=(w0/h0+2*pad-kernel_size)/stride+1
        pool1 = L.Pooling(conv1, pool=P.Pooling.MAX, kernel_size=2, stride=2)
        #卷积层2: n*50*8*8
        conv2 = L.Convolution(pool1, kernel_size=5, stride=1, pad=0, num_output=50, weight_filler=dict(type='xavier'))
        #池化层2: n*50*4*4
        pool2 = L.Pooling(conv2, pool=P.Pooling.MAX, kernel_size=2, stride=2)    
        #全连接层3: n*500*1*1
        fc3 = L.InnerProduct(pool2, num_output=500, weight_filler=dict(type='xavier'))
        #激活层3: n*500*1*1
        relu3 = L.ReLU(fc3, in_place=True)
        #全连接层4: n*10*1*1
        fc4 = L.InnerProduct(relu3, num_output=10, weight_filler=dict(type='xavier'))
        #损失率层5
        loss = L.SoftmaxWithLoss(fc4, label)
        #准确率层5
        if include_acc: #测试网络
            acc = L.Accuracy(fc4, label)
            return to_proto(loss, acc)
        else:           #训练网络
            return to_proto(loss)
    
    #写入网络模型文件
    def write_Lenet():
        #写入训练网络文件
        with open(train_proto, 'w') as f:
            f.write( str(Lenet(train_lmdb, batch_size=64)) )
        #写入测试网络文件
        with open(test_proto, 'w') as f:
            f.write( str(Lenet(test_lmdb, batch_size=100, include_acc=True)) )
    
    #写入参数配置文件
    def write_solver(solver_file, train_net, test_net):
        s = proto.caffe_pb2.SolverParameter()
        s.train_net = train_net                             #训练网络路径
        s.test_net.append(test_net)                         #测试网络路径
        
        s.test_interval = 938                               #测试间隔次数: 938 = 60000/64
        s.test_iter.append(100)                             #测试迭代次数: 100 = 10000/100
        s.max_iter = 9380                                   #最大迭代次数: 9380 = 938*10
        
        s.lr_policy = 'step' #学习率变化规则: learning_rate = base_lr*gamma^(floor(iter/stepsize))
        s.base_lr = 0.01                                    #基础学习率
        s.gamma = 0.1                                       #学习率变化指数
        s.stepsize = 3127                                   #学习率变化频率: 3127 = 9380/3
        s.momentum = 0.9                                    #学习率动量
        s.weight_decay = 5e-4                               #权重衰减率
        
        s.display = 938                                     #屏幕显示间隔: 938 = 9380/10
        s.snapshot = 4690                                   #保存权重间隔: 4690 = 9380/2
        s.snapshot_prefix = root + 'mnist/lenet'            #权重模型路径
        
        s.type = 'SGD'                                      #优化算法
        s.solver_mode = proto.caffe_pb2.SolverParameter.GPU #计算方式
        #写入参数配置文件
        with open(solver_file, 'w') as f:
            f.write( str(s) )
    
    def train_model():
        caffe.set_device(0)
        caffe.set_mode_gpu()
        solver = caffe.SGDSolver(solver_proto)
        solver.solve()
        
    if __name__ == '__main__':
        write_Lenet()
        write_solver(solver_proto, train_proto, test_proto)
        train_model()


    参考资料:

    https://www.cnblogs.com/denny402/p/5684431.html

  • 相关阅读:
    C# DateTimePicker控件详解
    python2.7虚拟环境virtualenv安装及使用
    Python2.7 安装numpy报错解决方法
    关于C语言中递归的一点点小问题
    Drozer--AndroidApp安全评估工具
    Android--native层so文件调试
    New Blog
    小旭讲解 LeetCode 53. Maximum Subarray 动态规划 分治策略
    2017年度回忆与总结 – 心态
    基于文本图形(ncurses)的文本搜索工具 ncgrep
  • 原文地址:https://www.cnblogs.com/d442130165/p/12765984.html
Copyright © 2011-2022 走看看