zoukankan      html  css  js  c++  java
  • Task4.用PyTorch实现多层网络

    1.引入模块,读取数据 

    2.构建计算图(构建网络模型)

    3.损失函数与优化器

    4.开始训练模型

    5.对训练的模型预测结果进行评估

     1 import torch.nn.functional as F
     2 import torch.nn.init as init
     3 import torch
     4 from torch.autograd import Variable
     5 import matplotlib.pyplot as  plt
     6 import numpy as np
     7 import math
     8 %matplotlib inline
     9 #%matplotlib inline 可以在Ipython编译器里直接使用
    10 #功能是可以内嵌绘图,并且可以省略掉plt.show()这一步。
    11 
    12 xy=np.loadtxt('./data/diabetes.csv.gz',delimiter=',',dtype=np.float32)
    13 x_data=torch.from_numpy(xy[:,0:-1])#取除了最后一列的数据
    14 y_data=torch.from_numpy(xy[:,[-1]])#取最后一列的数据,[-1]加中括号是为了keepdim
    15 
    16 print(x_data.size(),y_data.size())
    17 #print(x_data.shape,y_data.shape)
    18 
    19 #建立网络模型
    20 class Model(torch.nn.Module):
    21     
    22     def __init__(self):
    23         super(Model,self).__init__()
    24         self.l1=torch.nn.Linear(8,6)
    25         self.l2=torch.nn.Linear(6,4)
    26         self.l3=torch.nn.Linear(4,1)
    27         
    28     def forward(self,x):
    29         out1=F.relu(self.l1(x))
    30         out2=F.dropout(out1,p=0.5)
    31         out3=F.relu(self.l2(out2))
    32         out4=F.dropout(out3,p=0.5)
    33         y_pred=F.sigmoid(self.l3(out3))
    34         return y_pred
    35     
    36 def weights_init(m):
    37     classname=m.__class__.__name__
    38     if classname.find('Linear')!=-1:
    39         m.weight.data=torch.randn(m.weight.data.size()[0],m.weight.data.size()[1])
    40         m.bias.data=torch.randn(m.bias.data.size()[0])
    41         
    42 #our model
    43 model=Model()
    44 model.apply(weights_init)
    45 criterion=torch.nn.BCELoss()
    46 optimizer=torch.optim.SGD(model.parameters(),lr=0.1)
    47 
    48 #training loop
    49 Loss=[]
    50 for epoch in range(2000):
    51     y_pred=model(x_data)
    52     loss=criterion(y_pred,y_data)
    53     if epoch%100 == 0:
    54         print("epoch = ",epoch," loss = ",loss.data)
    55         Loss.append(loss.data)
    56         optimizer.zero_grad()
    57         loss.backward()
    58         optimizer.step()
    59         
    60 hour_var = Variable(torch.randn(1,8))
    61 print("predict",model(hour_var).data[0]>0.5)
    62 plt.plot(Loss)

    这里说明一下,这个dataset不是自带的,需要大家自己去下载,我找的时候费了不少功夫,这里提供一个网址给大家下载https://github.com/LianHaiMiao/pytorch-lesson-zh/blob/master/dataSet/diabetes.csv.gz
    参考:https://blog.csdn.net/qq_35547281/article/details/89285980

  • 相关阅读:
    内存不足报错
    curl Command Download File
    How to POST JSON data with Curl from Terminal/Commandline to Test Spring REST?
    IOS常用手势详解
    OC中的NSNumber、NSArray、NSString的常用方法
    如何利用autolayout动态计算UITableViewCell的高度
    对AFN和ASI各自使用方法及区别的总结
    转:你真的懂iOS的autorelease吗?
    文件管理(续)
    IOS文件管理
  • 原文地址:https://www.cnblogs.com/NPC-assange/p/11348338.html
Copyright © 2011-2022 走看看