zoukankan      html  css  js  c++  java
  • Tensorflow%20实战Google深度学习框架 4.2.2 自定义损失函数源代码

    import os
    import tab
    import tensorflow as tf
    from numpy.random import RandomState
    print "hello tensorflow 4.1"
    
    batch_size = 8
    
    
    x = tf.placeholder(tf.float32,shape=(None,2),name='x-input')
    y_ = tf.placeholder(tf.float32,shape=(None,1),name='y-input')
    
    
    w1 = tf.Variable(tf.random_normal([2,1],stddev=1,seed=1))
    #w2 = tf.Variable(tf.random_normal([3,1],stddev=1,seed=1))
    y = tf.matmul(x,w1)
    
    #a = tf.matmul(x,w1)
    #y = tf.matmul(a,w2)
    
    loss_less = 10
    loss_more = 1
    loss = tf.reduce_sum(tf.where(tf.greater(y,y_),(y-y_)*loss_more,(y_-y)*loss_less))
    train_step = tf.train.AdamOptimizer(0.001).minimize(loss)
    
    
    rdm = RandomState(1)
    dataset_size = 128
    X = rdm.rand(dataset_size,2)
    Y = [[x1 + x2 +rdm.rand()/10.0-0.05] for (x1 ,x2 ) in X]
    
    
    with tf.Session() as sess:
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        print sess.run(w1)
        STEPS = 5000
        for i in range(STEPS):
            start = (i * batch_size) % dataset_size
            end = min(start+batch_size,dataset_size)
            sess.run(train_step, feed_dict = {x: X[start:end], y_: Y[start:end]} )
            print sess.run(w1)
    
    
    print "end "
    

      

  • 相关阅读:
    Java 集合(静态导入)
    Java 集合 (Collections、Arrays)
    Java 异常
    Java 多态
    Java 继承

    内网服务器配置访问公网
    替换centos的原生yum源为阿里云yum源
    centos7安装杀毒软件ClamAV
    linux程序名称带devel跟不带的区别
  • 原文地址:https://www.cnblogs.com/a9999/p/9916545.html
Copyright © 2011-2022 走看看