zoukankan      html  css  js  c++  java
  • PaddlePaddle 极简入门实践二:简单的线性回归

    import paddle.fluid as fluid
    import numpy


    #定义数据
    train_data = [[0], [1], [2], [3], [4], [5], [10]]
    y_true = [[3], [13], [23], [33], [43], [53], [103]]

    #定义网络
    x = fluid.layers.data(name="x", shape=[1], dtype="float32")
    y = fluid.layers.data(name="y", shape=[1], dtype="float32")
    y_predict = fluid.layers.fc(input=x, size=1, act=None) # 定义x与其有关系

    #定义损失函数
    cost = fluid.layers.square_error_cost(input=y_predict,label=y)#平方差损失函数
    avg_cost = fluid.layers.mean(cost)#求平均损失

    #定义优化方法
    sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.01)#定义SGD-随机梯度下降的学习率
    sgd_optimizer.minimize(avg_cost)

    cpu = fluid.CPUPlace()#此处使用CPU进行训练 GPU训练则移步之后更新的文章
    exe = fluid.Executor(cpu)#Executor是执行器
    prog=fluid.default_startup_program()#将刚刚定义的一堆堆赋值给prog这个变量名
    exe.run(prog)#准备开始!

    for i in range(500):
    for data_id in range(len(y_true)):
    # 需转换为numpy数组类型,这样可以传入训练
    data_x=numpy.array(train_data[data_id]).astype("float32").reshape((1,1))
    data_y = numpy.array(y_true[data_id]).astype("float32").reshape((1,1))
    outs = exe.run(
    feed={'x':data_x , 'y': data_y},
    fetch_list=[y_predict.name, avg_cost]) # feed为数据表 输入数据和标签数据
    print("正在训练第" + str(i + 1) + "次")
    # 观察结果
    print(outs)


    # 保存预测模型
    dirname = "./test01.inference.model/"
    fluid.io.save_inference_model(dirname, ['x'], [y_predict], exe)


    ######################################################################################


    import paddle.fluid as fluid
    import numpy

    import paddle.fluid as fluid
    import numpy

    dirname = "./test01.inference.model/"


    cpu = fluid.CPUPlace()
    exe = fluid.Executor(cpu)


    # 加载模型
    [inference_program, feed_target_names, fetch_targets] = fluid.io.load_inference_model(dirname, exe)

    # 目标数据
    datatype = "float32"
    test_data = numpy.array([[input("请输入数值")]]).astype(datatype)


    # 启动预测
    results = exe.run(inference_program,
    feed={feed_target_names[0]: test_data},
    fetch_list=fetch_targets)


    print(results[0][0])



  • 相关阅读:
    python 八进制数
    python hmac加盐
    python contextlib
    python hashlib
    python struct
    python namedtuple
    python datetime timezone 时区转化
    Android核心基础(手机卫士的一个知识点总结)
    TabHost结合RadioButton实现主页的导航效果
    Android SDK更新失败最新解决方案
  • 原文地址:https://www.cnblogs.com/zhulimin/p/13194503.html
Copyright © 2011-2022 走看看