zoukankan      html  css  js  c++  java
  • AI

    分类与回归

    分类(Classification)与回归(Regression)的区别在于输出变量的类型
    通俗理解,定量输出称为回归,或者说是连续变量预测;定性输出称为分类,或者说是离散变量预测。

    回归问题的预测结果是连续的,通常是用来预测一个值,如预测房价、未来的天气情况等等。
    一个比较常见的回归算法是线性回归算法(LR,Linear Regression)。
    回归分析用在神经网络上,其最上层不需要加上softmax函数,而是直接对前一层累加即可。
    回归是对真实值的一种逼近预测。

    分类问题的预测结果是离散的,是用于将事物打上一个标签,通常结果为离散值。
    分类通常是建立在回归之上,分类的最后一层通常要使用softmax函数进行判断其所属类别。
    分类并没有逼近的概念,最终正确结果只有一个,错误的就是错误的,不会有相近的概念。
    最常见的分类方法是逻辑回归(Logistic Regression),或者叫逻辑分类。

    MNIST数据集

    MNIST(Mixed National Institute of Standards and Technology database)是一个计算机视觉数据集;

    • 官方下载地址:http://yann.lecun.com/exdb/mnist/
    • 包含70000张手写数字的灰度图片,其中60000张为训练图像和10000张为测试图像;
    • 每一张图片都是28*28个像素点大小的灰度图像;

    如果无法从网络下载MNIST数据集,可从官方下载,然后存放在当前脚本目录下的新建MNIST_data目录即可;

    •  MNIST_data rain-images-idx3-ubyte.gz
    • MNIST_data rain-labels-idx1-ubyte.gz
    • MNIST_data 10k-images-idx3-ubyte.gz
    • MNIST_data 10k-labels-idx1-ubyte.gz

    示例程序

     1 # coding=utf-8
     2 from __future__ import print_function
     3 import tensorflow as tf
     4 from tensorflow.examples.tutorials.mnist import input_data  # MNIST数据集
     5 import os
     6 
     7 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
     8 
     9 old_v = tf.logging.get_verbosity()
    10 tf.logging.set_verbosity(tf.logging.ERROR)
    11 
    12 mnist = input_data.read_data_sets('MNIST_data', one_hot=True)  # 准备数据(如果本地没有数据,将从网络下载)
    13 
    14 
    15 def add_layer(inputs, in_size, out_size, activation_function=None, ):
    16     Weights = tf.Variable(tf.random_normal([in_size, out_size]))
    17     biases = tf.Variable(tf.zeros([1, out_size]) + 0.1, )
    18     Wx_plus_b = tf.matmul(inputs, Weights) + biases
    19     if activation_function is None:
    20         outputs = Wx_plus_b
    21     else:
    22         outputs = activation_function(Wx_plus_b, )
    23     return outputs
    24 
    25 
    26 def compute_accuracy(v_xs, v_ys):
    27     global prediction
    28     y_pre = sess.run(prediction, feed_dict={xs: v_xs})
    29     correct_prediction = tf.equal(tf.argmax(y_pre, 1), tf.argmax(v_ys, 1))
    30     accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    31     result = sess.run(accuracy, feed_dict={xs: v_xs, ys: v_ys})
    32     return result
    33 
    34 
    35 xs = tf.placeholder(tf.float32, [None, 784])  # 输入数据是784(28*28)个特征
    36 ys = tf.placeholder(tf.float32, [None, 10])  # 输出数据是10个特征
    37 
    38 prediction = add_layer(xs, 784, 10, activation_function=tf.nn.softmax)  # 激励函数为softmax
    39 
    40 cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction),
    41                                               reduction_indices=[1]))  # loss函数(最优化目标函数)选用交叉熵函数
    42 
    43 train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)  # train方法(最优化算法)采用梯度下降法
    44 
    45 sess = tf.Session()
    46 init = tf.global_variables_initializer()
    47 sess.run(init)
    48 
    49 for i in range(1000):
    50     batch_xs, batch_ys = mnist.train.next_batch(100)  # 每次只取100张图片,免得数据太多训练太慢
    51     sess.run(train_step, feed_dict={xs: batch_xs, ys: batch_ys})
    52     if i % 50 == 0:  # 每训练50次输出预测精度
    53         print(compute_accuracy(
    54             mnist.test.images, mnist.test.labels))

    程序运行结果:

    Extracting MNIST_data	rain-images-idx3-ubyte.gz
    Extracting MNIST_data	rain-labels-idx1-ubyte.gz
    Extracting MNIST_data	10k-images-idx3-ubyte.gz
    Extracting MNIST_data	10k-labels-idx1-ubyte.gz
    0.146
    0.6316
    0.7347
    0.7815
    0.8095
    0.8198
    0.8306
    0.837
    0.8444
    0.8547
    0.8544
    0.8578
    0.8651
    0.8649
    0.8705
    0.8704
    0.8741
    0.8719
    0.8753
    0.8756

    问题处理

    问题现象

    执行程序提示“Please use tf.data to implement this functionality.”等信息

    WARNING:tensorflow:From D:/Anliven/Anliven-Code/PycharmProjects/TempTest/TempTest_2.py:13: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
    Instructions for updating:
    Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
    WARNING:tensorflow:From C:UsersanlivenAppDataLocalcondacondaenvsmlcclibsite-packages	ensorflowcontriblearnpythonlearndatasetsmnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
    Extracting MNIST_data	rain-images-idx3-ubyte.gz
    Instructions for updating:
    Please write your own downloading logic.
    WARNING:tensorflow:From C:UsersanlivenAppDataLocalcondacondaenvsmlcclibsite-packages	ensorflowcontriblearnpythonlearndatasetsmnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
    Instructions for updating:
    Please use tf.data to implement this functionality.
    Extracting MNIST_data	rain-labels-idx1-ubyte.gz
    ......
    ......

    处理方法

    参考链接:https://stackoverflow.com/questions/49901806/tensorflow-importing-mnist-warnings

  • 相关阅读:
    css页面自适应 媒体查询
    微信小程序rich-text中的nodes属性
    解析base64数据流---加载pdf
    用伪元素完成箭头
    搭建vue --2.x
    搭建Vue项目 vue-cli vue1.x
    Chrome----TCP
    单进程VS多进程
    线程VS进程
    Chrome---network模块---Timing
  • 原文地址:https://www.cnblogs.com/anliven/p/10434433.html
Copyright © 2011-2022 走看看