zoukankan      html  css  js  c++  java
  • tensorflow学习(一)

    今天开始学习tensorflow框架,从极客学院下载了官方中文教程(15年翻译的),第一天开始学习第一章ng基本流程和原理,作为前奏。然后写了代码,验证一下,准确率确实非常高,非常好用。把代码上传,作为以后备用。

     1 import tensorflow as tf
     2 import numpy as np
     3 import math
     4 
     5 class Model:
     6     def __init__(self,w = np.empty(None),b = None):
     7         self.b = b
     8         self.w = w
     9 
    10     def predict(self, input):
    11         return np.dot(input,self.w) + self.b
    12 
    13 data = np.float32(np.random.rand(1000,2))
    14 label = np.dot(data,np.array([[0.100],[0.200]])) + 0.3
    15 m,n = data.shape
    16 
    17 num_train = int(m * 0.6)
    18 num_validation = int(m * 0.2)
    19 num_test = int(m * 0.2)
    20 
    21 data_train = data[:num_train,:]
    22 data_validation = data[num_train:(num_train+num_validation),:]
    23 data_test = data[(num_train+num_validation):,:]
    24 
    25 label_train = label[:num_train,:]
    26 label_validation = label[num_train:(num_train+num_validation),:]
    27 label_test = label[(num_train+num_validation):,:]
    28 
    29 w = tf.Variable(tf.random_uniform([2,1],-1.0,1.0))
    30 b = tf.Variable(tf.zeros([1]))
    31 
    32 y_train = tf.matmul(data_train,w) + b
    33 loss = tf.reduce_mean(tf.square(label_train - y_train))
    34 '''
    35 bestModel = Model()
    36 minRMSE = (1 << 31) -1
    37 alphas = [0.1,0.3,0.5];
    38 iters = [100,150,200,250];
    39 for iter in iters:
    40     for alpha in alphas:
    41         optimizer = tf.train.GradientDescentOptimizer(alpha)
    42         train = optimizer.minimize(loss)
    43         init_state = tf.global_variables_initializer()
    44         with tf.Session() as sess:
    45             sess.run(init_state)
    46             for step in range(0,iter):
    47                 sess.run(train)
    48             model = Model(sess.run(w),sess.run(b))
    49             p = model.predict(data_validation)
    50             rmse = np.sqrt(np.mean(np.square(label_validation - p)))
    51             if rmse < minRMSE:
    52                 minRMSE = rmse
    53                 bestModel = model
    54 np.save("E:\Python\models\weights.npy",bestModel.w)
    55 np.save("E:\Python\models\b.npy",bestModel.b)
    56 '''
    57 
    58 weights = np.load("E:\Python\models\weights.npy")
    59 b = np.load("E:\Python\models\b.npy")
    60 
    61 model = Model(weights,b)
    62 predicts = model.predict(data_test)
    63 print(predicts)
    64 print(label_test)
    65 print(label_test - predicts)
    View Code
  • 相关阅读:
    给vs2012轻松换肤
    几种软件常用授权方式总结
    Discuz X2多人斗地主[消耗论坛积分]小体积版本,仅25MB!
    关于Socket 设置 IPAddress.Any 情况下,出现服务器积极拒绝的问题
    以前看过一个压缩过的.exe,运行会播放长达半小时的动画,却只有60KB,个人认为其中的原理
    VisualSvn Server安装和使用
    socket短时间内重连需注意的问题
    PostgreSQL在何处处理 sql查询之十一
    PostgreSQL在何处处理 sql查询之十三
    PostgreSQL在何处处理 sql查询之十四
  • 原文地址:https://www.cnblogs.com/txq157/p/6708086.html
Copyright © 2011-2022 走看看