zoukankan      html  css  js  c++  java
  • pytorch-第一章基本操作-线性拟合(GPU版本)

    第一步:构造数据

    import numpy as np
    import os
    
    x_values = [i for i in range(11)]
    x_train = np.array(x_values, dtype=np.float32).reshape(-1, 1)
    
    y_values = [i * 2 + 1 for i in x_values]
    y_train = np.array(y_values, dtype=np.float32).reshape(-1, 1)

    第二步: 使用class LinearRegressionModel 

    class LinearRegressionModel(nn.Module):
        def __init__(self, input_dim, output_dim):
            super(LinearRegressionModel, self).__init__()
            self.linear = nn.Linear(input_dim, output_dim)
        def forward(self, x):
            out = self.linear(x)
            return out

    第三步: 实例化模型,初始化epochs, 学习率,定义SGD优化函数,以及定义mse优化损失函数,使用model.to(device) 将模型的参数更新放在GPU上 

    input_dim = 1
    output_dim = 1
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = LinearRegressionModel(input_dim, output_dim)
    model.to(device) epochs = 1000 learning_rate = 0.01 optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate) criterion = nn.MSELoss()

    第四步: 如果模型存在就使用model.load_state_dict(torch.load("model.pkl")) 加载模型 参数,进行模型的参数优化,每50次,使用torch.save(model.state_dict)保存模型 ,使用to(device) 将训练样本和测试样本放在GPU上 

    if os.path.exists("model.pkl"):
        model.load_state_dict(torch.load("model.pkl"))
    
    for epoch in range(epochs):
    
        inputs = torch.from_numpy(x_train).to(device)
        labels = torch.from_numpy(y_train).to(device)
    
        # 梯度每次清零
        optimizer.zero_grad()
    
        # 前向传播
        outputs = model(inputs)
    
        # 计算损失值
        loss = criterion(outputs, labels)
    
        #反向传播
        loss.backward()
    
        #更新权重参数
        optimizer.step()
    
        if epoch % 50 == 0:
            print("epoch:{},loss:{}".format(epoch, loss.item()))
            torch.save(model.state_dict(), "model.pkl")
  • 相关阅读:
    JS注意事项
    正则
    js闭包
    【转】chrome console用法
    JSON
    流式传输原理(一) 之通过Web服务器访问音频和视频
    流式传输原理(二) 之通过流式服务器访问音视频
    Equivalence Class Partitioning等价类划分黑盒测试
    【判断闰年】程序抛出异常的解决方案
    新学期😄😄😄
  • 原文地址:https://www.cnblogs.com/my-love-is-python/p/12650342.html
Copyright © 2011-2022 走看看