zoukankan      html  css  js  c++  java
  • pytorch

    Data:

    from torch.utils.data import Dataset, DataLoader
    
    class MyDataset(Dataset):
        def __init__(self, file):
            self.data = ...
        
        def __getitem__(self, index):
            return self.data[index]
      
        def __len__(self):
            return len(self.data)
    
    dataset = MyDataset(file)
    dataloader = DataLoader(dataset, batch_size, shuffle=True)
    

    Model:

    import torch.nn as nn
    
    class MyModel(nn.Module):
        def __init__(self):
            super(MyModel, self).__init__()
            self.net = nn.Sequential(
                    nn.Linear(10, 32),
                    nn.Sigmoid(),
                    nn.Linear(32, 1)
                    )
    
        def forward(self, x):
            return self.net(x)
    

    Train:

    dataset = MyDataset(file)
    tr_set = DataLoader(dataset, 16, shuffle=True)
    model = MyModel().to(device)
    criterion = nn.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(), 0.1)
    
    for epoch in range(n_epochs):
        model.train()
        for x, y in tr_set:
            optimizer.zero_grad()
            x, y = x.to(device), y.to(device)
            pred = model(x)
            loss = criterion(pred, y)
            loss.backward()
            optimizer.step()
    

    Evaluate-Validate:

    model.eval()
    total_loss = 0
    for x, y in dv_set:
        x, y = x.to(device), y.to(device)
        with torch.no_grad():
            pred = model(x)
            loss = criterion(pred, y)
            total_loss += loss.cpu().item() * len(x)
        avg_loss = total_loss / len(dv_set.dataset)
    

    Evaluate-Test:

    model.eval()
    preds = []
    for x in tt_set:
        x = x.to(device)
        with torch.no_grad():
        pred = model(x)
        preds.append(pred.cpu())
    

    Utils:

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    torch.save(model.state_dict(), path)
    
    ckpt = torch.load(path)
    model.load_state_dict(ckpt)
    
    feature = torch.from_numpy(np.load(filename))
    

    https://speech.ee.ntu.edu.tw/~hylee/ml/ml2021-course-data/hw/Pytorch/Pytorch_Tutorial_1.pdf

  • 相关阅读:
    Linux部署golang程序(无数据库访问)
    MySQL备份数据库mysqldump
    Linux命令netstat
    SQL优化01(转载)
    springcloud之gateway点滴
    关于数据库错误:serverTimeZone
    代码重构的重要性
    关于集合的泛型
    python 视频下载神器(you-get)
    linux下ssh
  • 原文地址:https://www.cnblogs.com/holaworld/p/14603868.html
Copyright © 2011-2022 走看看