zoukankan      html  css  js  c++  java
  • caffe的python接口学习(8):caffemodel中的参数及特征的抽取

    如果用公式  y=f(wx+b)

    来表示整个运算过程的话,那么w和b就是我们需要训练的东西,w称为权值,在cnn中也可以叫做卷积核(filter),b是偏置项。f是激活函数,有sigmoid、relu等。x就是输入的数据。

    数据训练完成后,保存的caffemodel里面,实际上就是各层的w和b值。

    我们运行代码:

    deploy=root + 'mnist/deploy.prototxt'    #deploy文件
    caffe_model=root + 'mnist/lenet_iter_9380.caffemodel'   #训练好的 caffemodel
    net = caffe.Net(net_file,caffe_model,caffe.TEST)   #加载model和network

    就把所有的参数和数据都加载到一个net变量里面了,但是net是一个很复杂的object, 想直接显示出来看是不行的。其中:

    net.params: 保存各层的参数值(w和b)

    net.blobs: 保存各层的数据值

    可用命令:

    [(k,v[0].data) for k,v in net.params.items()]

    查看各层的参数值,其中k表示层的名称,v[0].data就是各层的W值,而v[1].data是各层的b值。注意:并不是所有的层都有参数,只有卷积层和全连接层才有。

    也可以不查看具体值,只想看一下shape,可用命令

    [(k,v[0].data.shape) for k,v in net.params.items()]

    假设我们知道其中第一个卷积层的名字叫'Convolution1', 则我们可以提取这个层的参数:

    w1=net.params['Convolution1'][0].data
    b1=net.params['Convolution1'][1].data

    输入这些代码,实际查看一下,对你理解network非常有帮助。

    同理,除了查看参数,我们还可以查看数据,但是要注意的是,net里面刚开始是没有数据的,需要运行:

    net.forward()

    之后才会有数据。我们可以用代码:

    [(k,v.data.shape) for k,v in net.blobs.items()]

    [(k,v.data) for k,v in net.blobs.items()]

    来查看各层的数据。注意和上面查看参数的区别,一个是net.params, 一个是net.blobs.

    实际上数据刚输入的时候,我们叫图片数据,卷积之后我们就叫特征了。

    如果要抽取第一个全连接层的特征,则可用命令:

    fea=net.blobs['InnerProduct1'].data

    只要知道某个层的名称,就可以抽取这个层的特征。

    推荐大家在spyder中,运行一下上面的所有代码,深入理解模型各层。

    最后,总结一个代码:

    import caffe
    import numpy as np
    root='/home/xxx/'   #根目录
    deploy=root + 'mnist/deploy.prototxt'    #deploy文件
    caffe_model=root + 'mnist/lenet_iter_9380.caffemodel'   #训练好的 caffemodel
    net = caffe.Net(deploy,caffe_model,caffe.TEST)   #加载model和network
    [(k,v[0].data.shape) for k,v in net.params.items()]  #查看各层参数规模
    w1=net.params['Convolution1'][0].data  #提取参数w
    b1=net.params['Convolution1'][1].data  #提取参数b
    net.forward()   #运行测试

    [(k,v.data.shape) for k,v in net.blobs.items()]  #查看各层数据规模
    fea=net.blobs['InnerProduct1'].data   #提取某层数据(特征)
  • 相关阅读:
    RDD执行延迟执行原理
    spark应用运行机制解析1
    spark streaming job生成与运行
    spark的Task的序列化
    spark将计算结果写入到hdfs的两种方法
    spark的runJob函数2
    SVG---------SVG sprite 使用示例
    段落边框——paraBox.scss
    背景条纹——bgStripes.scss
    css3动画——基本准则
  • 原文地址:https://www.cnblogs.com/denny402/p/5686257.html
Copyright © 2011-2022 走看看