zoukankan      html  css  js  c++  java
  • tensorflow线性回归预测鲍鱼数据

    代码如下:

    import tensorflow as tf
    import csv
    import numpy as np
    import matplotlib.pyplot as plt
    # 设置学习率
    learning_rate = 0.01
    # 设置训练次数
    train_steps = 1000
    #数据地址:http://archive.ics.uci.edu/ml/datasets/Abalone
    with open('./data/abalone.data') as file:
        reader = csv.reader(file)
        a, b = [], []
        for item in reader:
            b.append(item[8])
            del(item[8])
            a.append(item)
        file.close()
    x_data = np.array(a)
    new_x_data = []
    for i in x_data[:,0]:
        if i == 'M':
            i = 1
        elif i == 'F':
            i = 2
        elif i == 'I':
            i = 3
        new_x_data.append(i)
    new_data = np.array(new_x_data)
    x_data = np.delete(x_data,0,axis=1)
    print(x_data.shape)
    print(new_data.shape)
    x_data = np.column_stack((new_data,x_data)) #添加一列,将new_data添加到x_data中
    print(x_data)
    print(x_data[:,0])
    y_data = np.array(b)
    for i in range(len(x_data)):
        y_data[i] = float(y_data[i])
        for j in range(len(x_data[i])):
            x_data[i][j] = float(x_data[i][j])
    # 定义各影响因子的权重
    weights = tf.Variable(np.ones([8,1]),dtype = tf.float32)
    x_data_ = tf.placeholder(tf.float32, [None, 8])
    y_data_ = tf.placeholder(tf.float32, [None, 1])
    bias = tf.Variable(1.0, dtype = tf.float32)#定义偏差值
    # 构建模型为:y_model = w1X1 + w2X2 + w3X3 + w4X4 + w5X5 + w6X6 + w7X7 + w8X8 + bias
    y_model = tf.add(tf.matmul(x_data_ , weights), bias)
    # 定义损失函数
    loss = tf.reduce_mean(tf.pow((y_model - y_data_), 2))
    #训练目标为损失值最小,学习率为0.01
    train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print("Start training!")
        lo = []
        sample = np.arange(train_steps)
        for i in range(train_steps):
            for (x,y) in zip(x_data, y_data):
                z1 = x.reshape(1,8)
                z2 = y.reshape(1,1)
                sess.run(train_op, feed_dict = {x_data_ : z1, y_data_ : z2})
            l = sess.run(loss, feed_dict = {x_data_ : z1, y_data_ : z2})
            lo.append(l)
        print(weights.eval(sess))
        print(bias.eval(sess))
        # 绘制训练损失变化图
        plt.plot(sample, lo, marker="*", linewidth=1, linestyle="--", color="red")
        plt.title("The variation of the loss")
        plt.xlabel("Sampling Point")
        plt.ylabel("Loss")
        plt.grid(True)
        plt.show()
  • 相关阅读:
    GridView小知识1
    ASP 中 GridView 的粗浅入门
    SQL连接
    Microsoft Visual Studio 2010 Express for Windows Phone 新建文件 设置启动
    转载一个应届计算机毕业生2012求职之路
    百度之星平衡负载(3.23)
    查找字符串中首个非重复字符
    CreateMutex函数
    关于“Visual Studio 遇到了异常,可能是由于某个扩展导致的”的解决
    无法打开预编译头文件:“Debug\****.pch”: No such file or directory 的解决办法
  • 原文地址:https://www.cnblogs.com/ywjfx/p/12303360.html
Copyright © 2011-2022 走看看