zoukankan      html  css  js  c++  java
  • Iris Classification on PyTorch

    Breast Cancer on PyTorch

    Code

    # encoding:utf8
    
    from sklearn.datasets import load_breast_cancer
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import classification_report
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from matplotlib import pyplot as plt
    import numpy as np
    
    
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.l1 = nn.Linear(30, 60)
            self.a1 = nn.Sigmoid()
            self.l2 = nn.Linear(60, 2)
            self.a2 = nn.ReLU()
            self.l3 = nn.Softmax(dim=1)
    
        def forward(self, x):
            x = self.l1(x)
            x = self.a1(x)
            x = self.l2(x)
            x = self.a2(x)
            x = self.l3(x)
            return x
    
    
    if __name__ == '__main__':
        breast_cancer = load_breast_cancer()
    
        x_train, x_test, y_train, y_test = train_test_split(breast_cancer.data, breast_cancer.target, test_size=0.25)
        x_train, x_test = torch.tensor(x_train, dtype=torch.float), torch.tensor(x_test, dtype=torch.float)
        y_train, y_test = torch.tensor(y_train, dtype=torch.long), torch.tensor(y_test, dtype=torch.long)
    
        net = Net()
    
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(net.parameters(), lr=0.005)  # PyTorch suit to tiny learning rate
    
        error = list()
    
        for epoch in range(250):
            optimizer.zero_grad()
            y_pred = net(x_train)
            loss = criterion(y_pred, y_train)
            loss.backward()
            optimizer.step()
            error.append(loss.item())
    
        y_pred = net(x_test)
        y_pred = torch.argmax(y_pred, dim=1)
    
        # it is necessary that drawing the loss plot when we fine tuning the model
        plt.plot(np.arange(1, len(error)+1), error)
        plt.show()
    
        print(classification_report(y_test, y_pred, target_names=breast_cancer.target_names))
    
    

    损失函数图像:

    nn.Sequential

    # encoding:utf8
    
    from sklearn.datasets import load_breast_cancer
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import classification_report
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from matplotlib import pyplot as plt
    import numpy as np
    
    
    if __name__ == '__main__':
        breast_cancer = load_breast_cancer()
    
        x_train, x_test, y_train, y_test = train_test_split(breast_cancer.data, breast_cancer.target, test_size=0.25)
        x_train, x_test = torch.tensor(x_train, dtype=torch.float), torch.tensor(x_test, dtype=torch.float)
        y_train, y_test = torch.tensor(y_train, dtype=torch.long), torch.tensor(y_test, dtype=torch.long)
    
        net = nn.Sequential(
            nn.Linear(30, 60),
            nn.Sigmoid(),
            nn.Linear(60, 2),
            nn.ReLU(),
            nn.Softmax(dim=1)
        )
    
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(net.parameters(), lr=0.005)  # PyTorch suit to tiny learning rate
    
        error = list()
    
        for epoch in range(250):
            optimizer.zero_grad()
            y_pred = net(x_train)
            loss = criterion(y_pred, y_train)
            loss.backward()
            optimizer.step()
            error.append(loss.item())
    
        y_pred = net(x_test)
        y_pred = torch.argmax(y_pred, dim=1)
    
        # it is necessary that drawing the loss plot when we fine tuning the model
        plt.plot(np.arange(1, len(error)+1), error)
        plt.show()
    
        print(classification_report(y_test, y_pred, target_names=breast_cancer.target_names))
    
    

    模型性能:

                  precision    recall  f1-score   support
    
          setosa       1.00      1.00      1.00        14
      versicolor       1.00      1.00      1.00        16
       virginica       1.00      1.00      1.00        20
    
        accuracy                           1.00        50
       macro avg       1.00      1.00      1.00        50
    weighted avg       1.00      1.00      1.00        50
    
  • 相关阅读:
    130517Dev GridControl建立多行复杂表头(Banded View)时,统计列与对应列无法对齐的解决办法
    C&C++标准库
    Linux操作系统下的多线程编程详细解析
    Ubuntu12.04用户以root身份登录
    ubuntu永久修改主机名
    linux信号 linux signal
    淘宝api 登录验证
    淘宝开店 防骗 易赛加款诈骗|冲q币恶意差评
    面试..
    test
  • 原文地址:https://www.cnblogs.com/fengyubo/p/9141130.html
Copyright © 2011-2022 走看看