zoukankan      html  css  js  c++  java
  • transformer代码笔记----train.py

    import numpy as np
    import torch
    # from torch.utils.tensorboard import SummaryWriter
    import torch.nn as nn
    import argparse
    from tqdm import tqdm
    
    from config import device, print_freq, vocab_size, sos_id, eos_id
    from data_gen import AiShellDataset, pad_collate
    from transformer.decoder import Decoder
    from transformer.encoder import Encoder
    from transformer.loss import cal_performance
    from transformer.optimizer import TransformerOptimizer
    from transformer.transformer import Transformer
    from utils import parse_args, save_checkpoint, AverageMeter, get_logger
    
    
    def train_net(args):
        torch.manual_seed(7) #定义随机种子
        np.random.seed(7)
        checkpoint = args.checkpoint
        start_epoch = 0
        best_loss = float('inf')
        # writer = SummaryWriter()
        epochs_since_improvement = 0
    
        # Initialize / load checkpoint
        if checkpoint is None: #判断模型是否被中断过
            # model
            encoder = Encoder(args.d_input * args.LFR_m, args.n_layers_enc, args.n_head,
                              args.d_k, args.d_v, args.d_model, args.d_inner,
                              dropout=args.dropout, pe_maxlen=args.pe_maxlen)
            decoder = Decoder(sos_id, eos_id, vocab_size,
                              args.d_word_vec, args.n_layers_dec, args.n_head,
                              args.d_k, args.d_v, args.d_model, args.d_inner,
                              dropout=args.dropout,
                              tgt_emb_prj_weight_sharing=args.tgt_emb_prj_weight_sharing,
                              pe_maxlen=args.pe_maxlen)
            model = Transformer(encoder, decoder)
            # print(model)
            # model = nn.DataParallel(model)
    
            # optimizer
            optimizer = TransformerOptimizer(
                torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-09))
            #model.parameters():可用于迭代优化的参数或者定义参数组的dicts。
            #lr (float, optional) :学习率(默认: 1e-3)
            #betas (Tuple[float, float], optional):用于计算梯度的平均和平方的系数(默认: (0.9, 0.98))
            #eps (float, optional):为了提高数值稳定性而添加到分母的一个项(默认: 1e-09)
            #weight_decay (float, optional):权重衰减(如L2惩罚)
    
        else:
            checkpoint = torch.load(checkpoint)
            start_epoch = checkpoint['epoch'] + 1
            epochs_since_improvement = checkpoint['epochs_since_improvement']
            model = checkpoint['model']
            optimizer = checkpoint['optimizer']
    
        logger = get_logger() #日志
    
        # Move to GPU, if available
        model = model.to(device)
    
        # Custom dataloaders
        train_dataset = AiShellDataset(args, 'train') #从train路径下获取train数据,并对wav类型数据进行预处理
        #对数据设置批量和填充,pin_memory=True:锁页内存(不与硬盘进行交换);shuffle=True:打乱顺序;num_workers:工作进程数,越大批量处理越快,但加重CPU负担
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=pad_collate,
                                                   pin_memory=True, shuffle=True, num_workers=args.num_workers)
        valid_dataset = AiShellDataset(args, 'dev')
        valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, collate_fn=pad_collate,
                                                   pin_memory=True, shuffle=False, num_workers=args.num_workers)
    
        # Epochs
        for epoch in range(start_epoch, args.epochs):
            # One epoch's training
            train_loss = train(train_loader=train_loader,
                               model=model,
                               optimizer=optimizer,
                               epoch=epoch,
                               logger=logger)
            # writer.add_scalar('model/train_loss', train_loss, epoch)
    
            lr = optimizer.lr #获取学习率值
            print('
    Learning rate: {}'.format(lr)) 
            # writer.add_scalar('model/learning_rate', lr, epoch)
            step_num = optimizer.step_num #优化器更新学习率的次数
            print('Step num: {}
    '.format(step_num))
    
            # One epoch's validation
            valid_loss = valid(valid_loader=valid_loader,
                               model=model,
                               logger=logger)  #测试不需要优化器
            # writer.add_scalar('model/valid_loss', valid_loss, epoch)
    
            # Check if there was an improvement
            is_best = valid_loss < best_loss #判断等式右边是否成立,成立is_best=1,否则is_best=0
            best_loss = min(valid_loss, best_loss) #获得最小的loss值
            if not is_best: #比较当前测试损失和以前最好的损失谁更小,并做标记
                epochs_since_improvement += 1
                print("
    Epochs since last improvement: %d
    " % (epochs_since_improvement,))
            else:
                epochs_since_improvement = 0
    
            # Save checkpoint
            # 保存最小损失的数据信息
            save_checkpoint(epoch, epochs_since_improvement, model, optimizer, best_loss, is_best)
    
    
    def train(train_loader, model, optimizer, epoch, logger): #数据训练
        model.train()  # train mode (dropout and batchnorm is used) 训练模式,有梯度,参数更新等
    
        losses = AverageMeter() #损失平均值
    
        # Batches
        for i, (data) in enumerate(train_loader): #train_loader中有数据,标签和length
            # Move to GPU, if available
            padded_input, padded_target, input_lengths = data
            padded_input = padded_input.to(device) #将输入数据放入设备中
            padded_target = padded_target.to(device)
            input_lengths = input_lengths.to(device)
    
            # Forward prop.
            pred, gold = model(padded_input, input_lengths, padded_target) #将数据放入模型中训练得到预测值和目标值
            loss, n_correct = cal_performance(pred, gold, smoothing=args.label_smoothing)
            #将目标值和预测值放入损失函数中得到损失和准确个数,smoothing(平滑正则化):防止过拟合
    
            # Back prop.
            optimizer.zero_grad() #将优化器梯度归零
            loss.backward() #反向传播
    
            # Update weights
            optimizer.step() #更新参数
    
            # Keep track of metrics
            losses.update(loss.item()) #获得loss平均值
    
            # Print status
            if i % print_freq == 0:  #默认print_freq = 100,每一百个数据训练完成后日志中记录一次平均损失等信息
                logger.info('Epoch: [{0}][{1}/{2}]	'
                            'Loss {loss.val:.5f} ({loss.avg:.5f})'.format(epoch, i, len(train_loader), loss=losses))
    
        return losses.avg #返回损失平均值
    
    
    def valid(valid_loader, model, logger): #模型预测
        model.eval()  #预测模式
    
        losses = AverageMeter() 
    
        # Batches
        for data in tqdm(valid_loader):
            # Move to GPU, if available
            padded_input, padded_target, input_lengths = data
            padded_input = padded_input.to(device)
            padded_target = padded_target.to(device)
            input_lengths = input_lengths.to(device)
    
            with torch.no_grad():
                # Forward prop.
                pred, gold = model(padded_input, input_lengths, padded_target)
                loss, n_correct = cal_performance(pred, gold, smoothing=args.label_smoothing)
    
            # Keep track of metrics
            losses.update(loss.item())
    
        # Print status
        logger.info('
    Validation Loss {loss.val:.5f} ({loss.avg:.5f})
    '.format(loss=losses))
    
        return losses.avg
    
    
    def main():
      
        global args
        args = parse_args()
        train_net(args)
    
    
    if __name__ == '__main__':
        main()
  • 相关阅读:
    ubuntu lock
    ubuntu 源
    ubuntu server版 ssh配置有时没有sshd_config文件或者空文件的情况
    pip3 install tensorflow==2.2
    tensorflow安装提示load 失败
    wXgame上某游戏封包分析
    lazarus 使用微软detour库 delphi
    dll函数导出
    Error: Duplicate resource: Type = 24, Name = 1, Lang ID = 0000
    Tests run: 3, Failures: 0, Errors: 3, Skipped: 0, Time elapsed: 0.065 s <<< FAILURE!
  • 原文地址:https://www.cnblogs.com/Uriel-w/p/15426144.html
Copyright © 2011-2022 走看看