zoukankan      html  css  js  c++  java
  • 自然语言处理(三)——PTB数据的batching方法

    参考书

    《TensorFlow:实战Google深度学习框架》(第2版)

    从文本文件中读取数据,并按照下面介绍的方案将数据整理成batch。

    方法是:先将整个文档切分成若干连续段落,再让batch中的每一个位置负责其中一段。

    #!/usr/bin/env python
    # -*- coding: UTF-8 -*-
    # coding=utf-8 
    
    """
    @author: Li Tian
    @contact: 694317828@qq.com
    @software: pycharm
    @file: word_deal3.py
    @time: 2019/2/23 16:36
    @desc: 从文本文件中读取数据,并按照下面介绍的方案将数据整理成batch。
            方法是:先将整个文档切分成若干连续段落,再让batch中的每一个位置负责其中一段。
    """
    
    import numpy as np
    import tensorflow as tf
    
    
    # 使用单词编号表示的训练数据
    TRAIN_DATA = './ptb.train'
    TRAIN_BATCH_SIZE = 20
    TRAIN_NUM_STEP = 35
    
    
    # 从文件中读取数据,并返回包含单词编号的数组
    def read_data(file_path):
        with open(file_path, "r") as fin:
            # 将整个文档读进一个长字符串
            id_string = ' '.join([line.strip() for line in fin.readlines()])
        # 将读取的单词编号转为整数
        id_list = [int(w) for w in id_string.split()]
        return id_list
    
    
    def make_batches(id_list, batch_size, num_step):
        # batch_size: 一个batch中样本的数量
        # num_batches:batch的个数
        # num_step: 一个样本的序列长度
        # 计算总的batch数量。每个batch包含的单词数量是batch_size * num_step
        num_batches = (len(id_list) - 1) // (batch_size * num_step)
    
        # 将数据整理成一个维度为[batch_size, num_batches*num_step]的二维数组
        data = np.array(id_list[: num_batches * batch_size * num_step])
        data = np.reshape(data, [batch_size, num_batches * num_step])
    
        # 沿着第二个维度将数据切分成num_batches个batch,存入一个数组。
        data_batches = np.split(data, num_batches, axis=1)
    
        # 重复上述操作,但是每个位置向右移动一位,这里得到的是RNN每一步输出所需要预测的下一个单词
        label = np.array(id_list[1: num_batches * batch_size * num_step + 1])
        label = np.reshape(label, [batch_size, num_batches * num_step])
        label_batches = np.split(label, num_batches, axis=1)
        # 返回一个长度为num_batches的数组,其中每一项包括一个data矩阵和一个label矩阵
        print(len(id_list))
        print(num_batches * batch_size * num_step)
        return list(zip(data_batches, label_batches))
    
    
    def main():
        train_batches = make_batches(read_data(TRAIN_DATA), TRAIN_BATCH_SIZE, TRAIN_NUM_STEP)
        # 在这里插入模型训练的代码。训练代码将在后面提到。
        for i in train_batches:
            print(i)
    
    
    if __name__ == '__main__':
        main()

    运行结果:

  • 相关阅读:
    【Go语言入门系列】Go语言工作目录介绍及命令工具的使用
    【保姆级教程】手把手教你进行Go语言环境安装及相关VSCode配置
    【Go语言入门系列】(九)写这些就是为了搞懂怎么用接口
    【Go语言入门系列】(八)Go语言是不是面向对象语言?
    【Go语言入门系列】(七)如何使用Go的方法?
    趣解计算机网络(一)之入门概念介绍
    Go语言入门系列(六)之再探函数
    redis数据类型&操作命令
    从Linux源码看Socket(TCP)的accept
    从Linux源码看TIME_WAIT状态的持续时间
  • 原文地址:https://www.cnblogs.com/lyjun/p/10423536.html
Copyright © 2011-2022 走看看