1 # 数据集生成 2 import numpy as np 3 import torch 4 import matplotlib.pyplot as plt 5 import torch.utils.data as Data 6 from torch import nn 7 from torch.nn import init 8 9 num_inputs = 2 10 num_outputs = 1 11 num_samples = 1000 12 13 weight_true = torch.tensor([[2], [-3.4]]) 14 true_w = [2,-3.4] 15 bais_true = -4.2 16 features = torch.tensor(np.random.normal(0, 1, (num_samples, num_inputs)), dtype=torch.float) 17 labels = torch.mm(features, weight_true) + bais_true 18 #features = torch.tensor(np.random.normal(0, 1, (num_samples, num_inputs)), dtype=torch.float) 19 #labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + bais_true 20 labels += torch.tensor(np.random.normal(0, 0.01, size=labels.size()), dtype=torch.float) 21 22 # 加载数据 23 #batch_size = 50 24 #dataset = Data.TensorDataset(features, labels) 25 #dataiter = Data.DataLoader(dataset, batch_size, shuffle=True) 26 27 batch_size = 10 28 # 将训练数据的特征和标签组合 29 dataset = Data.TensorDataset(features, labels) 30 # 随机读取小批量 31 data_iter = Data.DataLoader(dataset, batch_size, shuffle=True) 32 33 # 模型 34 class LineNet(nn.Module): 35 def __init__(self, n_feature): 36 super(LineNet, self).__init__() 37 self.linear = nn.Linear(n_feature, 1) 38 # forward 定义前向传播 39 def forward(self, x): 40 y = self.linear(x) 41 return y 42 43 net = LineNet(num_inputs) 44 print(net) 45 46 # 模型参数 47 init.normal_(net.linear.weight, mean=0, std=0.01) 48 init.constant_(net.linear.bias, val=0) # 也可以直接修改bias的data: net[0].bias.data.fill_(0) 49 50 # 损失函数 51 loss = nn.MSELoss() 52 # 算法优化 53 optimzer = torch.optim.SGD(net.parameters(),lr=0.04) 54 # 训练 55 num_epochs = 100 56 for epoch in range(num_epochs): 57 train_l = 0.0 58 for X,y in data_iter: 59 output = net(X) 60 l = loss(output, y.view(-1, 1)) 61 optimzer.zero_grad() # 梯度清零,等价于net.zero_grad() 62 l.backward() 63 optimzer.step() 64 65 train_l+=l 66 print('epoch %d ,loss %.4f' % (epoch + 1, train_l)) 67 print(f'权重weight:{net.linear.weight},偏差bais{net.linear.bias}') 68 69 70 71 plt.scatter(features[:, 0].numpy(), labels.numpy(), 10, 'r') 72 plt.scatter(features[:, 1].numpy(), labels.numpy(), 10, 'g') 73 plt.show()
运行结果:
抛出问题:在做features数据集使用的是标准正态分布,但如果std != 1 就是拟合不出结果,如果那个高手知道。麻烦指导下,谢谢。