zoukankan      html  css  js  c++  java
  • transformer代码笔记----pre_process.py

    import os
    import pickle
    from tqdm import tqdm
    from config import wav_folder, transcript_file, pickle_file
    from utils import ensure_folder
    
    
    def get_data(split):
        print('getting {} data...'.format(split)) #对获取的数据名打印
     
        global VOCAB   #定义全局变量
        with open(transcript_file, 'r', encoding='utf-8') as file: #打开文件transcript_file,仅对其读操作,重命名为file
            lines = file.readlines() #逐行读取文件内容
    
        tran_dict = dict() #创建空字典
        for line in lines:  #迭代file文件中的每一行
            tokens = line.split() #将一行的输入进行切分。str.split(str="", num=string.count(str)):str为分隔符,默认空格;num为切分次数,默认全切分
            key = tokens[0]
            trn = ''.join(tokens[1:]) #'_'.join(sequence):将sequence中的元素以'_'连接形成一个新的元素
            tran_dict[key] = trn   # tran_dict: {'BAC0009123': wav1.wav, ...}
    
        samples = []
    
        folder = os.path.join(wav_folder, split)    # data/data_aishell/wav/train  os.path.join():连接路径名,以/连接
        ensure_folder(folder)    # 确保floder是一个目录,如果不存在该路径下的目录就生成一个新的目录
        #os.listdir():以列表的形式提取路径下的文件。os.path.isdir():判断是否存在该文件。最终dirs中以列表形式存储folder路径下的所有文件
        dirs = [os.path.join(folder, d) for d in os.listdir(folder) if os.path.isdir(os.path.join(folder, d))]  # data/data_aishell/wav/train/S0003
        for dir in tqdm(dirs): 
        #Tqdm 是一个快速,可扩展的Python进度条,可以在 Python 长循环中添加一个进度提示信息,用户只需要封装任意的迭代器 tqdm(iterator)。
            files = [f for f in os.listdir(dir) if f.endswith('.wav')]    # [wav1, wav2, .....]
            #endswith() 方法用于判断字符串是否以指定后缀结尾,如果以指定后缀结尾返回 True,否则返回 False。
    
            for f in files:
                wave = os.path.join(dir, f)  # data/data_aishell/wav/train/S0003/wav1.wav
    
                key = f.split('.')[0] #切分f,取第一个元素:wav1
                if key in tran_dict:
                    trn = tran_dict[key]
                    trn = list(trn.strip()) + ['<eos>'] #获取数据,并在每行数据后加结束标志
                    #strip() 方法用于移除字符串头尾指定的字符(默认为空格或换行符)或字符序列。
    
                    for token in trn:
                        build_vocab(token)
    
                    trn = [VOCAB[token] for token in trn]
    
                    samples.append({'trn': trn, 'wave': wave}) #append() 方法用于在列表末尾添加新的对象。
    
        print('split: {}, num_files: {}'.format(split, len(samples)))
        return samples
    
    
    def build_vocab(token):
        global VOCAB, IVOCAB
        if not token in VOCAB: #将token及index添加到IVOCAB和VOCAB中
            next_index = len(VOCAB)
            VOCAB[token] = next_index
            IVOCAB[next_index] = token
    
    
    if __name__ == "__main__":
        VOCAB = {'<sos>': 0, '<eos>': 1}
        IVOCAB = {0: '<sos>', 1: '<eos>'}
    
        data = dict()
        data['VOCAB'] = VOCAB
        data['IVOCAB'] = IVOCAB
        data['train'] = get_data('train')
        data['dev'] = get_data('dev')
        data['test'] = get_data('test')
    
        with open(pickle_file, 'wb') as file:
            pickle.dump(data, file)
    
        print('num_train: ' + str(len(data['train'])))
        print('num_dev: ' + str(len(data['dev'])))
        print('num_test: ' + str(len(data['test'])))
        print('vocab_size: ' + str(len(data['VOCAB'])))
  • 相关阅读:
    LOJ2565. 「SDOI2018」旧试题
    位运算
    Arrays.sort()原理
    LinkedList源码解析
    二维数组排序
    数据结构和算法-五大常用算法:贪心算法
    数据结构和算法-五大常用算法:分支限界法
    数据结构和算法-五大常用算法:分治算法
    数据结构和算法-二分查找
    Arrays.copyOf()&Arrays.copyOfRange()&System.arraycopy
  • 原文地址:https://www.cnblogs.com/Uriel-w/p/15426160.html
Copyright © 2011-2022 走看看