zoukankan      html  css  js  c++  java
  • Iris Classification on PyTorch

    Breast Cancer on PyTorch

    Code

    # encoding:utf8
    
    from sklearn.datasets import load_breast_cancer
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import classification_report
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from matplotlib import pyplot as plt
    import numpy as np
    
    
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.l1 = nn.Linear(30, 60)
            self.a1 = nn.Sigmoid()
            self.l2 = nn.Linear(60, 2)
            self.a2 = nn.ReLU()
            self.l3 = nn.Softmax(dim=1)
    
        def forward(self, x):
            x = self.l1(x)
            x = self.a1(x)
            x = self.l2(x)
            x = self.a2(x)
            x = self.l3(x)
            return x
    
    
    if __name__ == '__main__':
        breast_cancer = load_breast_cancer()
    
        x_train, x_test, y_train, y_test = train_test_split(breast_cancer.data, breast_cancer.target, test_size=0.25)
        x_train, x_test = torch.tensor(x_train, dtype=torch.float), torch.tensor(x_test, dtype=torch.float)
        y_train, y_test = torch.tensor(y_train, dtype=torch.long), torch.tensor(y_test, dtype=torch.long)
    
        net = Net()
    
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(net.parameters(), lr=0.005)  # PyTorch suit to tiny learning rate
    
        error = list()
    
        for epoch in range(250):
            optimizer.zero_grad()
            y_pred = net(x_train)
            loss = criterion(y_pred, y_train)
            loss.backward()
            optimizer.step()
            error.append(loss.item())
    
        y_pred = net(x_test)
        y_pred = torch.argmax(y_pred, dim=1)
    
        # it is necessary that drawing the loss plot when we fine tuning the model
        plt.plot(np.arange(1, len(error)+1), error)
        plt.show()
    
        print(classification_report(y_test, y_pred, target_names=breast_cancer.target_names))
    
    

    损失函数图像:

    nn.Sequential

    # encoding:utf8
    
    from sklearn.datasets import load_breast_cancer
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import classification_report
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from matplotlib import pyplot as plt
    import numpy as np
    
    
    if __name__ == '__main__':
        breast_cancer = load_breast_cancer()
    
        x_train, x_test, y_train, y_test = train_test_split(breast_cancer.data, breast_cancer.target, test_size=0.25)
        x_train, x_test = torch.tensor(x_train, dtype=torch.float), torch.tensor(x_test, dtype=torch.float)
        y_train, y_test = torch.tensor(y_train, dtype=torch.long), torch.tensor(y_test, dtype=torch.long)
    
        net = nn.Sequential(
            nn.Linear(30, 60),
            nn.Sigmoid(),
            nn.Linear(60, 2),
            nn.ReLU(),
            nn.Softmax(dim=1)
        )
    
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(net.parameters(), lr=0.005)  # PyTorch suit to tiny learning rate
    
        error = list()
    
        for epoch in range(250):
            optimizer.zero_grad()
            y_pred = net(x_train)
            loss = criterion(y_pred, y_train)
            loss.backward()
            optimizer.step()
            error.append(loss.item())
    
        y_pred = net(x_test)
        y_pred = torch.argmax(y_pred, dim=1)
    
        # it is necessary that drawing the loss plot when we fine tuning the model
        plt.plot(np.arange(1, len(error)+1), error)
        plt.show()
    
        print(classification_report(y_test, y_pred, target_names=breast_cancer.target_names))
    
    

    模型性能:

                  precision    recall  f1-score   support
    
          setosa       1.00      1.00      1.00        14
      versicolor       1.00      1.00      1.00        16
       virginica       1.00      1.00      1.00        20
    
        accuracy                           1.00        50
       macro avg       1.00      1.00      1.00        50
    weighted avg       1.00      1.00      1.00        50
    
  • 相关阅读:
    如何让一个对话框全屏对话框
    学习网络请求返回json对应的model
    学习网络请求返回json对应的model
    android获取未安装APK签名信息及MD5指纹
    android获取未安装APK签名信息及MD5指纹
    Android stadio 模板 liveTemplate不管用
    Android stadio 模板 liveTemplate不管用
    android 事件传递机制
    android systemtrace 报错
    我今天的收获,必备stadio 插件
  • 原文地址:https://www.cnblogs.com/fengyubo/p/9141130.html
Copyright © 2011-2022 走看看