zoukankan      html  css  js  c++  java
  • Iris Classification on PyTorch

    Breast Cancer on PyTorch

    Code

    # encoding:utf8
    
    from sklearn.datasets import load_breast_cancer
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import classification_report
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from matplotlib import pyplot as plt
    import numpy as np
    
    
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.l1 = nn.Linear(30, 60)
            self.a1 = nn.Sigmoid()
            self.l2 = nn.Linear(60, 2)
            self.a2 = nn.ReLU()
            self.l3 = nn.Softmax(dim=1)
    
        def forward(self, x):
            x = self.l1(x)
            x = self.a1(x)
            x = self.l2(x)
            x = self.a2(x)
            x = self.l3(x)
            return x
    
    
    if __name__ == '__main__':
        breast_cancer = load_breast_cancer()
    
        x_train, x_test, y_train, y_test = train_test_split(breast_cancer.data, breast_cancer.target, test_size=0.25)
        x_train, x_test = torch.tensor(x_train, dtype=torch.float), torch.tensor(x_test, dtype=torch.float)
        y_train, y_test = torch.tensor(y_train, dtype=torch.long), torch.tensor(y_test, dtype=torch.long)
    
        net = Net()
    
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(net.parameters(), lr=0.005)  # PyTorch suit to tiny learning rate
    
        error = list()
    
        for epoch in range(250):
            optimizer.zero_grad()
            y_pred = net(x_train)
            loss = criterion(y_pred, y_train)
            loss.backward()
            optimizer.step()
            error.append(loss.item())
    
        y_pred = net(x_test)
        y_pred = torch.argmax(y_pred, dim=1)
    
        # it is necessary that drawing the loss plot when we fine tuning the model
        plt.plot(np.arange(1, len(error)+1), error)
        plt.show()
    
        print(classification_report(y_test, y_pred, target_names=breast_cancer.target_names))
    
    

    损失函数图像:

    nn.Sequential

    # encoding:utf8
    
    from sklearn.datasets import load_breast_cancer
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import classification_report
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from matplotlib import pyplot as plt
    import numpy as np
    
    
    if __name__ == '__main__':
        breast_cancer = load_breast_cancer()
    
        x_train, x_test, y_train, y_test = train_test_split(breast_cancer.data, breast_cancer.target, test_size=0.25)
        x_train, x_test = torch.tensor(x_train, dtype=torch.float), torch.tensor(x_test, dtype=torch.float)
        y_train, y_test = torch.tensor(y_train, dtype=torch.long), torch.tensor(y_test, dtype=torch.long)
    
        net = nn.Sequential(
            nn.Linear(30, 60),
            nn.Sigmoid(),
            nn.Linear(60, 2),
            nn.ReLU(),
            nn.Softmax(dim=1)
        )
    
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(net.parameters(), lr=0.005)  # PyTorch suit to tiny learning rate
    
        error = list()
    
        for epoch in range(250):
            optimizer.zero_grad()
            y_pred = net(x_train)
            loss = criterion(y_pred, y_train)
            loss.backward()
            optimizer.step()
            error.append(loss.item())
    
        y_pred = net(x_test)
        y_pred = torch.argmax(y_pred, dim=1)
    
        # it is necessary that drawing the loss plot when we fine tuning the model
        plt.plot(np.arange(1, len(error)+1), error)
        plt.show()
    
        print(classification_report(y_test, y_pred, target_names=breast_cancer.target_names))
    
    

    模型性能:

                  precision    recall  f1-score   support
    
          setosa       1.00      1.00      1.00        14
      versicolor       1.00      1.00      1.00        16
       virginica       1.00      1.00      1.00        20
    
        accuracy                           1.00        50
       macro avg       1.00      1.00      1.00        50
    weighted avg       1.00      1.00      1.00        50
    
  • 相关阅读:
    Centos服务器搭建(3)——安装maven
    Centos服务器搭建(2)——安装tomcat
    Centos服务器搭建(1)——安装jdk
    mysql主从复制
    Json中返回换行符处理
    github pages 绑定域名
    SharePoint学习笔记——子页面
    SharePoint学习笔记——母版页
    SSH+Oracle的整合(SSH与Oracle整合坑巨多)
    SSH整合做CRUD(大神老师整理)
  • 原文地址:https://www.cnblogs.com/fengyubo/p/9141130.html
Copyright © 2011-2022 走看看