zoukankan      html  css  js  c++  java
  • Pytorch常用的对象类和继承机制(如果有)

      参考资料:

      Pytorch这个深度学习框架在设计的时候嵌入了非常丰富的继承机制。在通用的深度学习算法中使用到的组件其实都继承于某一个父类,比如:Dataset,DataLoader,Model等其实都蕴含了一个继承机制。这篇随笔打算梳理并剖析一下Pytorch里的这样一种继承现象。请注意,继承后的子类的构造方法第一行一定要调用super()方法哦。

      torch.nn.Module

    import torch.nn as nn
    import torch.nn.functional as F
    
    class Model(nn.Module):
        def __init__(self):
            super(Model, self).__init__()
            self.conv1 = nn.Conv2d(1, 20, 5)
            self.conv2 = nn.Conv2d(20, 20, 5)
    
        def forward(self, x):
            x = F.relu(self.conv1(x))
            return F.relu(self.conv2(x))

      平时我们在深度学习中提到的Model其实就是继承自torch.nn.Module。最重要且继承后必须重写的方法是forward,这个方法直接规定Model的前向运算方式。此外还有一些预定义的方法,比较重要的是:

    1. net.apply
    2. net.cuda
    3. net.train
    4. net.eval
    5. net.load_state_dict
    6. net.zero_grad

      torch.utils.data.Dataset

      官网并没有给这个类示例,可能是觉得这个类比较简单。正如描述中所说,torch.utils.data.Dataset是来handle键值对形式的数据格式的。我们必须实现两个函数,__getitem__和__len__。前者输入索引index返回对应的数据(和label),后者返回数据集总的大小(index的上限)。

      补充一句,在官网的Doc中torch.utils.data.Dataset下面就是torch.utils.data.IterableDataset,这个数据集格式和上面Dataset的区别在于它是来handle可迭代的数据集类型。其只需要重写一个__iter__函数,留待日后有需要的时候研究。

      torchvision.transforms

      这里跑题提一下torchvision里面经常用到的transforms,它本质也是nn.Module(不信看源码),其方便之处在于提供了丰富的内置处理图片的方法(transforms变换)。并且可以通过transforms.Compose方法把多个transform串序并到一起(类似nn.Sequential)。所以在继承一个torch.utils.data.Dataset的时候不妨多利用transforms哦(explicitly specify transform)。

      torch.utils.data.DataLoader

      从形式上来看,DataLoader是Dataset套的一层包装;从功能上来看,DataLoader才是最终提供给Model数据的人。这个组件基本不涉及继承机制(很少人去改写这个类),因此略过。

      torch.nn.modules.loss

      说完了Dataset和Model,不得不提的就是损失函数了,从torch.nn.modules.loss可以看出,所有的loss其实没啥特别的,说白了也是一个nn.Module。只不过它的forward方法比较特殊,Model的forward方法是给他一个data_tensor,而Loss的forward方法是给他一个target_tensor和一个(Model预测的)input_tensor,返回值一般是一个常数。

      torch.optim

      torch.optim这个包下预置了很多Optimizer比如SGD,Adam。其用法如下:

    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    optimizer = optim.Adam([var1, var2], lr=0.0001)
    for input, target in dataset:
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
    

      需要注意的是如果网络要在GPU上训练,则optimizer和model绑定应该在model转移到GPU上之后。

      torch.optim.lr_scheduler

      深度学习在训练时一个动态衰减的学习率是比较理想的。torch.optim.lr_scheduler提供了这样一个功能。其用法如下:

    model = [Parameter(torch.randn(2, 2, requires_grad=True))]
    optimizer = SGD(model, 0.1)
    scheduler = ExponentialLR(optimizer, gamma=0.9)
    
    for epoch in range(20):
        for input, target in dataset:
            optimizer.zero_grad()
            output = model(input)
            loss = loss_fn(output, target)
            loss.backward()
            optimizer.step()
        scheduler.step()
    

      注意scheduler.step要在optimizer.step之后。

  • 相关阅读:
    sencha touch学习心得之FormPanel
    从零开始学习Sencha Touch MVC应用之十九
    sencha touch中datepicker的汉化
    从零开始学习Sencha Touch MVC应用之十九
    从零开始学习Sencha Touch MVC应用之十八
    sencha touch中datepicker的汉化
    sencha touch学习心得之FormPanel
    常用内置模块(二)——logging模块
    包的介绍
    常用内置模块(一)——time、os、sys、random、shutil、pickle、json
  • 原文地址:https://www.cnblogs.com/chester-cs/p/15471259.html
Copyright © 2011-2022 走看看