zoukankan      html  css  js  c++  java
  • Keras猫狗大战三:加载模型,预测目录中图片,画混淆矩阵

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处。联系方式:460356155@qq.com

     一、加载模型,预测测试集

    %matplotlib inline
    import matplotlib.pyplot as plt
    
    import os
    import itertools
    import cv2
    
    import numpy as np
    from sklearn.metrics import confusion_matrix
    
    from keras.preprocessing.image import ImageDataGenerator
    from keras.models import load_model
    
    dst_path = r'D:BaiduNetdiskDownloadsmall'
    model_file = r"D:fastaiprojectscats_and_dogs_small_1.h5"
    test_dir = os.path.join(dst_path, 'test')
    
    batch_size = 20
    
    model = load_model(model_file)
    
    test_datagen = ImageDataGenerator(rescale=1. / 255)
    
    test_generator = test_datagen.flow_from_directory(
        test_dir,
        target_size=(150, 150),
        batch_size=batch_size,
        class_mode='binary')
    
    test_loss, test_acc = model.evaluate_generator(test_generator, steps=test_generator.samples / batch_size)
    print('test acc: %.3f%%' % test_acc)
    Found 400 images belonging to 2 classes.
    test acc: 0.747%

    二、预测测试集,画混淆矩阵
    def get_input_xy(src=[]):
        pre_x = []
        true_y = []
    
        class_indices = {'cat': 0, 'dog': 1}
    
        for s in src:
            input = cv2.imread(s)
            input = cv2.resize(input, (150, 150))
            input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
            pre_x.append(input)
    
            _, fn = os.path.split(s)
            y = class_indices.get(fn[:3])
            true_y.append(y)
    
        pre_x = np.array(pre_x) / 255.0
    
        return pre_x, true_y
    
    
    def plot_sonfusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
        plt.imshow(cm, interpolation='nearest', cmap=cmap)
        plt.title(title)
        plt.colorbar()
        tick_marks = np.arange(len(classes))
        plt.xticks(tick_marks, classes, rotation=45)
        plt.yticks(tick_marks, classes)
    
        if normalize:
            cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    
        thresh = cm.max() / 2.0
        for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
            plt.text(j, i, cm[i, j], horizontalalignment='center', color='white' if cm[i, j] > thresh else 'black')
    
        plt.tight_layout()
        plt.ylabel('True label')
        plt.xlabel('Predict label')
    
    
    test = os.listdir(test_dir)
    
    images = []
    
    # 获取每张图片的地址,并保存在列表images中
    for testpath in test:
        for fn in os.listdir(os.path.join(test_dir, testpath)):
            if fn.endswith('jpg'):
                fd = os.path.join(test_dir, testpath, fn)
                images.append(fd)
    
    # 得到规范化图片及true label
    pre_x, true_y = get_input_xy(images)
    
    # 预测
    pred_y = model.predict_classes(pre_x)
    
    # 画混淆矩阵
    confusion_mat = confusion_matrix(true_y, pred_y)
    plot_sonfusion_matrix(confusion_mat, classes=range(2))

  • 相关阅读:
    获取DataGrid数据
    C# 分頁
    TCP 协议
    node fs对象
    ANSI转义码 改变输出的字体颜色
    异步流程控制模式
    node event对象
    js中的异常捕获 try{} catch{}(二)
    node require 文件查找的顺序
    node process全局对象
  • 原文地址:https://www.cnblogs.com/zhengbiqing/p/11070050.html
Copyright © 2011-2022 走看看