zoukankan      html  css  js  c++  java
  • tflearn实vgg16模型

    vgg16构造模型图:http://ethereon.github.io/netscope/#/gist/dc5003de6943ea5a6b8b

    一下为代码:注释会不断添加。

    # -*- coding: utf-8 -*-
    from __future__ import division, print_function, absolute_import
    """
    Created on Sat Jul  2 14:58:30 2016
    
    @author: ubuntu
    """
    
    # -*- coding: utf-8 -*-
    
    """ Very Deep Convolutional Networks for Large-Scale Visual Recognition.
    Applying VGG 16-layers convolutional network to Oxford's 17 Category Flower
    Dataset classification task.
    References:
        Very Deep Convolutional Networks for Large-Scale Image Recognition.
        K. Simonyan, A. Zisserman. arXiv technical report, 2014.
    Links:
        http://arxiv.org/pdf/1409.1556
    """

    #在Ubuntu的terminal中运行是偶尔会报错关于PIL的,但是如此使用就不会报错了 from PIL import Image a=Image.open('/home/ubuntu/pythonproject/tflearnproject/17flowers/jpg/0/image_0001.jpg') import tflearn from tflearn.layers.core import input_data, dropout, fully_connected from tflearn.layers.conv import conv_2d, max_pool_2d from tflearn.layers.estimator import regression import numpy as np def load_image(in_image): """ Load an image, returns PIL.Image. """ img = Image.open(in_image) return img img_path='/home/ubuntu/pythonproject/tflearnproject/17flowers/jpg/0/image_0001.jpg' img=load_image(img_path) def resize_image(in_image, new_width, new_height, out_image=None, resize_mode=Image.ANTIALIAS): """ Resize an image. Arguments: in_image: `PIL.Image`. The image to resize. new_ `int`. The image new width. new_height: `int`. The image new height. out_image: `str`. If specified, save the image to the given path. resize_mode: `PIL.Image.mode`. The resizing mode. Returns: `PIL.Image`. The resize image. """ img = in_image.resize((new_width, new_height), resize_mode) if out_image: img.save(out_image) return img img=resize_image(img, 224, 224) def pil_to_nparray(pil_image): """ Convert a PIL.Image to numpy array. """ pil_image.load() return np.asarray(pil_image, dtype="float32") img=pil_to_nparray(img) print(u'用于测试的图片加载完成!') # Data loading and preprocessing import tflearn.datasets.oxflower17 as oxflower17 print('------') print('666666666666') X, Y = oxflower17.load_data(one_hot=True) # Building 'VGG Network'以下为模型的加载,其中3是卷积核的大小即3*3.64/128/256/512是卷积核的个数 network = input_data(shape=[None, 224, 224, 3]) network = conv_2d(network, 64, 3, activation='relu') network = conv_2d(network, 64, 3, activation='relu') network = max_pool_2d(network, 2, strides=2) network = conv_2d(network, 128, 3, activation='relu') network = conv_2d(network, 128, 3, activation='relu') network = max_pool_2d(network, 2, strides=2) network = conv_2d(network, 256, 3, activation='relu') network = conv_2d(network, 256, 3, activation='relu') network = conv_2d(network, 256, 3, activation='relu') network = max_pool_2d(network, 2, strides=2) network = conv_2d(network, 512, 3, activation='relu') network = conv_2d(network, 512, 3, activation='relu') network = conv_2d(network, 512, 3, activation='relu') network = max_pool_2d(network, 2, strides=2) network = conv_2d(network, 512, 3, activation='relu') network = conv_2d(network, 512, 3, activation='relu') network = conv_2d(network, 512, 3, activation='relu') network = max_pool_2d(network, 2, strides=2) network = fully_connected(network, 4096, activation='relu') network = dropout(network, 0.5) network = fully_connected(network, 4096, activation='relu') network = dropout(network, 0.5) network = fully_connected(network, 17, activation='softmax') network = regression(network, optimizer='rmsprop', loss='categorical_crossentropy', learning_rate=0.001) # Training #max_checkpoints是存储checkpoint文件的个数,如果超过个数,应该是自动删除 model = tflearn.DNN(network, checkpoint_path='model_vgg', max_checkpoints=1, tensorboard_verbose=0) #snapshot_step表示执行多少步后保存checkpoint文件,n_epoch是执行循环的次数,batch_size每次读取图片的个数,如果内存不足可以通过这个进行调节。 print(u'开始加载模型') #model.load('/home/ubuntu/pythonproject/tflearnproject/model_vgg-20') #model.load('model_vgg-30') model.fit(X, Y, n_epoch=1, shuffle=True, show_metric=True, batch_size=8, snapshot_step=10, snapshot_epoch=False, run_id='vgg_oxflowers17') model.save('vgg16.tflearn') #model.predit(X[0]) print(u'开始预测') model.predict(img) #model.load('vgg16.tflearn')

    图片版模型图:

  • 相关阅读:
    WinForm里保存TreeView状态
    动态规划 回溯和较难题
    go 基本链表操作
    leetcode 42接雨水
    leetcode 旋转数组搜索
    leetcode 牛客编程 子序列 树 数组(积累)
    剑指offer(积累)
    go快排计算最小k个数和第k大的数
    leetcode 打家劫舍
    leetcode 字符串相关问题
  • 原文地址:https://www.cnblogs.com/SSSR/p/5644512.html
Copyright © 2011-2022 走看看