zoukankan      html  css  js  c++  java
  • Breast Cancer on PyTorch

    Breast Cancer on PyTorch

    Code

    # encoding:utf8
    
    from sklearn.datasets import load_breast_cancer
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import classification_report
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from matplotlib import pyplot as plt
    import numpy as np
    
    
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.l1 = nn.Linear(30, 60)
            self.a1 = nn.Sigmoid()
            self.l2 = nn.Linear(60, 2)
            self.a2 = nn.ReLU()
            self.l3 = nn.Softmax(dim=1)
    
        def forward(self, x):
            x = self.l1(x)
            x = self.a1(x)
            x = self.l2(x)
            x = self.a2(x)
            x = self.l3(x)
            return x
    
    
    if __name__ == '__main__':
        breast_cancer = load_breast_cancer()
    
        x_train, x_test, y_train, y_test = train_test_split(breast_cancer.data, breast_cancer.target, test_size=0.25)
        x_train, x_test = torch.tensor(x_train, dtype=torch.float), torch.tensor(x_test, dtype=torch.float)
        y_train, y_test = torch.tensor(y_train, dtype=torch.long), torch.tensor(y_test, dtype=torch.long)
    
        net = Net()
    
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(net.parameters(), lr=0.005)  # PyTorch suit to tiny learning rate
    
        error = list()
    
        for epoch in range(250):
            optimizer.zero_grad()
            y_pred = net(x_train)
            loss = criterion(y_pred, y_train)
            loss.backward()
            optimizer.step()
            error.append(loss.item())
    
        y_pred = net(x_test)
        y_pred = torch.argmax(y_pred, dim=1)
    
        # it is necessary that drawing the loss plot when we fine tuning the model
        plt.plot(np.arange(1, len(error)+1), error)
        plt.show()
    
        print(classification_report(y_test, y_pred, target_names=breast_cancer.target_names))
    
    

    损失函数图像:

    nn.Sequential

    # encoding:utf8
    
    from sklearn.datasets import load_breast_cancer
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import classification_report
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from matplotlib import pyplot as plt
    import numpy as np
    
    
    if __name__ == '__main__':
        breast_cancer = load_breast_cancer()
    
        x_train, x_test, y_train, y_test = train_test_split(breast_cancer.data, breast_cancer.target, test_size=0.25)
        x_train, x_test = torch.tensor(x_train, dtype=torch.float), torch.tensor(x_test, dtype=torch.float)
        y_train, y_test = torch.tensor(y_train, dtype=torch.long), torch.tensor(y_test, dtype=torch.long)
    
        net = nn.Sequential(
            nn.Linear(30, 60),
            nn.Sigmoid(),
            nn.Linear(60, 2),
            nn.ReLU(),
            nn.Softmax(dim=1)
        )
    
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(net.parameters(), lr=0.005)  # PyTorch suit to tiny learning rate
    
        error = list()
    
        for epoch in range(250):
            optimizer.zero_grad()
            y_pred = net(x_train)
            loss = criterion(y_pred, y_train)
            loss.backward()
            optimizer.step()
            error.append(loss.item())
    
        y_pred = net(x_test)
        y_pred = torch.argmax(y_pred, dim=1)
    
        # it is necessary that drawing the loss plot when we fine tuning the model
        plt.plot(np.arange(1, len(error)+1), error)
        plt.show()
    
        print(classification_report(y_test, y_pred, target_names=breast_cancer.target_names))
    
    

    模型性能:

                  precision    recall  f1-score   support
    
       malignant       0.91      0.91      0.91        54
          benign       0.94      0.94      0.94        89
    
        accuracy                           0.93       143
       macro avg       0.93      0.93      0.93       143
    weighted avg       0.93      0.93      0.93       143
    
  • 相关阅读:
    Pyhont 高阶函数
    Python 函数式编程
    Python 递归函数
    Python 函数的参数定义
    Lniux学习-AWK使用
    Windows10 下 VirtualBox6 中 Centos8 无法安装"增强功能"
    Linux学习-Shell-系统启动过程与执行方式
    接口测试-工具介绍
    Linux学习-Sed 命令
    Linux学习-命令行参数、函数
  • 原文地址:https://www.cnblogs.com/fengyubo/p/12328267.html
Copyright © 2011-2022 走看看