zoukankan      html  css  js  c++  java
  • tensorflow和keras混用

    在tensorflow中可以调用keras,有时候让模型的建立更加简单。如下这种是官方写法:

    import tensorflow as tf
    from keras import backend as K
    from keras.layers import Dense
    from keras.objectives import categorical_crossentropy
    from keras.metrics import categorical_accuracy as accuracy
    from tensorflow.examples.tutorials.mnist import input_data
    # create a tf session,and register with keras。
    sess = tf.Session()
    K.set_session(sess)
    
    # this place holder is the same with input layer in keras
    img = tf.placeholder(tf.float32, shape=(None, 784))
    # keras layers can be called on tensorflow tensors
    x = Dense(128, activation='relu')(img)
    x = Dense(128, activation='relu')(x)
    preds = Dense(10, activation='softmax')(x)
    # label
    labels = tf.placeholder(tf.float32, shape=(None, 10))
    # loss function
    loss = tf.reduce_mean(categorical_crossentropy(labels, preds))
    
    train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss)
    
    mnist_data = input_data.read_data_sets('MNIST_data', one_hot=True)
    
    # initialize all variables
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    
    with sess.as_default():
        for i in range(1000):
            batch = mnist_data.train.next_batch(50)
            train_step.run(feed_dict={img:batch[0],
                                      labels:batch[1]})
    
    acc_value = accuracy(labels, preds)
    with sess.as_default():
        print(acc_value.eval(feed_dict={img:mnist_data.test.images,
                                        labels:mnist_data.test.labels}))

    上述代码中,在训练阶段直接采用了tf的方式,甚至都没有定义keras的model!官网说 最重要的一步就是这里:

    K.set_session(sess)

    创建一个TensorFlow会话并且注册Keras。这意味着Keras将使用我们注册的会话来初始化它在内部创建的所有变量。 
    keras的层和模型都充分兼容tensorflow的各种scope, 例如name scope,device scope和graph scope。

    经过测试,下面这种不需要k.set_session()也是可以的。

    
    
    import tensorflow as tf
    from tensorflow.examples.tutorials.mnist import input_data

    # build module

    img = tf.placeholder(tf.float32, shape=(None, 784))
    labels = tf.placeholder(tf.float32, shape=(None, 10))

    x = tf.keras.layers.Dense(128, activation='relu')(img)
    x = tf.keras.layers.Dense(128, activation='relu')(x)
    prediction = tf.keras.layers.Dense(10, activation='softmax')(x)

    loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=prediction, labels=labels))

    train_optim = tf.train.AdamOptimizer().minimize(loss)
    path="/home/vv/PycharmProject/Cnnsvm/MNIST_data"
    mnist_data = input_data.read_data_sets(path, one_hot=True)

    with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)

    for _ in range(1000):
    batch_x, batch_y = mnist_data.train.next_batch(50)
    sess.run(train_optim, feed_dict={img: batch_x, labels: batch_y})

    acc_pred = tf.keras.metrics.categorical_accuracy(labels, prediction)
    pred = sess.run(acc_pred, feed_dict={labels: mnist_data.test.labels, img: mnist_data.test.images})

    print('accuracy: %.3f' % (sum(pred) / len(mnist_data.test.labels)))
    print(pred)

    如果在下载导入mnist数据出错,可以在网站上下好,本地导入。

    mnist_data = input_data.read_data_sets(path, one_hot=True)
    x1 = tf.layers.conv2d(img2,64,2)
    x2 = tf.keras.layers.Conv2D(img2,64,2)
    x3 = tf.keras.layers.Conv2D(64,2)(img2)

    x1和x3卷积效果相同

  • 相关阅读:
    React组件二
    React组件一
    React新接触
    清除浮动的方法
    div section article aside的理解
    html引入外部的jswenjian
    绘制扇形,空心文字,实心文字,颜色线性 放射性渐变
    绘制扇形空心 实心文字 ,颜色线性渐变,颜色放射性渐变
    绘制圆弧的几种简单方法
    求两个有序数组的中位数
  • 原文地址:https://www.cnblogs.com/a-little-v/p/9772836.html
Copyright © 2011-2022 走看看