zoukankan      html  css  js  c++  java
  • pytorch加载语音类自定义数据集

      pytorch对一下常用的公开数据集有很方便的API接口,但是当我们需要使用自己的数据集训练神经网络时,就需要自定义数据集,在pytorch中,提供了一些类,方便我们定义自己的数据集合

    • torch.utils.data.Dataset:所有继承他的子类都应该重写  __len()__  , __getitem()__ 这两个方法
      •  __len()__ :返回数据集中数据的数量
      •   __getitem()__ :返回支持下标索引方式获取的一个数据
    • torch.utils.data.DataLoader:对数据集进行包装,可以设置batch_size、是否shuffle....

    第一步

      自定义的 Dataset 都需要继承 torch.utils.data.Dataset 类,并且重写它的两个成员方法:

    • __len()__:读取数据,返回数据和标签
    • __getitem()__:返回数据集的长度
    from torch.utils.data import Dataset
    
    
    class AudioDataset(Dataset):
        def __init__(self, ...):
            """类的初始化"""
            pass
    
        def __getitem__(self, item):
            """每次怎么读数据,返回数据和标签"""
            return data, label
    
        def __len__(self):
            """返回整个数据集的长度"""
            return total

    注意事项:Dataset只负责数据的抽象,一次调用getiitem只返回一个样本

    案例:

      文件目录结构

    • p225
      • ***.wav
      • ***.wav
      • ***.wav
      • ...
    • dataset.py

    目的:读取p225文件夹中的音频数据

     1 class AudioDataset(Dataset):
     2     def __init__(self, data_folder, sr=16000, dimension=8192):
     3         self.data_folder = data_folder
     4         self.sr = sr
     5         self.dim = dimension
     6 
     7         # 获取音频名列表
     8         self.wav_list = []
     9         for root, dirnames, filenames in os.walk(data_folder):
    10             for filename in fnmatch.filter(filenames, "*.wav"):  # 实现列表特殊字符的过滤或筛选,返回符合匹配“.wav”字符列表
    11                 self.wav_list.append(os.path.join(root, filename))
    12 
    13     def __getitem__(self, item):
    14         # 读取一个音频文件,返回每个音频数据
    15         filename = self.wav_list[item]
    16         wb_wav, _ = librosa.load(filename, sr=self.sr)
    17 
    18         # 取 帧
    19         if len(wb_wav) >= self.dim:
    20             max_audio_start = len(wb_wav) - self.dim
    21             audio_start = np.random.randint(0, max_audio_start)
    22             wb_wav = wb_wav[audio_start: audio_start + self.dim]
    23         else:
    24             wb_wav = np.pad(wb_wav, (0, self.dim - len(wb_wav)), "constant")
    25 
    26         return wb_wav, filename
    27 
    28     def __len__(self):
    29         # 音频文件的总数
    30         return len(self.wav_list)

    注意事项:19-24行:每个音频的长度不一样,如果直接读取数据返回出来的话,会造成维度不匹配而报错,因此只能每次取一个音频文件读取一帧,这样显然并没有用到所有的语音数据,

    第二步

      实例化 Dataset 对象

    Dataset= AudioDataset("./p225", sr=16000)

    如果要通过batch读取数据的可直接跳到第三步,如果你想一个一个读取数据的可以看我接下来的操作

    # 实例化AudioDataset对象
    train_set = AudioDataset("./p225", sr=16000)
    
    for i, data in enumerate(train_set):
        wb_wav, filname = data
        print(i, wb_wav.shape, filname)
    
        if i == 3:
            break
        # 0 (8192,) ./p225p225_001.wav
        # 1 (8192,) ./p225p225_002.wav
        # 2 (8192,) ./p225p225_003.wav
        # 3 (8192,) ./p225p225_004.wav

    第三步

      如果想要通过batch读取数据,需要使用DataLoader进行包装

    为何要使用DataLoader?

    1. 深度学习的输入是mini_batch形式
    2. 样本加载时候可能需要随机打乱顺序,shuffle操作
    3. 样本加载需要采用多线程

      pytorch提供的 DataLoader 封装了上述的功能,这样使用起来更方便。

    DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False)

    参数

    • dataset:加载的数据集(Dataset对象)
    • batch_size每个批次要加载多少个样本(默认值:1)
    • shuffle:每个epoch是否将数据打乱
    • sampler定义从数据集中抽取样本的策略如果指定,则不能指定洗牌。
    • batch_sampler类似于sampler,但每次返回一批索引。与batch_size、shuffle、sampler和drop_last相互排斥。
    • num_workers:使用多进程加载的进程数,0代表不使用多线程
    • collate_fn:如何将多个样本数据拼接成一个batch,一般使用默认拼接方式
    • pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些
    • drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃

    返回:数据加载器

    案例:

    # 实例化AudioDataset对象
    train_set = AudioDataset("./p225", sr=16000)
    train_loader = DataLoader(train_set, batch_size=8, shuffle=True)
    
    for (i, data) in enumerate(train_loader):
        wav_data, wav_name = data
        print(wav_data.shape)   # torch.Size([8, 8192])
        print(i, wav_name)
        # ('./p225\p225_293.wav', './p225\p225_156.wav', './p225\p225_277.wav', './p225\p225_210.wav',
        # './p225\p225_126.wav', './p225\p225_021.wav', './p225\p225_257.wav', './p225\p225_192.wav')

    我们来吃几个栗子消化一下:

    栗子1

      这个例子就是本文一直举例的,栗子1只是合并了一下而已

      文件目录结构

    • p225
      • ***.wav
      • ***.wav
      • ***.wav
      • ...
    • dataset.py

    目的:读取p225文件夹中的音频数据

     1 import fnmatch
     2 import os
     3 import librosa
     4 import numpy as np
     5 from torch.utils.data import Dataset
     6 from torch.utils.data import DataLoader
     7 
     8 
     9 class Aduio_DataLoader(Dataset):
    10     def __init__(self, data_folder, sr=16000, dimension=8192):
    11         self.data_folder = data_folder
    12         self.sr = sr
    13         self.dim = dimension
    14 
    15         # 获取音频名列表
    16         self.wav_list = []
    17         for root, dirnames, filenames in os.walk(data_folder):
    18             for filename in fnmatch.filter(filenames, "*.wav"):  # 实现列表特殊字符的过滤或筛选,返回符合匹配“.wav”字符列表
    19                 self.wav_list.append(os.path.join(root, filename))
    20 
    21     def __getitem__(self, item):
    22         # 读取一个音频文件,返回每个音频数据
    23         filename = self.wav_list[item]
    24         print(filename)
    25         wb_wav, _ = librosa.load(filename, sr=self.sr)
    26 
    27         # 取 帧
    28         if len(wb_wav) >= self.dim:
    29             max_audio_start = len(wb_wav) - self.dim
    30             audio_start = np.random.randint(0, max_audio_start)
    31             wb_wav = wb_wav[audio_start: audio_start + self.dim]
    32         else:
    33             wb_wav = np.pad(wb_wav, (0, self.dim - len(wb_wav)), "constant")
    34 
    35         return wb_wav, filename
    36 
    37     def __len__(self):
    38         # 音频文件的总数
    39         return len(self.wav_list)
    40 
    41 
    42 train_set = Aduio_DataLoader("./p225", sr=16000)
    43 train_loader = DataLoader(train_set, batch_size=8, shuffle=True)
    44 
    45 
    46 for (i, data) in enumerate(train_loader):
    47     wav_data, wav_name = data
    48     print(wav_data.shape)   # torch.Size([8, 8192])
    49     print(i, wav_name)
    50     # ('./p225\p225_293.wav', './p225\p225_156.wav', './p225\p225_277.wav', './p225\p225_210.wav',
    51     # './p225\p225_126.wav', './p225\p225_021.wav', './p225\p225_257.wav', './p225\p225_192.wav')

    注意事项

    1. 27-33行:每个音频的长度不一样,如果直接读取数据返回出来的话,会造成维度不匹配而报错,因此只能每次取一个音频文件读取一帧,这样显然并没有用到所有的语音数据,
    2. 48行:我们在__getitem__中并没有将numpy数组转换为tensor格式,可是第48行显示数据是tensor格式的。这里需要引起注意

    栗子2

      相比于案例1,案例二才是重点,因为我们不可能每次只从一音频文件中读取一帧,然后读取另一个音频文件,通常情况下,一段音频有很多帧,我们需要的是按顺序的读取一个batch_size的音频帧,先读取第一个音频文件,如果满足一个batch,则不用读取第二个batch,如果不足一个batch则读取第二个音频文件,来补充。

      我给出一个建议,先按顺序读取每个音频文件,以窗长8192、帧移4096对语音进行分帧,然后拼接。得到(帧数,帧长,1)(frame_num, frame_len, 1)的数组保存到h5中。然后用上面讲到的 torch.utils.data.Dataset 和 torch.utils.data.DataLoader 读取数据。

    具体实现代码:

      第一步:创建一个H5_generation脚本用来将数据转换为h5格式文件:

      第二步:通过Dataset从h5格式文件中读取数据

    import numpy as np
    from torch.utils.data import Dataset
    from torch.utils.data import DataLoader
    import h5py
    
    def load_h5(h5_path):
        # load training data
        with h5py.File(h5_path, 'r') as hf:
            print('List of arrays in input file:', hf.keys())
            X = np.array(hf.get('data'), dtype=np.float32)
            Y = np.array(hf.get('label'), dtype=np.float32)
        return X, Y
    
    
    class AudioDataset(Dataset):
        """数据加载器"""
        def __init__(self, data_folder):
            self.data_folder = data_folder
            self.X, self.Y = load_h5(data_folder)   # (3392, 8192, 1)
    
        def __getitem__(self, item):
            # 返回一个音频数据
            X = self.X[item]
            Y = self.Y[item]
    
            return X, Y
    
        def __len__(self):
            return len(self.X)
    
    
    train_set = AudioDataset("./speaker225_resample_train.h5")
    train_loader = DataLoader(train_set, batch_size=64, shuffle=True, drop_last=True)
    
    
    for (i, wav_data) in enumerate(train_loader):
        X, Y = wav_data
        print(i, X.shape)
        # 0 torch.Size([64, 8192, 1])
        # 1 torch.Size([64, 8192, 1])
        # ...

    我尝试在__init__中生成h5文件,但是会导致内存爆炸,就很奇怪,因此我只好分开了,

    参考

    pytorch学习(四)—自定义数据集(讲的比较详细)

     

  • 相关阅读:
    使用ACEGI搭建权限系统:第三部分
    分支在版本树中的应用(使用subversion)
    acegi安全框架使用:第二部分
    错误数据导致java.lang.IllegalArgumentException:Unsupported configuration attributes
    移动中间件和wap网关的比较
    3年后,又回到了.net阵营
    android中listView的几点总结
    oracle相关分布式数据解决方案
    ajax实现用户名存在校验
    使用template method模式简化android列表页面
  • 原文地址:https://www.cnblogs.com/LXP-Never/p/13816254.html
Copyright © 2011-2022 走看看