zoukankan      html  css  js  c++  java
  • tensorflow中关于vgg16的项目

    转载请注明链接:http://www.cnblogs.com/SSSR/p/5630534.html

    tflearn中的例子训练vgg16项目:https://github.com/tflearn/tflearn/blob/master/examples/images/vgg_network.py 尚未测试成功。

    下面的项目是使用别人已经训练好的模型进行预测,测试效果非常好。

    github:https://github.com/ry/tensorflow-vgg16 此项目已经测试成功,效果非常好,

    如果在Ubuntu中的terminal中运行出现问题,可以参照以下部分解决(解决skimage读取图片的问题)。

    #coding:utf-8
    
    
    import skimage
    import skimage.io
    import skimage.transform
    a=skimage.io.imread('cat.jpg')
    import PIL
    import numpy as np
    import tensorflow as tf
    synset = [l.strip() for l in open('/home/ubuntu/pythonproject/tensorflow/tensorflow-vgg16/synset.txt').readlines()]
    
    def load_image(path):
      # load image
      img = skimage.io.imread(path)
      #img1=PIL.Image.open("/home/ubuntu/pythonproject/tensorflow/tensorflow-vgg16/pic/pig.jpg")
      #img=np.array(PIL.Image.open(path))
      #imgx=np.array(img)  
      #print type(imgx),imgx.shape
      img = img/ 255.0
      assert (0 <= img).all() and (img <= 1.0).all()
      #print "Original Image Shape: ", img.shape
      # we crop image from center
      short_edge = min(img.shape[:2])
      yy = int((img.shape[0] - short_edge) / 2)
      xx = int((img.shape[1] - short_edge) / 2)
      crop_img = img[yy : yy + short_edge, xx : xx + short_edge]
      # resize to 224, 224
      resized_img = skimage.transform.resize(crop_img, (224, 224))
      return resized_img
      
    # returns the top1 string
    def print_prob(prob):
      #print prob
      print "prob shape", prob.shape
      pred = np.argsort(prob)[::-1]
      # Get top1 label
      top1 = synset[pred[0]]
      #print "Top1: ", top1
      # Get top5 label
      top5 = [synset[pred[i]] for i in range(5)]
      #print "Top5: ", top5
      return top1
    
    print u'加载模型文件'
    with open("/home/ubuntu/pythonproject/tensorflow/tensorflow-vgg16/vgg16.tfmodel", mode='rb') as f:
      fileContent = f.read()
      
    print u'创建图'
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(fileContent)
    
    images = tf.placeholder("float", [None, 224, 224, 3])
    
    tf.import_graph_def(graph_def, input_map={ "images": images })
    print "graph loaded from disk"
    
    graph = tf.get_default_graph()
    print u'加载图片'
    #img=np.array(PIL.Image.open("/home/ubuntu/pythonproject/tensorflow/tensorflow-vgg16/pic/pig.jpg"))
    #cat = load_image(path)
    print u'进入sess执行'
    
    sess=tf.Session()
    result=[]
    for i in ['cat.jpg','airplane.jpg','zebra.jpg','pig.jpg','12.jpg','23.jpg']:
      img=load_image('pic/'+i)
      init = tf.initialize_all_variables()
      sess.run(init)
      print "variables initialized"
      batch = img.reshape((1, 224, 224, 3))
      assert batch.shape == (1, 224, 224, 3)
      feed_dict = { images: batch }
      print u'开始执行'
      prob_tensor = graph.get_tensor_by_name("import/prob:0")
      prob = sess.run(prob_tensor, feed_dict=feed_dict)
      print u'输出结果'
      #print_prob(prob[0])
      result.append(print_prob(prob[0]))
    
    
    print result
    sess.close()
    
    
    '''
    with tf.Session() as sess:
      init = tf.initialize_all_variables()
      sess.run(init)
      print "variables initialized"
      batch = cat.reshape((1, 224, 224, 3))
      assert batch.shape == (1, 224, 224, 3)
      feed_dict = { images: batch }
      print u'开始执行'
      prob_tensor = graph.get_tensor_by_name("import/prob:0")
      prob = sess.run(prob_tensor, feed_dict=feed_dict)
    
    print u'输出结果'
    print_prob(prob[0])
    
    '''
    

      

  • 相关阅读:
    node.js+mysql接口入门
    input边写边验证?正则表达式写在属性里?小技巧
    创建vue,react项目
    jquery在网页中加载本地json文件
    OpenFeigin服务接口调用
    Ribbon负载均衡服务调用
    Consul服务注册与发现
    Eureka服务注册与发现
    springboot项目在idea实现热部署
    设计模式——单例模式
  • 原文地址:https://www.cnblogs.com/SSSR/p/5630534.html
Copyright © 2011-2022 走看看