zoukankan      html  css  js  c++  java
  • tensorflow实现最基本的神经网络 + 对比GD、SGD、batch-GD的训练方法

    参考博客:https://zhuanlan.zhihu.com/p/27853521

    该代码默认是梯度下降法,可自行从注释中选择其他训练方法

    在异或问题上,由于训练的样本数较少,神经网络简单,训练结果最好的是GD梯度下降法。

      1 # -*- coding:utf-8 -*-
      2 
      3 # 将tensorflow 引入并命名tf
      4 import tensorflow as tf
      5 # 矩阵操作库numpy,命名为np
      6 import numpy as np
      7 
      8 '''
      9 生成数据 
     10 用python使用tensorflow时,输入到网络中的训练数据需要以np.array的类型
     11 存在。并且要限制dtype为32bit以下。变量后跟着“.astype('float32')”总可以满足要求
     12 '''
     13 # X和Y是4个数据的矩阵,X[i]和Y[i]的值始终对应
     14 X = [[0, 0], [0, 1], [1, 0], [1, 1]]
     15 Y = [[0], [1], [1], [0]]
     16 X = np.array(X).astype('int16')
     17 Y = np.array(Y).astype('int16')
     18 
     19 '''
     20 定义变量
     21 '''
     22 # 网络结构:2维输入--> 2维隐含层 -->1维输出
     23 # 学习速率(learing rate):0.0001
     24 
     25 D_input = 2
     26 D_hidden = 2
     27 D_label = 1
     28 lr = 0.0001
     29 '''
     30 容器
     31 '''
     32 # x为列向量 可变样本数*D_input; y为列向量 1*D_label 用GPU训练需要float32以下精度
     33 x = tf.placeholder(tf.float32, [None, D_input], name=None)
     34 t = tf.placeholder(tf.float32, [None, D_label], name=None)
     35 
     36 '''
     37 隐含层
     38 '''
     39 # 初始化权重W [D_input ,D_hidden ]
     40 # truncated_normal 正对数函数,返回随机截短的正态分布,默认均值为0,区间为[-2.0,2.0]
     41 W_h1 = tf.Variable(tf.truncated_normal([D_input, D_hidden], stddev=1.0), name="W_h")
     42 # 初始化b D_hidden 一维
     43 b_h1 = tf.Variable(tf.constant(0.1, shape=[D_hidden]), name="b_h")
     44 # 计算Wx+b  可变样本数*D_hidden
     45 pre_act_h1 = tf.matmul(x, W_h1) + b_h1
     46 # 计算a(Wx+b) a代表激活函数,有tf.nn.relu()、tf.nn.tanh()、tf.nn.sigmoid()
     47 act_h1 = tf.nn.relu(pre_act_h1, name=None)
     48 
     49 '''
     50 输出层
     51 '''
     52 W_o = tf.Variable(tf.truncated_normal([D_hidden, D_label],  stddev=1.0), name="W_o")
     53 b_o = tf.Variable(tf.constant(0.1, shape=[D_label]), name="b_o")
     54 pre_act_o = tf.matmul(act_h1, W_o) + b_o
     55 y = tf.nn.relu(pre_act_o, name=None)
     56 
     57 '''
     58 损失函数和更新方法
     59 '''
     60 loss = tf.reduce_mean((y - t)**2)
     61 train_step = tf.train.AdamOptimizer(lr).minimize(loss)
     62 '''
     63 训练
     64 sess = tf.InteractiveSession()是比较方便的创建方法。也有sess =
     65 tf.Session()方式,但该方式无法使用tensor.eval()快速取值等功能
     66 '''
     67 sess = tf.InteractiveSession()
     68 # 初始化权重
     69 # tf.tables_initializer(name="init_all_tables").run()调试时报错,可能是版本问题
    70 # Add the variable initializer Op. 71 init = tf.global_variables_initializer() 72 sess.run(init) 73 # 训练网络 74 ''' 75 GD(Gradient Descent):X和Y是4组不同的训练数据。上面将所有数据输入到网络, 76 算出平均梯度来更新一次网络的方法叫做GD。效率很低,也容易卡在局部极小值,但更新方向稳定 77 ''' 78 79 T = 100000 # 训练次数 80 for i in range(T): 81 sess.run(train_step, feed_dict={x: X, t: Y}) 82 83 84 ''' 85 SGD(Gradient Descent):一次只输入一个训练数据到网络,算出梯度来更新一次网络的方法叫做SGD。 86 效率高,适合大规模学习任务,容易挣脱局部极小值(或鞍点),但更新方向不稳定。代码如下 87 ''' 88 ''' 89 T = 100000 # 训练几epoch 90 for i in range(T): 91 for j in range(X.shape[0]): # X.shape[0]表示样本个数 X.shape[0] 报错 'Placeholder:0', which has shape '(?, 2) 92 sess.run(train_step, feed_dict={x: [X[j]], t: [Y[j]]}) 93 ''' 94 ''' 95 batch-GD:这是上面两个方法的折中方式。每次计算部分数据的平均梯度来更新权重。 96 部分数据的数量大小叫做batch_size,对训练效果有影响。一般10个以下的也叫mini-batch-GD。代码如下: 97 ''' 98 ''' 99 T = 10000 # 训练几epoch 100 b_idx = 0 # batch计数 101 b_size = 2 # batch大小 102 for i in range(T): 103 while b_idx <= X.shape[0]: 104 sess.run(train_step, feed_dict={x: X[b_idx:b_idx+b_size], t: Y[b_idx:b_idx+b_size]}) 105 b_idx += b_size # 更新batch计数 106 ''' 107 108 109 ''' 110 shuffle:SGD和batch-GD由于只用到了部分数据。若数据都以相同顺序进入网络会使得随后的epoch影响很小。 111 shuffle是用于打乱数据在矩阵中的排列顺序,提高后续epoch的训练效果。代码如下: 112 ''' 113 ''' 114 # shuffle 115 def shufflelists(lists): # 多个序列以相同顺序打乱 116 ri = np.random.permutation(len(lists[1])) 117 out = [] 118 for l in lists: 119 out.append(l[ri]) 120 return out 121 122 # 训练网络 123 T = 100000 # 训练几epoch 124 b_idx = 0 # batch计数 125 b_size = 2 # batch大小 126 for i in range(T): # 每次epoch都打乱顺 127 X, Y = shufflelists([X, Y]) 128 while b_idx <= X.shape[0]: 129 sess.run(train_step, feed_dict={x: X[b_idx:b_idx + b_size], t: Y[b_idx:b_idx + b_size]}) 130 b_idx += b_size # 更新batch计数 131 ''' 132 # 预测数据 133 print(sess.run(y, feed_dict={x: X})) 134 print(sess.run(act_h1, feed_dict={x: X}))
  • 相关阅读:
    新安装的CentOS 7不能上网
    修改机器名
    读书笔记-MySQL运维内参08-索引实现原理2
    读书笔记-MySQL运维内参08-索引实现原理1
    读书笔记-MySQL运维内参07-InnoDB数据存储结构
    MySQL 参数设置-持续更新
    读书笔记-Mycat权威指南-10-分片规则
    读书笔记-Mycat权威指南-09-全局序列号
    读书笔记-Mycat权威指南-08-Mycat中的Join
    读书笔记-Mycat权威指南-03-Mycat中的概念
  • 原文地址:https://www.cnblogs.com/yamin/p/7210255.html
Copyright © 2011-2022 走看看