zoukankan      html  css  js  c++  java
  • pytorch深度学习:一般分类器

    使用的criterion不是MSE而是交叉熵。

    numpy.shape,tensor.size(),正确遍历变量。

    另外 CrossEntropyLoss的参数真是有够饶人的。

     1 import torch
     2 from torch import nn,optim
     3 import matplotlib.pyplot as plt
     4 
     5 class Classifier(nn.Module):
     6     def __init__(self,input_feature,output_size):
     7         super(Classifier, self).__init__()
     8         self.linear=nn.Linear(input_feature,output_size)
     9         # print(input_feature)
    10         # print(output_size)
    11 
    12     def forward(self,x):
    13         # print(x.size())
    14         x=self.linear(x)
    15         # print(x.size())
    16         x=torch.sigmoid(x)
    17         # print(x.size())
    18         return x
    19 
    20 
    21     def train(self, inp, target, criterion, optimizer, epoches):
    22         for epoch in range(epoches):
    23             output = self.forward(inp)
    24             # print(output.size())
    25             # print(target.size())
    26             loss = criterion(output, target)
    27             optimizer.zero_grad()
    28             loss.backward()
    29             optimizer.step()
    30         return self, loss
    31 
    32 cluster=torch.ones(100,2)
    33 data0=torch.normal(cluster,1)
    34 data1=torch.normal(-cluster,1)
    35 target0=torch.zeros(100,1)
    36 target1=torch.ones(100,1)
    37 inputs=torch.cat((data0,data1),dim=0)
    38 target=torch.cat((target0,target1),dim=0)
    39 print(target.size())
    40 target=torch.squeeze(target)
    41 print(inputs.size())
    42 print(target.size())
    43 
    44 plt.scatter(inputs.numpy()[:,0],inputs.numpy()[:,1],c=target.numpy()[:,0],s=10,cmap='RdYlGn')
    45 plt.show()
    46 
    47 model=Classifier(2,2)
    48 criterion=nn.CrossEntropyLoss()
    49 optimizer = optim.SGD(model.parameters(), lr=1e-2)
    50 
    51 # x=torch.cat((data0,data1),).type(torch.FloatTensor)
    52 # y=torch.cat((torch.zeros(100),torch.ones(100)),).type(torch.LongTensor)
    53 
    54 new_model,loss=model.train(inputs,target.type(torch.LongTensor),criterion,optimizer,100)
    55 print(loss)
  • 相关阅读:
    操作excel语法
    MySQL exists的用法介
    vim 快捷键
    mysql中datetime比较大小问题
    MySQL CAST与CONVERT 函数的用法
    tbxvUZIAJH
    springBoot相关
    springCloud
    Spring Boot使用JavaMailSender发送邮件
    RabbitMq 消息队列
  • 原文地址:https://www.cnblogs.com/St-Lovaer/p/13696443.html
Copyright © 2011-2022 走看看