zoukankan      html  css  js  c++  java
  • tensorflow基础模型之RandomForest(随机森林)算法

    随机森林算法原理请参照上篇:随机森林。数据依旧为MNIST数据集。

    代码如下:

    from __future__ import print_function

    # Ignore all GPUs, tf random forest does not benefit from it.
    import os

    import tensorflow as tf
    from tensorflow.contrib.tensor_forest.python import tensor_forest
    from tensorflow.python.ops import resources

    os.environ["CUDA_VISIBLE_DEVICES"] = ""

    # 导入 MNIST 数据
    from tensorflow.examples.tutorials.mnist import input_data

    mnist = input_data.read_data_sets("./tmp/data/", one_hot=False)

    # 参数
    num_steps = 500 # Total steps to train
    batch_size = 1024 # 每批处理样本数
    num_classes = 10 # 10个数字=>10个分类
    num_features = 784 # 每张图片 28x28 像素 => 784特征
    num_trees = 10
    max_nodes = 1000

    # 输入数据
    X = tf.placeholder(tf.float32, shape=[None, num_features])
    # 用数字表示随机森林中的标签(类id)
    Y = tf.placeholder(tf.int32, shape=[None])

    # 随机森林参数
    hparams = tensor_forest.ForestHParams(num_classes=num_classes,
                                        num_features=num_features,
                                        num_trees=num_trees,
                                        max_nodes=max_nodes).fill()

    # 建立随机森林
    forest_graph = tensor_forest.RandomForestGraphs(hparams)
    # 获取训练图,计算损失率
    train_op = forest_graph.training_graph(X, Y)
    loss_op = forest_graph.training_loss(X, Y)

    # 计算准确率
    infer_op, _, _ = forest_graph.inference_graph(X)
    correct_prediction = tf.equal(tf.argmax(infer_op, 1), tf.cast(Y, tf.int64))
    accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    # 初始化变量和森林资源
    init_vars = tf.group(tf.global_variables_initializer(),
                        resources.initialize_resources(resources.shared_resources()))

    # 启动TensorFlow会话
    sess = tf.Session()

    # 初始化
    sess.run(init_vars)

    # 训练
    for i in range(1, num_steps + 1):
      # 准备数据
      # 获取一批图片数据
      batch_x, batch_y = mnist.train.next_batch(batch_size)
      _, l = sess.run([train_op, loss_op], feed_dict={X: batch_x, Y: batch_y})
      if i % 50 == 0 or i == 1:
          acc = sess.run(accuracy_op, feed_dict={X: batch_x, Y: batch_y})
          print('Step %i, Loss: %f, Acc: %f' % (i, l, acc))

    # 测试模型
    test_x, test_y = mnist.test.images, mnist.test.labels
    print("Test Accuracy:", sess.run(accuracy_op, feed_dict={X: test_x, Y: test_y}))
    ---------------------

  • 相关阅读:
    MVC HTTP 错误 403.14
    web.config connectionStrings 数据库连接字符串的解释(转载)
    bootstrap div 弹出与关闭
    jquery操作select(取值,设置选中)
    VS2013使用EF6与mysql数据库
    php中创建和调用webservice接口示例
    java script 确认框
    mysql中判断记录是否存在方法比较
    根据Unicode编码用C#语言把它转换成汉字的代码
    微软架构师解读Windows Server 2008 R2新特性
  • 原文地址:https://www.cnblogs.com/hyhy904/p/11182994.html
Copyright © 2011-2022 走看看