zoukankan      html  css  js  c++  java
  • 机器学习笔记(2):线性回归-使用gluon

    代码来自:https://zh.gluon.ai/chapter_supervised-learning/linear-regression-gluon.html

     1 from mxnet import ndarray as nd
     2 from mxnet import autograd
     3 from mxnet import gluon
     4 
     5 num_inputs = 2
     6 num_examples = 1000
     7 
     8 true_w = [2, -3.4]
     9 true_b = 4.2
    10 
    11 X = nd.random_normal(shape=(num_examples, num_inputs)) #1000行,2列的数据集
    12 y = true_w[0] * X[:, 0] + true_w[1] * X[:, 1] + true_b #已知答案的结果
    13 y += .01 * nd.random_normal(shape=y.shape) #加入噪音
    14 
    15 #1 随机读取10行数据
    16 batch_size = 10
    17 dataset = gluon.data.ArrayDataset(X, y)
    18 data_iter = gluon.data.DataLoader(dataset, batch_size, shuffle=True)
    19 
    20 #2 定义回归模型
    21 net = gluon.nn.Sequential()
    22 net.add(gluon.nn.Dense(1))
    23 
    24 #3 参数初始化
    25 net.initialize()
    26 
    27 #4 损失函数
    28 square_loss = gluon.loss.L2Loss()
    29 
    30 #5 指定训练方法
    31 trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.1})
    32     
    33 #6 训练
    34 epochs = 5
    35 batch_size = 10
    36 for e in range(epochs):
    37     total_loss = 0
    38     for data, label in data_iter:
    39         with autograd.record():
    40             output = net(data)
    41             loss = square_loss(output, label)
    42         loss.backward()
    43         trainer.step(batch_size)
    44         total_loss += nd.sum(loss).asscalar()
    45     print("Epoch %d, average loss: %f" % (e, total_loss/num_examples))
    46 
    47 #7 输出结果
    48 dense = net[0]
    49 print(true_w)
    50 print(dense.weight.data())
    51 print(true_b)
    52 print(dense.bias.data())

    相对上一篇纯手动的处理方式,用gluon后代码明显更精简了。

  • 相关阅读:
    git log中文乱码问题
    局域网映射公网IP
    Android Studio 的一些配置
    Android Studio的安装
    adb的安装
    python的安装
    CentOS 7 上部署 java web 项目
    SQL——SQL语句总结(8)
    SQL——SQL语句总结(7)
    SQL——SQL语句总结(6)
  • 原文地址:https://www.cnblogs.com/yjmyzz/p/7774166.html
Copyright © 2011-2022 走看看