zoukankan      html  css  js  c++  java
  • 【7-4使用inception-v3做各种图像的识别】

    参考程序:

     1 import tensorflow as tf
     2 import os
     3 import numpy as np
     4 import re
     5 from PIL import Image
     6 import matplotlib.pyplot as plt
     7 
     8 class NodeLookup(object):
     9     def __init__(self):  
    10         label_lookup_path = 'inception_model/imagenet_2012_challenge_label_map_proto.pbtxt'   
    11         uid_lookup_path = 'inception_model/imagenet_synset_to_human_label_map.txt'
    12         self.node_lookup = self.load(label_lookup_path, uid_lookup_path)
    13 
    14     def load(self, label_lookup_path, uid_lookup_path):
    15         # 加载分类字符串n********对应分类名称的文件
    16         proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines()   
    17         uid_to_human = {}     
    18         #一行一行读取数据
    19         for line in proto_as_ascii_lines :
    20             #去掉换行符
    21             line=line.strip('
    ')
    22             #按照'	'分割
    23             parsed_items = line.split('	')
    24             #获取分类编号
    25             uid = parsed_items[0]
    26             #获取分类名称
    27             human_string = parsed_items[1]
    28             #保存编号字符串n********与分类名称映射关系
    29             uid_to_human[uid] = human_string
    30 
    31         # 加载分类字符串n********对应分类编号1-1000的文件
    32         proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines()
    33         node_id_to_uid = {}
    34         for line in proto_as_ascii:
    35             if line.startswith('  target_class:'):
    36                 #获取分类编号1-1000
    37                 target_class = int(line.split(': ')[1])
    38             if line.startswith('  target_class_string:'):
    39                 #获取编号字符串n********
    40                 target_class_string = line.split(': ')[1]
    41                 #保存分类编号1-1000与编号字符串n********映射关系
    42                 node_id_to_uid[target_class] = target_class_string[1:-2]  #要将两侧的双引号去掉
    43 
    44         #建立分类编号1-1000对应分类名称的映射关系
    45         node_id_to_name = {}
    46         for key, val in node_id_to_uid.items():
    47             #获取分类名称
    48             name = uid_to_human[val]
    49             #建立分类编号1-1000到分类名称的映射关系
    50             node_id_to_name[key] = name
    51         return node_id_to_name
    52 
    53     #传入分类编号1-1000返回分类名称
    54     def id_to_string(self, node_id):
    55         if node_id not in self.node_lookup:
    56             return ''
    57         return self.node_lookup[node_id]
    58 
    59 
    60 #创建一个图来存放google训练好的模型
    61 with tf.gfile.FastGFile('inception_model/classify_image_graph_def.pb', 'rb') as f:
    62     graph_def = tf.GraphDef()
    63     graph_def.ParseFromString(f.read())
    64     tf.import_graph_def(graph_def, name='')
    65 
    66 
    67 with tf.Session() as sess:
    68     softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
    69     #遍历目录
    70     for root,dirs,files in os.walk('images/'):
    71         for file in files:
    72             #载入图片
    73             image_data = tf.gfile.FastGFile(os.path.join(root,file), 'rb').read()
    74             predictions = sess.run(softmax_tensor,{'DecodeJpeg/contents:0': image_data})#图片格式是jpg格式
    75             predictions = np.squeeze(predictions)#把结果转为1维数据
    76 
    77             #打印图片路径及名称
    78             image_path = os.path.join(root,file)
    79             print(image_path)
    80             #显示图片
    81             img=Image.open(image_path)
    82             plt.imshow(img)
    83             plt.axis('off')
    84             plt.show()
    85 
    86             #排序
    87             top_k = predictions.argsort()[-5:][::-1]
    88             node_lookup = NodeLookup()
    89             for node_id in top_k:     
    90                 #获取分类名称
    91                 human_string = node_lookup.id_to_string(node_id)
    92                 #获取该分类的置信度
    93                 score = predictions[node_id]
    94                 print('%s (score = %.5f)' % (human_string, score))
    95             print()

    在inception_model中有这2个文件:

    分别长这样:

    分类共有1000种结果以及所对应得字符串:

    上述字符串及所对应的描述:

    程序首先读入这2个文件:

            proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines()   
            uid_to_human = {}     
            #一行一行读取数据
            for line in proto_as_ascii_lines :
                #去掉换行符
                line=line.strip('
    ')
                #按照'	'分割
                parsed_items = line.split('	')
                #获取分类编号
                uid = parsed_items[0]
                #获取分类名称
                human_string = parsed_items[1]
                #保存编号字符串n********与分类名称映射关系
                uid_to_human[uid] = human_string

    从 uid_lookup_path 中读取的结果存放在 proto_as_ascii_lines中,创建了一个空字典 uid_to_human 用来存储键值对,一行一行的读取数据,并将换行符去掉,以Tab键作为分割,

    parsed_items[0]代表分类字符串n********(Tab键之前的内容),parsed_items[1]对应分类的名称(Tab键之后的内容),作为键值对存入字典。


            # 加载分类字符串n********对应分类编号1-1000的文件
            proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines()
            node_id_to_uid = {}
            for line in proto_as_ascii:
                if line.startswith('  target_class:'):
                    #获取分类编号1-1000
                    target_class = int(line.split(': ')[1])
                if line.startswith('  target_class_string:'):
                    #获取编号字符串n********
                    target_class_string = line.split(': ')[1]
                    #保存分类编号1-1000与编号字符串n********映射关系
                    node_id_to_uid[target_class] = target_class_string[1:-2]

    程序同理,以 分割,获取分类编号以及编号字符串n********,作为键值对存入字典,target_class_string[1:-2]是去掉字符串中的2个双引号:

  • 相关阅读:
    #跟着教程学# 6、maya/python window命令
    element 中MessageBox的封装
    vue中XLSX解析excel文件
    git工作区-暂存区
    GIT相关
    弹窗-二维码生成与下载
    输入框限定100个汉字或200字符
    深浅拷贝(详细)
    日期选择器选取时间范围(非空以及初始时间不能在当天以前)
    多选
  • 原文地址:https://www.cnblogs.com/direwolf22/p/11055235.html
Copyright © 2011-2022 走看看