zoukankan      html  css  js  c++  java
  • 笔记3:逻辑回归(分批次训练)

    相关库导入

    import torch
    import pandas as pd
    import numpy as np
    import matplotlib.pyplot as plt
    from torch import nn
    %matplotlib inline
    

    数据读入及预处理

    data = pd.read_csv('E:/datasets/dataset/credit-a.csv', header = None)
    X_data = data.iloc[:, :-1].values
    X = torch.from_numpy(X_data).type(torch.float32)
    Y_data = data.iloc[:, -1].replace(-1, 0).values.reshape(-1, 1)
    Y = torch.from_numpy(Y_data).type(torch.float32)
    

    数据格式:

    这里有几个关键点:

    • 数据没有表头,因此在读入的时候要设置 header = None
    • data.iloc[] 可以获得相应的数据。返回的是Series类型,用values可以获得数值数组
    • 类别是-1和1,二分类问题,因此可以用replace()方法将标签为-1的转换为0
    • 要注意转换数据的shape,以及数据的类型

    模型定义

    model = nn.Sequential(
        nn.Linear(15, 1),
        nn.Sigmoid()
    )
    loss_func = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001)
    

    关键点:

    • nn.Sequential() 定义一个模型序列
    • 损失函数使用交叉熵损失函数
    • 优化器使用Adam

    相关参数定义

    batch_size = 16
    num_batch = len(data) // batch_size
    epochs = 1000
    

    模型训练

    for epoch in range(epochs):
        for batch in range(num_batch):
            start = batch * batch_size
            end = start + batch_size
            x = X[start: end]
            y = Y[start: end]
            y_pred = model(x)
            loss = loss_func(y_pred, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    

    这里使用的是手动切分数据

    训练结果

    model.state_dict() # 查看训练得到的参数
    

    ((model(X).data.numpy() > 0.5).astype('int') == Y.numpy()).mean()  # 查看正确率,输出结果为 0.8667687595712098
    

    注意:

    • 模型训练之后,model(X)已经不是单纯的数据了,而是包含data,grad,grad_fn
    • 与 0.5 比较之后,返回bool值,因此需要类型转换一下
  • 相关阅读:
    如何判断某个设备文件是否存在
    shell中export理解误区
    linux命令之tail
    国内较快的gnu镜像:北京交通大学镜像
    Cmake的交叉编译
    linux 命令之grep
    makefile之变量赋值
    makefile之VPATH和vpath的使用
    arm汇编进入C函数分析,C函数压栈,出栈,传参,返回值
    Jlink 软件断点和硬件断点
  • 原文地址:https://www.cnblogs.com/miraclepbc/p/14332084.html
Copyright © 2011-2022 走看看