zoukankan      html  css  js  c++  java
  • PyTorch 实现异或XOR运算

    1. 异或运算

     2. 实现

    
    
     1 # 利用Pytorch解决XOR问题
     2 import torch
     3 import torch.nn as nn
     4 import torch.nn.functional as F
     5 import torch.optim as optim
     6 import numpy as np
     7 
     8 data = np.array([[1, 0, 1], [0, 1, 1],
     9                  [1, 1, 0], [0, 0, 0]], dtype='float32')
    10 x = data[:, :2]
    11 y = data[:, 2]
    12 
    13 
    14 # 初始化权重变量
    15 def weight_init_normal(m):
    16     classname = m.__class__.__name__ #是获取类名,得到的结果classname是一个字符串
    17     if classname.find('Linear') != -1:  #判断这个类名中,是否包含"Linear"这个字符串,字符串的find()方法,检索这个字符串中是否包含另一个字符串
    18         m.weight.data.normal_(0.0, 1.)
    19         m.bias.data.fill_(0.)
    20 
    21 
    22 class XOR(nn.Module):
    23     def __init__(self):
    24         super(XOR, self).__init__()
    25         self.fc1 = nn.Linear(2, 3)   # 隐藏层 3个神经元
    26         self.fc2 = nn.Linear(3, 4)   # 隐藏层 4个神经元
    27         self.fc3 = nn.Linear(4, 1)   # 输出层 1个神经元
    28 
    29     def forward(self, x):
    30         h1 = F.sigmoid(self.fc1(x))  # 之前也尝试过用ReLU作为激活函数, 太容易死亡ReLU了.
    31         h2 = F.sigmoid(self.fc2(h1))
    32         h3 = F.sigmoid(self.fc3(h2))
    33         return h3
    34 
    35 
    36 net = XOR()
    37 net.apply(weight_init_normal) #相当于net.weight_init_normal()
    38  #apply方式的调用是递归的,即net这个类和其子类(如果有),挨个调用一次weight_init_normal()方法。
    39 x = torch.Tensor(x.reshape(-1, 2))
    40 y = torch.Tensor(y.reshape(-1, 1))
    41 
    42 # 定义loss function
    43 criterion = nn.BCELoss()  # MSE
    44 # 定义优化器
    45 optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9)  # SGD
    46 # 训练
    47 for epoch in range(500):
    48     optimizer.zero_grad()   # 清零梯度缓存区
    49     out = net(x)
    50     loss = criterion(out, y)
    51     print(loss)
    52     loss.backward()
    53     optimizer.step()  # 更新
    54 
    55 # 测试
    56 test = net(x)
    57 print("input is {}".format(x.detach().numpy()))
    58 print('out is {}'.format(test.detach().numpy()))

    来源:(1条消息) PyTorch——解决异或问题XOR_我是大黄同学呀的博客-CSDN博客_pytorch 异或

    稍微改了一下网络结构,添加少量注释,理解第16-17,37行。

    结构如下:

     A Neural Network Playground (tensorflow.org)

  • 相关阅读:
    把word文档转换成swf格式
    利用“审阅”批改作业
    注意:QQ空间加密并不安全
    MySQLDB 错误 InterfaceError(0,")
    Linux 文件大小 文件夹大小 磁盘大小
    JavaArrays类fill()方法详解
    构造函数
    ASP部署错误"未能加载类型..."
    试AJAX出错两则
    ASP.Net如何区分开发状态与实际应用状态
  • 原文地址:https://www.cnblogs.com/vvzhang/p/14063429.html
Copyright © 2011-2022 走看看