zoukankan      html  css  js  c++  java
  • pytorch解决鸢尾花分类

    半年前用numpy写了个鸢尾花分类200行。。每一步计算都是手写的  python构建bp神经网络_鸢尾花分类

    现在用pytorch简单写一遍,pytorch语法解释请看上一篇pytorch搭建简单网络

     1 import pandas as pd
     2 import torch.nn as nn
     3 import torch
     4 
     5 
     6 class MyNet(nn.Module):
     7     def __init__(self):
     8         super(MyNet, self).__init__()
     9         self.fc = nn.Sequential(
    10             nn.Linear(4, 3),
    11             nn.Sigmoid(),
    12             nn.Linear(3, 3),
    13             nn.Sigmoid(),
    14             nn.Linear(3, 1),
    15         )
    16         self.mls = nn.MSELoss()
    17         self.opt = torch.optim.Adam(params=self.parameters(), lr=0.001)
    18 
    19     def get_data(self):
    20         inputs = []
    21         labels = []
    22         with open('flower.csv') as file:
    23             df = pd.read_csv(file, header=None)
    24             x = df.iloc[:, 0:4].values
    25             y = df.iloc[:, 4].values
    26             for i in range(len(x)):
    27                 inputs.append(x[i])
    28             for j in range(len(y)):
    29                 a = []
    30                 a.append(y[j])
    31                 labels.append(a)
    32 
    33         return inputs, labels
    34 
    35     def forward(self, inputs):
    36         out = self.fc(inputs)
    37         return out
    38 
    39     def train(self, x, label):
    40         out = self.forward(x)
    41         loss = self.mls(out, label)
    42         self.opt.zero_grad()
    43         loss.backward()
    44         self.opt.step()
    45 
    46     def test(self, x):
    47         return self.fc(x)
    48 
    49 
    50 if __name__ == '__main__':
    51     net = MyNet()
    52     inputs, labels = net.get_data()
    53     for i in range(1000):
    54         for index, input in enumerate(inputs):
    55             # 这里不加.float()会报错,可能是数据格式的问题吧
    56             input = torch.from_numpy(input).float()
    57             label = torch.Tensor(labels[index])
    58             net.train(input, label)
    59     # 简单测试一下
    60     c = torch.Tensor([[5.6, 2.7, 4.2, 1.3]])
    61     print(net.test(c))

    运行结果趋近于0.5  正确,单纯练一下pytorch,就没有分训练集,测试集

    1 tensor([[0.5392]], grad_fn=<AddmmBackward>)

    不用手写反向传播和梯度下降 是多么幸福一件事~

  • 相关阅读:
    数据库隔离级别
    Mysql 命令详解
    Mysql 索引
    强化学习(四):蒙特卡洛方法
    强化学习(三):动态编程
    强化学习(二):马尔可夫决策过程
    强化学习(一): 引入
    自然语言处理(五)时下流行的生成模型
    论文选读三 QANet
    皮质学习 HTM 知多少
  • 原文地址:https://www.cnblogs.com/MC-Curry/p/10109138.html
Copyright © 2011-2022 走看看