zoukankan      html  css  js  c++  java
  • [tensorflow]异或门的实现

    一段小程序:待理解

    import tensorflow as tf
    import numpy as np
    #输入训练数据,这里是python的list, 也可以定义为numpy的ndarray
    x_data = [[1., 0.], [0., 1.], [0., 0.], [1., 1.]]
    #定义占位符,占位符在运行图的时候必须feed数据
    x = tf.placeholder(tf.float32, shape = [None, 2])
    #训练数据的标签,注意维度
    y_data = [[1], [1], [0], [0]]
    y = tf.placeholder(tf.float32, shape = [None, 1])
    #定义variables,在运行图的过程中会被按照优化目标改变和保存
    weights = {'w1': tf.Variable(tf.random_normal([2, 16])), 'w2': tf.Variable(tf.random_normal([16, 1]))}
    bias = {'b1': tf.Variable(tf.zeros([1])), 'b2': tf.Variable(tf.zeros([1]))}
    #定义对于节点的操作函数
    def nn(x, weights, bias):
        d1 = tf.matmul(x, weights['w1']) + bias['b1']
        d1 = tf.nn.relu(d1)
        d2 = tf.matmul(d1, weights['w2']) + bias['b2']
        d2 = tf.nn.sigmoid(d2)
        return d2
    #预测值
    pred = nn(x, weights, bias)
    #损失函数
    cost = tf.reduce_mean(tf.square(y - pred))
    #学习率
    learning_rate = 0.01
    #定义tf.train用来训练
    # train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)  ## max_step: 20000, loss: 0.002638
    train_step = tf.train.AdamOptimizer(learning_rate).minimize(cost)  ## max_step: 2000, loss: 0.000014
    #初始化参数,图运行的一开始必须初始化所有变量
    init = tf.global_variables_initializer()
    # correct_pred = tf.equal(tf.argmax(y, 1), tf.argmax(pred, 1))
    # accuracy = tf.reduce_mean(tf.cast(correct_pred, 'float'))
    
    #运行图,with语句调用其后面函数的__enter()__函数,将返回值赋给as后面的参数,并在块的最后调用__exit()__函数,相当于
    #sess = tf.Sessions(),
    
    with tf.Session() as sess:
        sess.run(init)
        max_step = 2000
        for i in range(max_step + 1):
            sess.run(train_step, feed_dict = {x: x_data, y: y_data})
            loss = sess.run(cost, feed_dict = {x: x_data, y: y_data})
            # acc = sess.run(accuracy, feed_dict = {x: x_data, y: y_data})
            # 输出训练误差和测试数据的标签
            if i % 100 == 0:
                print('step: '+ str(i) + '  loss:' + "{:.6f}".format(loss)) #+ '    accuracy:' + "{:.6f}".format(acc))
                print(sess.run(pred, feed_dict = {x: x_data}))
        print('end')
    
    #sess.close()
    

      

  • 相关阅读:
    HDU1251 统计难题
    字典树模板
    HDU5536 Chip Factory(01字典树)
    函数的返回值
    函数的使用原则
    文件修改
    函数
    文件内指针移动
    文件操作模式
    字符编码
  • 原文地址:https://www.cnblogs.com/elpsycongroo/p/9153649.html
Copyright © 2011-2022 走看看