zoukankan      html  css  js  c++  java
  • mxnet(gluon)—— 模型、数据集、损失函数、优化子等类、接口大全

    1. 数据集

    dataset_train = gluon.data.ArrayDataset(X_train, y_train)
    data_iter = gluon.data.DataLoader(dataset_train, batch_size, shuffle=True)
    
    for data, label in data_iter:
        ...

    2. 模型

    • gluon.nn:神经网络

      • gluon.nn.Sequential(),可添加:

        • gluon.nn.Flatten() ⇒ Flattens the input to two dimensional,将输入平坦为 2 维矩阵,是一种操作,而非添加进新的层

          net = gluon.nn.Sequencial()
          with net.name_scope():
              net.add(gluon.nn.Flatten())
        • gluon.nn.Dense:全连接

        • gluon.nn.Dropout(drop_prob1)
      
      # 序列化神经网络模型
      
      net = gluon.nn.Sequential()
      
      with net.name_scope():
          net.add(gluon.nn.Dense(1))
              # Dense(1):表示输出值的维度,
          # 一层的神经网络相当于线性回归
      
      # 参数初始化
      
      net.collect_params().initialize(mxnet.init.Normal(sigma=1))

    3. 训练器(Trainer)

    仅保存参数及超参,以及根据 batch size 进行参数更新:

    trainer = gluon.Trainer(net.collect_params(), optimizer='sgd',
            optimizer_params={'learning_rate': learning_rate, 'weight_decay': weight_decay})
    ....
    for data, label in data_iter:
        ...
        trainer.step(batch_size)
    

    4. 自动求导:autograd

    • autograd.is_training() ⇒ 训练过程还是测试预测过程:

      对于 dropout 型网络,训练过程因为 dropout 随机性的存在,模型是变化的,测试过程中节点全部参与,没有dropout;

  • 相关阅读:
    Socket基础一
    MyBatisPlus【目录】
    MyBatis(十一)扩展:自定义类型处理器
    MyBatis(十一)扩展:批量操作
    MyBatis(十一)扩展:存储过程
    MyBatis(十一)扩展:分页插件PageHelper
    MyBatis(十)插件 4
    09月07日总结
    09月06日总结
    09月03日总结
  • 原文地址:https://www.cnblogs.com/mtcnn/p/9421065.html
Copyright © 2011-2022 走看看