zoukankan      html  css  js  c++  java
  • 机器学习-利用pickle加载cifar文件

    首先这里有百度云的数据集供大家下载:(官网太慢了)

    链接:https://pan.baidu.com/s/1G0MxZIGSK_DyZTcuNbxraQ
    提取码:ui51
    复制这段内容后打开百度网盘手机App,操作更方便哦

    然后奉献代码

    def load_CIFAR10(ROOT):
        """ 载入cifar全部数据 """
        xs = []
        ys = []
        for b in range(1, 2):
            f = os.path.join(ROOT, 'data_batch_%d' % (b,))
            X, Y = load_CIFAR_batch(f)
            xs.append(X)         #将所有batch整合起来
            ys.append(Y)
        Xtr = np.concatenate(xs) #使变成行向量,最终Xtr的尺寸为(50000,32,32,3)
        Ytr = np.concatenate(ys)
        del X, Y
        Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
        return Xtr, Ytr, Xte, Yte

    找到cifar文件夹下面的二进制文件:

    然后对每次的文件进行批处理:

    def load_CIFAR_batch(filename):
        """ 直接读入cifar数据集的一个batch """
        with open(filename, 'rb') as f:
            datadict = p.load(f, encoding='latin1')
            X = datadict['data']
            Y = datadict['labels']
            X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
            Y = np.array(Y)
            return X, Y

    测试:

    import numpy as np
    
    # 载入CIFAR-10数据集
    cifar10_dir = 'datacifar10cifar-10-batches-py'
    X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)
    
    # 看看数据集中的一些样本:每个类别展示一些
    print('训练数据的形状: ', X_train.shape)
    print('训练集标签的形状: ', y_train.shape)
    print('测试数据的形状: ', X_test.shape)
    print('测试数据的形状: ', y_test.shape)
    import pickle as p
    import os
    
    
    def load_CIFAR_batch(filename):
        """ 载入cifar数据集的一个batch """
        with open(filename, 'rb') as f:
            datadict = p.load(f, encoding='latin1')
            X = datadict['data']
            Y = datadict['labels']
            X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
            Y = np.array(Y)
            return X, Y
    
    
    def load_CIFAR10(ROOT):
        """ 载入cifar全部数据 """
        xs = []
        ys = []
        for b in range(1, 2):
            f = os.path.join(ROOT, 'data_batch_%d' % (b,))
            X, Y = load_CIFAR_batch(f)
            xs.append(X)         #将所有batch整合起来
            ys.append(Y)
        Xtr = np.concatenate(xs) #使变成行向量,最终Xtr的尺寸为(50000,32,32,3)
        Ytr = np.concatenate(ys)
        del X, Y
        Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
        return Xtr, Ytr, Xte, Yte
    
    if __name__ == '__main__':
        import numpy as np
    
    
        # 载入CIFAR-10数据集
        cifar10_dir = 'datacifar10cifar-10-batches-py'
        X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)
    
        # 看看数据集中的一些样本:每个类别展示一些
        print('Training data shape: ', X_train.shape)
        print('Training labels shape: ', y_train.shape)
        print('Test data shape: ', X_test.shape)
        print('Test labels shape: ', y_test.shape)
  • 相关阅读:
    c# influxDB
    ASP.NET Web简单开发
    vue3 最长递增子序列 diff优化
    【转】Android Kotlin协程 coroutines 理解
    基于混合模型的语音降噪实践
    语音降噪论文“A Hybrid Approach for Speech Enhancement Using MoG Model and Neural Network Phoneme Classifier”的研读
    基于sinc的音频重采样(二):实现
    基于sinc的音频重采样(一):原理
    深度学习中神经网络模型的量化
    嵌入式设备上卷积神经网络推理时memory的优化
  • 原文地址:https://www.cnblogs.com/TimVerion/p/11226189.html
Copyright © 2011-2022 走看看