zoukankan      html  css  js  c++  java
  • helper工具包——基于cifar10数据集的cnn分类模型的模块

    import pickle
    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn.preprocessing import LabelBinarizer
    import os
    
    def _load_label_names():
        """
        Load the label names from file
        """
        return ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
    
    
    def load_cfar10_batch(cifar10_dataset_folder_path, batch_id):
        """
        Load a batch of the dataset
        """
        with open(cifar10_dataset_folder_path + '/data_batch_' + str(batch_id), mode='rb') as file:
            batch = pickle.load(file, encoding='latin1')
    
        # 先reshape 然后再转置 [N, C, H, W] --> [N, H, W, C]
        features = batch['data'].reshape((len(batch['data']), 3, 32, 32)).transpose(0, 2, 3, 1)
        labels = batch['labels']
    
        return features, labels
    
    
    def display_stats(cifar10_dataset_folder_path, batch_id, sample_id):
        """
        Display Stats of the the dataset
        """
        batch_ids = list(range(1, 6))
    
        if batch_id not in batch_ids:
            print('Batch Id out of Range. Possible Batch Ids: {}'.format(batch_ids))
            return None
    
        features, labels = load_cfar10_batch(cifar10_dataset_folder_path, batch_id)
    
        if not (0 <= sample_id < len(features)):
            print('{} samples in batch {}.  {} is out of range.'.format(len(features), batch_id, sample_id))
            return None
    
        print('
    Stats of batch {}:'.format(batch_id))
        print('Samples: {}'.format(len(features)))
        print('Label Counts: {}'.format(dict(zip(*np.unique(labels, return_counts=True)))))
        print('First 20 Labels: {}'.format(labels[:20]))
    
        sample_image = features[sample_id]
        sample_label = labels[sample_id]
        label_names = _load_label_names()
    
        print('
    Example of Image {}:'.format(sample_id))
        print('Image - Min Value: {} Max Value: {}'.format(sample_image.min(), sample_image.max()))
        print('Image - Shape: {}'.format(sample_image.shape))
        print('Label - Label Id: {} Name: {}'.format(sample_label, label_names[sample_label]))
        plt.axis('off')
        plt.imshow(sample_image)
        plt.show()
    
    
    def _preprocess_and_save(normalize, one_hot_encode, features, labels, filename):
        """
        Preprocess data and save it to file
        """
        # features shape =[9000, 32, 32, 3]
        features = normalize(features)
        labels = one_hot_encode(labels)
    
        pickle.dump((features, labels), open(filename, 'wb'))
    
    
    def preprocess_and_save_data(cifar10_dataset_folder_path, normalize, one_hot_encode):
        """
        Preprocess Training and Validation Data
        """
        n_batches = 5
        valid_features = []
        valid_labels = []
    
        # 迭代循环5次,分批次读入原始数据
        for batch_i in range(1, n_batches + 1):
    
            # 调用读入数据函数
            features, labels = load_cfar10_batch(cifar10_dataset_folder_path, batch_i)
            # 求得features得长度,取10%,并取整,作为 验证数据集。
            validation_count = int(len(features) * 0.1)
    
            # 调用我们定义的预处理函数-处理数据,并将训练数据集写入磁盘中。
            _preprocess_and_save(
                normalize,
                one_hot_encode,
                features[:-validation_count],
                labels[:-validation_count],
                'preprocess_batch_' + str(batch_i) + '.p')
    
            # 训练数据集中余下得10% 作为验证数据集。
            valid_features.extend(features[-validation_count:])
            valid_labels.extend(labels[-validation_count:])
    
        # 预处理验证数据,并写入磁盘
        _preprocess_and_save(
            normalize,
            one_hot_encode,
            np.array(valid_features),
            np.array(valid_labels),
            'preprocess_validation.p')
    
        # 下面预处理  测试数据集。
        with open(cifar10_dataset_folder_path + '/test_batch', mode='rb') as file:
            batch = pickle.load(file, encoding='latin1')
    
        # todo-将维度为 [None, 3, 32, 32]的数据 转置成 [None, 32, 32, 3]的数据
        test_features = batch['data'].reshape((len(batch['data']), 3, 32, 32)).transpose(0, 2, 3, 1)
        test_labels = batch['labels']
    
        # Preprocess and Save all test data
        _preprocess_and_save(
            normalize,
            one_hot_encode,
            np.array(test_features),
            np.array(test_labels),
            'preprocess_test.p')
    
    
    def batch_features_labels(features, labels, batch_size):
        """
        Split features and labels into batches
        """
        # 用 yield迭代器。
        assert len(features) == len(labels)
        for start in range(0, len(features), batch_size):
            end = min(start + batch_size, len(features))
            yield features[start:end], labels[start:end]
    
    
    def load_preprocess_training_batch(batch_id, batch_size):
        """
        Load the Preprocessed Training data and return them in batches of <batch_size> or less
        """
        # todo-先读入该数据
        filepath = '../datas/cifar10'
        filename = 'preprocess_batch_' + str(batch_id) + '.p'
        filename1 = os.path.join(filepath, filename)
        features, labels = pickle.load(open(filename1, mode='rb'))
    
        # Return the training data in batches of size <batch_size> or less
        return batch_features_labels(features, labels, batch_size)
    
    
    def display_image_predictions(features, labels, predictions):
        n_classes = 10
        label_names = _load_label_names()
        label_binarizer = LabelBinarizer()
        label_binarizer.fit(range(n_classes))
        label_ids = label_binarizer.inverse_transform(np.array(labels))
    
        fig, axies = plt.subplots(nrows=4, ncols=2)
        fig.tight_layout()
        fig.suptitle('Softmax Predictions', fontsize=20, y=1.1)
    
        n_predictions = 3
        margin = 0.05
        ind = np.arange(n_predictions)
        width = (1. - 2. * margin) / n_predictions
    
        for image_i, (feature, label_id, pred_indicies, pred_values) in enumerate(zip(features, label_ids, predictions.indices, predictions.values)):
            pred_names = [label_names[pred_i] for pred_i in pred_indicies]
            correct_name = label_names[label_id]
    
            axies[image_i][0].imshow(feature)
            axies[image_i][0].set_title(correct_name)
            axies[image_i][0].set_axis_off()
    
            axies[image_i][1].barh(ind + margin, pred_values[::-1], width)
            axies[image_i][1].set_yticks(ind + margin)
            axies[image_i][1].set_yticklabels(pred_names[::-1])
            axies[image_i][1].set_xticks([0, 0.5, 1.0])
        plt.show()
  • 相关阅读:
    git push要输入密码问题
    excel换行
    React的diff算法
    https的通信过程
    一道面试题的分析
    Mac将应用拖入Finder工具栏
    React获取组件实例
    Warning: Received `false` for a non-boolean attribute `xxx`.
    warning: React does not recognize the xxx prop on a DOM element
    webpack开发模式和生产模式设置及不同环境脚本执行
  • 原文地址:https://www.cnblogs.com/qianchaomoon/p/12315984.html
Copyright © 2011-2022 走看看