zoukankan      html  css  js  c++  java
  • 使用caffe提供的python接口训练mnist例子

    1 首先肯定是安装caffe,并且编译python接口,如果是在windows上,最好把编译出来的python文件夹的caffe文件夹拷贝到anaconda文件夹下面去,这样就有代码自动提示功能,如下:

    image

    本文中使用的ide为anaconda安装中自带的spyder,如图所示,将根目录设置为caffe的根目录。

    image

    import caffe
    caffe.set_mode_cpu()
    solver = caffe.SGDSolver('examples/mnist/lenet_solver.prototxt')
    solver.solve()

    以上为一次全部迭代,如果想自己控制,可使用如下代码:

    import caffe
    caffe.set_mode_cpu()
    solver = caffe.SGDSolver('examples/mnist/lenet_solver.prototxt')
    #solver.solve()
    
    iter = solver.iter
    while iter<10000:
        solver.step(1)
        iter = solver.iter
        input_data = solver.net.blobs['data'].data  
        loss = solver.net.blobs['loss'].data
        accuracy = solver.test_nets[0].blobs['accuracy'].data
        print 'iter:', iter, 'loss:', loss,'accuracy:',accuracy
    import caffe
    import matplotlib.pyplot as plt     
    import numpy as np
    
    def vis_square(data):
        """Take an array of shape (n, height, width) or (n, height, width, 3)
           and visualize each (height, width) thing in a grid of size approx. sqrt(n) by sqrt(n)"""
        
        # normalize data for display
        data = (data - data.min()) / (data.max() - data.min())
        
        # force the number of filters to be square
        n = int(np.ceil(np.sqrt(data.shape[0])))
        padding = (((0, n ** 2 - data.shape[0]),
                   (0, 1), (0, 1))                 # add some space between filters
                   + ((0, 0),) * (data.ndim - 3))  # don't pad the last dimension (if there is one)
        data = np.pad(data, padding, mode='constant', constant_values=1)  # pad with ones (white)
        
        # tile the filters into an image
        data = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1)))
        data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:])
        if data.shape[2] == 1:
            data = data[:,:,0]
        plt.imshow(data); plt.axis('off')
    
    if __name__ == '__main__':
        caffe.set_mode_cpu()
        solver = caffe.SGDSolver('examples/mnist/lenet_solver.prototxt')
        solver.step(1)
        input_data = solver.net.blobs['data'].data  
        plt.figure(0)
        vis_square(input_data.transpose(0, 2, 3, 1))  
        filters = solver.net.params['conv1'][0].data
        plt.figure(1)
        vis_square(filters.transpose(0, 2, 3, 1))

        特征图:
    image   
        权值图

    image

  • 相关阅读:
    Oracle表级约束和列级约束
    什么是SSL证书服务?
    什么是阿里云SCDN
    什么是阿里云CDN
    什么是弹性公网IP?
    什么是云解析DNS?
    什么是DataV数据可视化
    什么是大数据计算服务MaxCompute
    什么是文件存储NAS
    什么是云存储网关
  • 原文地址:https://www.cnblogs.com/linyuanzhou/p/6012231.html
Copyright © 2011-2022 走看看