zoukankan      html  css  js  c++  java
  • 训练神经网络解决而分类问题

    #导入库
    import tensorflow as tf
    from numpy.random import RandomState

    #定义训练数据batch的大小
    batch_size = 8

    #定义神经网络的参数
    w1 = tf.Variable(tf.random_normal([2,3],stddev=1,seed=1))
    w2 = tf.Variable(tf.random_normal([3,1],stddev=1,seed=1))

    #在shape的一个纬度上使用None可以方便使用不大的batch大小。在训练时需要把数据分成比较小的batch,但是在测试时,可以一次性
    #使用全部的数据。当书记比较小时这样比较方便测试,但数据比较大时,将大量数据放入一个batch可能会导致内存溢出。
    x = tf.placeholder(tf.float32,shape=(None,2),name='x-input')
    y_ = tf.placeholder(tf.float32,shape=(None,1),name='y-input')

    #定义神经网络的前向传播的过程。
    a = tf.matmul(x,w1)
    y = tf.matmul(a,w2)

    #定义损失函数和反向传播的算法。
    cross_entropy = -tf.reduce_mean(y_*tf.log(tf.clip_by_value(y,1e-10,1.0)))
    train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy)

    #随机生成一个数据模拟集。
    rdm = RandomState(1)
    dataset_size = 128
    X = rdm.rand(dataset_size,2)
    #定义规则来给出样本的标签。在这里所有的x1+x2<1的样例被认为是正样本,而定义其他为负样本。和TensorFlow游乐场中的表示法不大一样的地方是,
    #在这里使用0来表示负样本,1表示正样本。大部分解决分类的神经网络都会采用0和1的表示方法
    Y = [[int(x1+x2<1)] for (x1,x2) in X]

    #创建一个会话来运行TensorFlow程序 。
    with tf.Session() as sess:
    #初始化变量,在这个版本中,以前的initialize——all_variables()函数被取消了,取而代之的是global——variables——initializer()
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    print(sess.run(w1))
    print(sess.run(w2))
    #设定训练的轮数。
    STEPS = 5000
    for i in range(STEPS):
    #每次选取batch——size个样本进行训练。
    start = (i*batch_size)%dataset_size
    end = min(start+batch_size,dataset_size)
    #通过选取的样本训练神经网路并更新参数。
    sess.run(train_step,feed_dict={x:X[start:end],y_:Y[start:end]})
    if i % 1000 == 0:
    #每隔一段时间计算在所有数据上的交叉熵并输出。
    total_cross_entropy = sess.run(cross_entropy,feed_dict={x:X,y_:Y})
    print("After %d trainig step(s),cross entropy on all data is %g" % (i,total_cross_entropy))

    print(sess.run(w1))
    print(sess.run(w2))
  • 相关阅读:
    日志/异常处理(nnlog+traceback)
    Excel操作
    商品管理系统
    大乐透作业
    随机生成密码作业
    时间相关的模块
    os模块
    sys模块
    Pytho中dict(或对象)与json之间的互相转化
    Python三元表达式和列表生成式
  • 原文地址:https://www.cnblogs.com/hmy-blog/p/6573610.html
Copyright © 2011-2022 走看看