zoukankan      html  css  js  c++  java
  • 数据集(benchmark)、常用数据集的解析(cifar-10、)

    What is the class of this image ?

    主要是以下常见的数据集,用以衡量算法的分类准确率:

    • mnist、cifar-10、cifar-100stl-10
    • svhn、ILSVRC2012 task 1

    1. cifar-10

    CIFAR-10 and CIFAR-100 datasets

    • cifar-10-batches-py(Python 接口)

      import os
      import pickle
      import numpy as np
      
      def load_CIFAR10_batch(filename):
          with open(filename, 'rb') as f:
              data = pickle.load(f, encoding='latin1')
              X = data['data']
              y = data['labels']
              X = X.reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1).astype(np.float32)
              y = np.array(y)
              return X, y
      
      
      def load_CIFAR10(root):
          xs, ys = [], []
          for n in range(1, 6):
              filename = os.path.join(root, 'data_batch_{}'.format(n))
              X, y = load_CIFAR10_batch(filename)
              xs.append(X)
              ys.append(y)
          Xtr = np.concatenate(xs)
          Ytr = np.concatenate(ys)
          Xte, Yte = load_CIFAR10_batch(os.path.join(root, 'test_batch'))
          return Xtr, Ytr, Xte, Yte

      对于描述数据信息的信息(batches.meta),仍然可以使用 pickle.load 的形式加载,加载的结果仍然是一个字典类型:

      with open('batches.meta', 'rb') as f:
          data = pickle.load(f, encoding='latin1')
      print(data)
      
      {'label_names': ['airplane',
        'automobile',
        'bird',
        'cat',
        'deer',
        'dog',
        'frog',
        'horse',
        'ship',
        'truck'],
       'num_cases_per_batch': 10000,
       'num_vis': 3072}
    • cifar-10-batches-mat(matlab 接口)

      最方便的方式是调用 matlab 内置已封装好的 api,helperCIFAR10Data.download/load,或者使用 edit helperCIFAR10Data查看其实现;

      function [train_x, train_y, test_x, test_y] = load_cifar(filepath)
      
          train_x = []; train_y = [];
          for i = 1:5
              filename = fullfile(filepath, sprintf('data_batch_%d.mat', i));
              [batch_train, batch_labels] = load_batch_as_4d_tensor(filename, true);
              train_x = cat(4, train_x, batch_train);
              train_y = [train_y; batch_labels];
          end
          filename = fullfile(filepath, 'test_batch.mat');
          [test_x, test_y] = load_batch_as_4d_tensor(filename, true);
      end
      
      function [train_x, train_y] = load_batch_as_4d_tensor(filename, to_categorical)
      % 这里的 x_train 是 4 维的 tensor, 32*32*3*num
          if  ~exist('to_categorical', 'var') || isempty(to_categorical)
              to_categorical = false;
          end
          load(filename);
          train_x = reshape(data', 32, 32, 3, []);
          train_x = permute(train_x, [2, 1, 3, 4]);       % 互换第一维和第二维
          train_y = labels;
          if to_categorical
              metafile = fullfile(fileparts(filename), 'batches.meta.mat');
              load(metafile);
              train_y = categorical(train_y, 0:9, label_names);
          end
      
      end
  • 相关阅读:
    6.linux下指定项目使用特定jdk
    5.linux 执行shell报bad interpreter:No such file or directory错误
    定时任务基础版本
    同一台电脑安装两个jdk切换问题
    接口如何设计?安全如何保证?签名如何实现?防重如何实现?
    spring boot常见get 、post请求参数处理
    bat例子
    1.Volatile关键字详解
    1.linux目录
    解析xml报文,xml与map互转
  • 原文地址:https://www.cnblogs.com/mtcnn/p/9422049.html
Copyright © 2011-2022 走看看