zoukankan      html  css  js  c++  java
  • 寒假学习进度14:对mnist数据集实现逻辑回归

    转自:https://blog.csdn.net/weixin_38859557/article/details/80795476
    #-*- coding:utf-8 _*- """ @author:bluesli @file: logistic_regression.py @time: 2018/06/23 """ ''' 逻辑回归是基于数字字符逻辑的,所以英文字母是logistic,而不是logic ''' ''' 1:通过input_data 获取数据 2:分别获取对应的训练样本(784个像素点)练标签,测试样本,测试标签 3:训练标签(train_label)shape中的10是通过0,1编码的,即数字0-9,如果是对应数据就显示1,不是就是0;通过train_label[0]看他的形式 4:初始化x,y变量(placeholder(float,[None,784])None表示无穷的意思,(不知道有多少样本就用他表示) 5"初始化:w,b (tf.zeros[784,10]分类个数是10,零值初始化,也可以高斯初始化;b:tf.zeros[10] 6:tf.nn.softmax(tf.matmul(x,w)+b) softmax传入的是一个分值; 这是模型,要做数据归一化; 7:定义损失函数loss:log(b)正确类别的概率值; cost:tf.reduce_mean(-tf.reduce_sum(y*tf.log(actv),reduction_indices=1)) 8:定义优化器 9:最小化损失函数 10:通过模型进行预测:tf.equal(tf.argmax(actv,1),tf.argmax(y,1)),对比预测值最大值,和label(真实)看一下是不是对的; 11:accr:(精度)tf.reduce(pre,'float') 转化成浮点型,然后加起来求均值,就可以知道预测的精度了; 12: ''' #argmax函数: #tf.rank(array).输出矩阵的维数 #tensorflow要想输出需要在后面加上eval() #shape返回矩阵是几行几列的; #tf.argmax(arr,0)eval()返回当前数组每一列最大值的索引(从零开始的) #tf.argmx(arr,1).eval()返回当前数组每一行最大值的索引 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data import matplotlib.pyplot as plt #获取各个种类数据 mnist = input_data.read_data_sets('data/',one_hot=True) train_img = mnist.train.images #记录的是每一个特征值的大小(可以理解为得分) 也可以这样理解:相当于计算成绩时,有语文,数学,英语等科目,然后下面有对应的分数,然后乘以相应的权重值也就得到了总分数; train_label = mnist.train.labels test_img = mnist.test.images test_label = mnist.test.labels print(train_label[0]) print(len(train_img[0])) #定义变量 x = tf.placeholder(tf.float32,[None,784]) y = tf.placeholder(tf.float32,[None,10]) w = tf.Variable(tf.zeros([784,10])) #也可以用高斯分布进行初始化,需要有784个权重值,做的是10分类的任务,所有10就可以了, b = tf.Variable(tf.zeros([10])) #定义模型: actv = tf.nn.softmax(tf.matmul(x,w)+b) #逻辑回归解决的是一个二分类的问题,所以要将逻辑回归升级成多分类的任务;,softmax(多分类任务) #softmax的输入实际上是一个得分值:#记录的是每一个特征值的大小(可以理解为得分) 也可以这样理解:相当于计算成绩时,有语文,数学,英语等科目,然后下面有对应的分数,然后乘以相应的权重值也就得到了总分数; #x*w得到的是10个分值,也就是对应0,1,2,3...9所占的分数; cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(actv),reduction_indices=1)) #逻辑回归函数损失值是-log(p)类型的;p是什么呢,模型得到10个值,然后做归一化操作,然后得到0-9的概率值;没有理解需要详细深入理解逻辑回归问题; #y*log得到的是真实类别的那个概率值,由于每一个样本对应的label中是这样的,[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.],相乘就只有属于那个分类的值了, #reduce_mean 平均值; optimizer = tf.train.GradientDescentOptimizer(0.3) train = optimizer.minimize(cost) #定义predict,arg_max(1)找到每一行的最大值返回他的下标; # a = tf.arg_max(actv,1) pred = tf.equal(tf.arg_max(actv,1),tf.arg_max(y,1)) accr = tf.reduce_mean(tf.cast(pred,'float')) #变量初始化 init_op = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init_op) train_epochs = 6 batch_size = 1000 train_step = 2 avg_cost = 0 for epoch in range(train_epochs): num_batch = int(mnist.train.num_examples/batch_size)+1 batch_example,batch_label = mnist.train.next_batch(batch_size) # plt.scatter(batch_example,batch_label,c='r') # plt.show() print(batch_label.shape,batch_example.shape) for i in range(num_batch): feed_seed = {x:batch_example,y:batch_label} sess.run(train,feed_dict=feed_seed) print('w=',sess.run(w,feed_seed),'b=',sess.run(b,feed_dict=feed_seed),'loss=',sess.run(cost,feed_dict=feed_seed)) avg_cost += sess.run(cost,feed_dict=feed_seed) avg_cost = avg_cost/num_batch if train_epochs % train_step ==0: # feed_seed2 = {x: batch_example, y: batch_label} # print(sess.run(cost,feed_dict=feed_seed2)) # print('------------') # print(sess.run(w,feed_dict=feed_seed),sess.run(b,feed_seed2)) # print('--------------') # print(cost) # print(cost/num_batch) feed_train = {x:batch_example,y:batch_label} feed_test = {x:test_img,y:test_label} print('train accr:%f'%sess.run(accr,feed_dict=feed_train)) print('test accr:%f'%sess.run(accr,feed_dict=feed_test)) print('avg_cost:%f'%avg_cost) # print(sess.run(a,feed_dict=feed_train))

      

  • 相关阅读:
    java项目数据库从oracle迁移到mysql 中 java部分的一些修改
    mysql表名等大小写敏感问题、字段类型timestamp、批量修改表名、oracle查询历史操作记录等
    navicat premium相关应用(将oracle数据库迁移到mysql等)
    Java byte 类型的取值范围是-128~127
    idea中debug:
    chrome里面模拟手机上打开网页的场景方法
    Dealloc weak nil
    用七牛sdk传递图片到七牛服务器
    iOS block 本质研究
    UIWebView JSContext相关问题
  • 原文地址:https://www.cnblogs.com/yangqqq/p/14460450.html
Copyright © 2011-2022 走看看