zoukankan      html  css  js  c++  java
  • PyTorch学习笔记5--案例1: 使用多GPU进行数据拟合

    写在前面

    在本教程中,我们将学习:

    • 通过DataParallel使用多GPU训练模型.
    • 数据拟合.

    使用多GPU

    device = torch.device("cuda:0")
    model.to(device) #返回my_tensor的一个GPU上的备份, 而不是重写覆盖了`my_tensor`,这种写法是不正确的
    mytensor = my_tensor.to(device) # 需要assign给一个新的tensor:mytensor,在GPU上用这个才合适
    

    Pytorch默认只使用一个GPU。代码model = nn.DataParallel(model)将让model在多个GPU上运行。

    案例: 数据拟合

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.optim as optim
    from torch.utils.data import Dataset, DataLoader
    
    # Parameters and DataLoaders
    # y(data_size) = X(data_size,2)xW(2,1) + b(标量)
    # input_size = 2:代表设计矩阵 X 有2个特征
    # output_size = 1:代表标签y 只有1个数据
    input_size = 2
    output_size = 1
    batch_size = 200
    data_size = 200
    
    # 优先使用GPU
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    # Dummy DataSet
    # 制作一个随机数数据集。你只需要实现getitem函数
    # @paras:    length:数据集的长度(数据点的个数)
    #            n_features:数据的维度
    class RandomDataset(Dataset):
        def __init__(self, length,n_features):
            self.len = length
            self.weight = torch.tensor([[5.2],[9.6]])
            self.bias = torch.tensor(3.4)
            self.data = torch.randn(length, n_features)
            self.targets = torch.matmul(self.data,self.weight) + self.bias
        def __getitem__(self, index):
            return self.data[index],self.targets[index]
        def __len__(self):
            return self.len
    
    rand_loader = DataLoader(dataset=RandomDataset(data_size, 2),
                             batch_size=batch_size, shuffle=True)
    
    # `DataParallel`可以用在任何模型上。
    # 模型中的print语句将打印输入tensor和输出tensor的size.
    # 注意batch rank0会打印什么
    
    class Model(nn.Module):
        def __init__(self, input_size, output_size):
            super(Model, self).__init__()
            self.fc = nn.Linear(input_size, output_size)
    
        def forward(self, input):
            output = self.fc(input)
            print("In Model: input size", input.size(),"output size", output.size())
            return output
    
    # 生成一个model实例:检测是否有多个GPU
    
    model = Model(input_size, output_size)
    if torch.cuda.device_count() > 1:
      print("Let's use", torch.cuda.device_count(), "GPUs!")
      # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
      model = nn.DataParallel(model)
    
    model.to(device)
    
    # 3 loss 
    
    # 4 optimizer
    optimizer = optim.Adam(model.parameters(),lr=0.5)
    
    def Train():
        for epoch in range(200):
            for (data,labels) in rand_loader:
                # forward
                input = data.to(device)
                targets = labels.to(device)
                output = model(input)
                #loss = sum((output-targets)*(output-targets))/batch_size
                loss = F.mse_loss(output,targets)
                #if epoch%(20) == 0:
                print("Outside: input size", input.size(),"output_size", output.size(),'loss:{}'.format(loss.item()))
                # backward
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
    
    Train()
    print(list(model.parameters()))
    

    打印结果为(可见,PyTorch将数据集X均分到多个GPU上计算后,合并为输出):

    # If you have 2 GPUs, you will see:
    #     # on 2 GPUs
    #     Let's use 2 GPUs!
    #         In Model: input size torch.Size([100, 2]) output size torch.Size([100, 1])
    #         In Model: input size torch.Size([100, 2]) output size torch.Size([100, 1])
    #     Outside: input size torch.Size([200, 2]) output_size torch.Size([200, 1])
    #         In Model: input size torch.Size([100, 2]) output size torch.Size([100, 1])
    #         In Model: input size torch.Size([100, 2]) output size torch.Size([100, 1])
    #     Outside: input size torch.Size([200, 2]) output_size torch.Size([200, 1])
    #     ...
    
  • 相关阅读:
    浅浅的分析下es6箭头函数
    css实现背景半透明文字不透明的效果
    五星评分,让我告诉你半颗星星怎么做
    微信小程序--成语猜猜看
    微信小程序开发中如何实现侧边栏的滑动效果?
    强力推荐微信小程序之简易计算器,很适合小白程序员
    swing _JFileChooser文件选择窗口
    file类简单操作
    序列化对象
    MessageBox_ swt
  • 原文地址:https://www.cnblogs.com/charleechan/p/12320744.html
Copyright © 2011-2022 走看看