zoukankan      html  css  js  c++  java
  • 07 Mutilple Dimension Input

    Multiple Dimension Logistic Regression Model

    Logistic Regression Model

    [hat{y}=sigma(x*omega+b) (一维) ]

    [hat{y}^{(i)}=sigma(sum_{n=1}^N x_{n}^{(i)}*omega_{n}+b) (N维) ]

    from abc import ABC
    
    import torch
    import numpy as np
    import matplotlib.pyplot as plt
    
    xy = np.loadtxt('diabetes.csv.gz', delimiter=',', dtype=np.float32)  # 指定文件名, 分隔符, 数据类型 加载文件
    x_data = torch.from_numpy(xy[::, :-1])  # 选取xy的所有行, 除去最后一列后 所有列的数据
    y_data = torch.from_numpy(xy[:, [-1]])  # 选取xy所有行的最后一列数据, 这里-1需要加中括号使数据为矩阵形式, 否则将是一个向量
    
    
    class Model(torch.nn.Module, ABC):
        def __init__(self):
            super(Model, self).__init__()
            self.linear1 = torch.nn.Linear(8, 6)  # 维度 8 -> 6 -> 4 -> 1
            self.linear2 = torch.nn.Linear(6, 4)
            self.linear3 = torch.nn.Linear(4, 1)
            self.sigmoid = torch.nn.Sigmoid()
    
        def forward(self, x):
            x = self.sigmoid(self.linear1(x))
            x = self.sigmoid(self.linear2(x))
            x = self.sigmoid(self.linear3(x))
            return x
    
    
    model = Model()
    criterion = torch.nn.BCELoss(reduction='sum')
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    
    loss_list = []
    for epoch in range(100):
        y_pred = model(x_data)
        loss = criterion(y_pred, y_data)
        loss_list.append(loss.item())
        print(epoch, loss.item())
    
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    epoch_list = list(range(100))
    plt.plot(epoch_list, loss_list, c='b')
    plt.xlabel("epoch")
    plt.ylabel("loss")
    plt.show()
    

    损失

    100轮迭代

    10000轮迭代

    Reference

    https://www.bilibili.com/video/BV1Y7411d7Ys?p=7

  • 相关阅读:
    第十一课:容器监控和Prometheus介绍
    第五课:单机编排利器:Docker Compose
    第四课:企业级镜像仓库Harbor
    第三课:快速部署LNMP平台
    负载均衡
    中间系统到中间系统IS-IS
    ansble 常识
    centos 的两种密码破解方式
    git在windos下使用
    git 本地仓库和远程仓库搭建
  • 原文地址:https://www.cnblogs.com/vict0r/p/13607702.html
Copyright © 2011-2022 走看看