zoukankan      html  css  js  c++  java
  • Pytorch之数据处理

    使用TensorDataset和DataLoader来简化

     
    from torch.utils.data import TensorDataset
    from torch.utils.data import DataLoader
    train_ds = TensorDataset(x_train, y_train)
    train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)
    valid_ds = TensorDataset(x_valid, y_valid)
    valid_dl = DataLoader(valid_ds, batch_size=bs * 2)
     
    def get_data(train_ds, valid_ds, bs):
        return (
            DataLoader(train_ds, batch_size=bs, shuffle=True),
            DataLoader(valid_ds, batch_size=bs * 2),
        )
     
     
     
    • 一般在训练模型时加上model.train(),这样会正常使用Batch Normalization和 Dropout
    • 测试的时候一般选择model.eval(),这样就不会使用Batch Normalization和 Dropout
    import numpy as np
    def fit(steps, model, loss_func, opt, train_dl, valid_dl):
        for step in range(steps):
            model.train()
            for xb, yb in train_dl:
                loss_batch(model, loss_func, xb, yb, opt)
            model.eval()
            with torch.no_grad():
                losses, nums = zip(
                    *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
                )
            val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
            print('当前step:'+str(step), '验证集损失:'+str(val_loss))
     
     
    from torch import optim
    def get_model():
        model = Mnist_NN()
        return model, optim.SGD(model.parameters(), lr=0.001)
     
     
    def loss_batch(model, loss_func, xb, yb, opt=None):
        loss = loss_func(model(xb), yb)
        if opt is not None:
            loss.backward()
            opt.step()
            opt.zero_grad()
        return loss.item(), len(xb)
     
     
     

    三行搞定!

    train_dl, valid_dl = get_data(train_ds, valid_ds, bs)
    model, opt = get_model()
    fit(25, model, loss_func, opt, train_dl, valid_dl)
     
     
     
     
     
  • 相关阅读:
    二维数组和二维指针作为函数的参数
    我所理解的tensorflow
    新篇:A New Start
    3NF(Canonical Cover and Decomposition)
    SQL: group by + having
    hihoCoder挑战赛14
    KMP算法
    二分查找
    Cellular Network
    拓撲排序
  • 原文地址:https://www.cnblogs.com/BetterThanEver_Victor/p/13280586.html
Copyright © 2011-2022 走看看