import库,加载mnist数据集。
设置学习率,迭代次数,batch并行计算数量,以及log显示。
这里设置了占位符,输入是batch * 784的矩阵,由于是并行计算,所以None实际上代表并行数。输出是10类,因为mnist数据集是手写数字0-9,所以分成10类是很正常的。
W和b是变量。
第一行代码建立了一个softmax模型,意思是,将10类最后的输出结果再通过softmax函数换算一下,softmax函数如下:
,其实就是做了一次转换,让各个输出变成了概率,且概率之和等于1.
也要了解到,softmax 回归是 logistic 回归的一般形式。
cost函数定义为方差,这是可以的,但是更常用的算法应该是交叉熵。
optimizer定义为梯度下降法,学习率已经在最前面被定义完毕。
最后初始化所有的变量。
这里写了训练过程,只是被精简了很多,直接被函数替代了。
大概意思是,在每一次迭代中,又对整个batch进行迭代,这里是以一个batch为单位的。
之后sess.run(),将损失存储起来,之后进行平均损失的计算。
当然这个平均损失在每一次迭代(外层循环)后,会逐渐变小。