zoukankan      html  css  js  c++  java
  • Iris Classification on PyTorch

    Breast Cancer on PyTorch

    Code

    # encoding:utf8
    
    from sklearn.datasets import load_breast_cancer
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import classification_report
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from matplotlib import pyplot as plt
    import numpy as np
    
    
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.l1 = nn.Linear(30, 60)
            self.a1 = nn.Sigmoid()
            self.l2 = nn.Linear(60, 2)
            self.a2 = nn.ReLU()
            self.l3 = nn.Softmax(dim=1)
    
        def forward(self, x):
            x = self.l1(x)
            x = self.a1(x)
            x = self.l2(x)
            x = self.a2(x)
            x = self.l3(x)
            return x
    
    
    if __name__ == '__main__':
        breast_cancer = load_breast_cancer()
    
        x_train, x_test, y_train, y_test = train_test_split(breast_cancer.data, breast_cancer.target, test_size=0.25)
        x_train, x_test = torch.tensor(x_train, dtype=torch.float), torch.tensor(x_test, dtype=torch.float)
        y_train, y_test = torch.tensor(y_train, dtype=torch.long), torch.tensor(y_test, dtype=torch.long)
    
        net = Net()
    
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(net.parameters(), lr=0.005)  # PyTorch suit to tiny learning rate
    
        error = list()
    
        for epoch in range(250):
            optimizer.zero_grad()
            y_pred = net(x_train)
            loss = criterion(y_pred, y_train)
            loss.backward()
            optimizer.step()
            error.append(loss.item())
    
        y_pred = net(x_test)
        y_pred = torch.argmax(y_pred, dim=1)
    
        # it is necessary that drawing the loss plot when we fine tuning the model
        plt.plot(np.arange(1, len(error)+1), error)
        plt.show()
    
        print(classification_report(y_test, y_pred, target_names=breast_cancer.target_names))
    
    

    损失函数图像:

    nn.Sequential

    # encoding:utf8
    
    from sklearn.datasets import load_breast_cancer
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import classification_report
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from matplotlib import pyplot as plt
    import numpy as np
    
    
    if __name__ == '__main__':
        breast_cancer = load_breast_cancer()
    
        x_train, x_test, y_train, y_test = train_test_split(breast_cancer.data, breast_cancer.target, test_size=0.25)
        x_train, x_test = torch.tensor(x_train, dtype=torch.float), torch.tensor(x_test, dtype=torch.float)
        y_train, y_test = torch.tensor(y_train, dtype=torch.long), torch.tensor(y_test, dtype=torch.long)
    
        net = nn.Sequential(
            nn.Linear(30, 60),
            nn.Sigmoid(),
            nn.Linear(60, 2),
            nn.ReLU(),
            nn.Softmax(dim=1)
        )
    
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(net.parameters(), lr=0.005)  # PyTorch suit to tiny learning rate
    
        error = list()
    
        for epoch in range(250):
            optimizer.zero_grad()
            y_pred = net(x_train)
            loss = criterion(y_pred, y_train)
            loss.backward()
            optimizer.step()
            error.append(loss.item())
    
        y_pred = net(x_test)
        y_pred = torch.argmax(y_pred, dim=1)
    
        # it is necessary that drawing the loss plot when we fine tuning the model
        plt.plot(np.arange(1, len(error)+1), error)
        plt.show()
    
        print(classification_report(y_test, y_pred, target_names=breast_cancer.target_names))
    
    

    模型性能:

                  precision    recall  f1-score   support
    
          setosa       1.00      1.00      1.00        14
      versicolor       1.00      1.00      1.00        16
       virginica       1.00      1.00      1.00        20
    
        accuracy                           1.00        50
       macro avg       1.00      1.00      1.00        50
    weighted avg       1.00      1.00      1.00        50
    
  • 相关阅读:
    JAVA中字符串比较equals()和equalsIgnoreCase()的区别
    JAVA字母的大小写转换
    对于java线程的理解
    JAVA实现文件导出Excel
    处理数据库中的null值问题
    POJO、JAVABean、Entity的区别
    Mybatis的choose标签使用
    redis详解
    Spring框架基础解析
    利用 BackgroundService 固定时间间隔执行某动作
  • 原文地址:https://www.cnblogs.com/fengyubo/p/9141130.html
Copyright © 2011-2022 走看看