zoukankan      html  css  js  c++  java
  • pytorch-第一章基本操作-线性拟合(GPU版本)

    第一步:构造数据

    import numpy as np
    import os
    
    x_values = [i for i in range(11)]
    x_train = np.array(x_values, dtype=np.float32).reshape(-1, 1)
    
    y_values = [i * 2 + 1 for i in x_values]
    y_train = np.array(y_values, dtype=np.float32).reshape(-1, 1)

    第二步: 使用class LinearRegressionModel 

    class LinearRegressionModel(nn.Module):
        def __init__(self, input_dim, output_dim):
            super(LinearRegressionModel, self).__init__()
            self.linear = nn.Linear(input_dim, output_dim)
        def forward(self, x):
            out = self.linear(x)
            return out

    第三步: 实例化模型,初始化epochs, 学习率,定义SGD优化函数,以及定义mse优化损失函数,使用model.to(device) 将模型的参数更新放在GPU上 

    input_dim = 1
    output_dim = 1
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = LinearRegressionModel(input_dim, output_dim)
    model.to(device) epochs = 1000 learning_rate = 0.01 optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate) criterion = nn.MSELoss()

    第四步: 如果模型存在就使用model.load_state_dict(torch.load("model.pkl")) 加载模型 参数,进行模型的参数优化,每50次,使用torch.save(model.state_dict)保存模型 ,使用to(device) 将训练样本和测试样本放在GPU上 

    if os.path.exists("model.pkl"):
        model.load_state_dict(torch.load("model.pkl"))
    
    for epoch in range(epochs):
    
        inputs = torch.from_numpy(x_train).to(device)
        labels = torch.from_numpy(y_train).to(device)
    
        # 梯度每次清零
        optimizer.zero_grad()
    
        # 前向传播
        outputs = model(inputs)
    
        # 计算损失值
        loss = criterion(outputs, labels)
    
        #反向传播
        loss.backward()
    
        #更新权重参数
        optimizer.step()
    
        if epoch % 50 == 0:
            print("epoch:{},loss:{}".format(epoch, loss.item()))
            torch.save(model.state_dict(), "model.pkl")
  • 相关阅读:
    Unity3D性能优化--- 收集整理的一堆
    Unity5.3官方VR教程重磅登场-系列7 优化VR体验
    VR沉浸体验的要求
    Unity5中叹为观止的实时GI效果
    浅谈控制反转与依赖注入[转]
    unity 使用unityaction 需要注意的问题[转]
    c# orm框架 sqlsugar
    unity Instantiate设置父级物体bug
    宝塔面板 使用mongodb注意事项
    unity中gameObject.SetActive()方法的注意事项。
  • 原文地址:https://www.cnblogs.com/my-love-is-python/p/12650342.html
Copyright © 2011-2022 走看看