使用的criterion不是MSE而是交叉熵。
numpy.shape,tensor.size(),正确遍历变量。
另外 CrossEntropyLoss的参数真是有够饶人的。
1 import torch 2 from torch import nn,optim 3 import matplotlib.pyplot as plt 4 5 class Classifier(nn.Module): 6 def __init__(self,input_feature,output_size): 7 super(Classifier, self).__init__() 8 self.linear=nn.Linear(input_feature,output_size) 9 # print(input_feature) 10 # print(output_size) 11 12 def forward(self,x): 13 # print(x.size()) 14 x=self.linear(x) 15 # print(x.size()) 16 x=torch.sigmoid(x) 17 # print(x.size()) 18 return x 19 20 21 def train(self, inp, target, criterion, optimizer, epoches): 22 for epoch in range(epoches): 23 output = self.forward(inp) 24 # print(output.size()) 25 # print(target.size()) 26 loss = criterion(output, target) 27 optimizer.zero_grad() 28 loss.backward() 29 optimizer.step() 30 return self, loss 31 32 cluster=torch.ones(100,2) 33 data0=torch.normal(cluster,1) 34 data1=torch.normal(-cluster,1) 35 target0=torch.zeros(100,1) 36 target1=torch.ones(100,1) 37 inputs=torch.cat((data0,data1),dim=0) 38 target=torch.cat((target0,target1),dim=0) 39 print(target.size()) 40 target=torch.squeeze(target) 41 print(inputs.size()) 42 print(target.size()) 43 44 plt.scatter(inputs.numpy()[:,0],inputs.numpy()[:,1],c=target.numpy()[:,0],s=10,cmap='RdYlGn') 45 plt.show() 46 47 model=Classifier(2,2) 48 criterion=nn.CrossEntropyLoss() 49 optimizer = optim.SGD(model.parameters(), lr=1e-2) 50 51 # x=torch.cat((data0,data1),).type(torch.FloatTensor) 52 # y=torch.cat((torch.zeros(100),torch.ones(100)),).type(torch.LongTensor) 53 54 new_model,loss=model.train(inputs,target.type(torch.LongTensor),criterion,optimizer,100) 55 print(loss)