zoukankan      html  css  js  c++  java
  • pytorch imagenet测试代码

    image_test.py

    import argparse
    import numpy as np
    import sys
    import os
    import csv
    from imagenet_test_base import TestKit
    import torch
    
    
    class TestTorch(TestKit):
    
        def __init__(self):
            super(TestTorch, self).__init__()
    
            self.truth['tensorflow']['inception_v3'] = [(22, 9.6691055), (24, 4.3524747), (25, 3.5957973), (132, 3.5657473), (23, 3.346283)]
            self.truth['keras']['inception_v3'] = [(21, 0.93430489), (23, 0.002883445), (131, 0.0014781791), (24, 0.0014518998), (22, 0.0014435351)]
    
            self.model = self.MainModel.KitModel(self.args.w)
            self.model.eval()
    
        def preprocess(self, image_path):
            x = super(TestTorch, self).preprocess(image_path)
            x = np.transpose(x, (2, 0, 1))
            x = np.expand_dims(x, 0).copy()
            self.data = torch.from_numpy(x)
            self.data = torch.autograd.Variable(self.data, requires_grad = False)
    
    
        def print_result(self, image_name, top1, top5):
            predict = self.model(self.data)
            predict = predict.data.numpy()
            return super(TestTorch, self).print_result(predict, image_name, top1, top5)
    
    
        def print_intermediate_result(self, layer_name, if_transpose=False):
            intermediate_output = self.model.test.data.numpy()
            super(TestTorch, self).print_intermediate_result(intermediate_output, if_transpose)
    
    
        def inference(self, images):
    
            with open(images) as fp_images:
                images_file = csv.reader(fp_images, delimiter='
    ')
                top1 = 0.0
                top5 = 0.0
                image_count = 0
                for image_name in images_file:
                    image_count += 1
                    image_path = "../data/imagenet/small_imagenet/"+image_name[0]
                    self.preprocess(image_path)
                    temp1, temp5 = self.print_result(image_name[0], top1, top5)
                    top1 = temp1
                    top5 = temp5
            print("top1's accuracy : %f"%(top1/image_count))
            print("top5's accuracy : %f"%(top5/image_count))
            # self.print_intermediate_result(None, False)
     # self.test_truth()
    
    
        def dump(self, path=None):
            if path is None: path = self.args.dump
            torch.save(self.model, path)
            print('PyTorch model file is saved as [{}], generated by [{}.py] and [{}].'.format(
                  path, self.args.n, self.args.w))
    
    
    if __name__=='__main__':
        tester = TestTorch()
        if tester.args.dump:
            tester.dump()
        else:
            tester.inference(tester.args.image)
    

    image_test_base.py:

      请见上传的代码。 下载地址:https://files.cnblogs.com/files/jzcbest1016/imagenet_test_base.py.tar.gz

    执行py文件时,需要终端输入参数:

     parser = argparse.ArgumentParser()
    
            parser.add_argument('-p', '--preprocess', type=_text_type, help='Model Preprocess Type')   # pytorch的测试程序, 这里为image_test.py
    
            parser.add_argument('-n', type=_text_type, default='kit_imagenet',  
                                help='Network structure file name.')   # 模型结构测试文件
    
            parser.add_argument('-s', type=_text_type, help='Source Framework Type',
                                choices=self.truth.keys())           # 框架类型:pytorch,tensorflow...
    
            parser.add_argument('-w', type=_text_type, required=True,
                                help='Network weights file name')   #模型结构文件
    
            parser.add_argument('--image', '-i',
                                type=_text_type, help='Test image path.',
                                default="../data/file_list.txt"     #图像路径
            )
    
            parser.add_argument('-l', '--label',
                                type=_text_type,
                                default='../data/val.txt',
                                help='Path of label.')   #测试集类别
    
            parser.add_argument('--dump',
                type=_text_type,
                default=None,
                help='Target model path.')  # 转化的目标模型文件的保存路径
    
            parser.add_argument('--detect',
                type=_text_type,
                default=None,
                help='Model detection result path.')
    
            # tensorflow dump tag
            parser.add_argument('--dump_tag',
                type=_text_type,
                default=None,
                help='Tensorflow model dump type',
                choices=['SERVING', 'TRAINING'])
    
  • 相关阅读:
    学生信息表
    水仙花数
    DirectAccess完整配置
    这些惹人嫌系统安装方法
    求解方程式
    AD DS的维护之备份还原
    简单的switch语句
    linux_常用命令
    小小问题
    frameset和frame
  • 原文地址:https://www.cnblogs.com/jzcbest1016/p/9780356.html
Copyright © 2011-2022 走看看