zoukankan      html  css  js  c++  java
  • PyTorch


    什么是 hub

    hub(modelzoo)主要用来调用其他人训练好的模型和参数
    Facebook官方博客表示,PyTorch Hub是一个简易API和工作流程,为复现研究提供了基本构建模块,包含预训练模型库。
    并且,PyTorch Hub还支持Colab,能与论文代码结合网站Papers With Code集成,用于更广泛的研究。

    github:https://github.com/pytorch/hub
    模型:https://pytorch.org/hub/research-models


    使用示例

    import torch
    model = torch.hub.load('pytorch/vision:v0.4.2', 'deeplabv3_resnet101', pretrained=True)
    model.eval()
    # 下载会显示下载数据和进度
    # Downloading: "https://github.com/pytorch/vision/archive/v0.4.2.zip" to /Users/xx/.cache/torch/hub/v0.4.2.zip
    # Downloading: "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth" to /Users/xx/.cache/torch/hub/checkpoints/resnet101-5d3b4d8f.pth
    
    Downloading: "https://github.com/pytorch/vision/archive/v0.4.2.zip" to /Users/shushu/.cache/torch/hub/v0.4.2.zip
    Downloading: "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth" to /Users/shushu/.cache/torch/hub/checkpoints/resnet101-5d3b4d8f.pth
    
    HBox(children=(IntProgress(value=0, max=178728960), HTML(value='')))
    
    torch.hub.list('pytorch/vision:v0.4.2')
    
    Using cache found in C:UsersAdministrator/.cache	orchhubpytorch_vision_v0.4.2
    
    ['alexnet',
     'deeplabv3_resnet101',
     'densenet121',
     'densenet161',
     'densenet169',
     'densenet201',
     'fcn_resnet101',
     'googlenet',
     'inception_v3',
     'mobilenet_v2',
     'resnet101',
     'resnet152',
     'resnet18',
     'resnet34',
     'resnet50',
     'resnext101_32x8d',
     'resnext50_32x4d',
     'shufflenet_v2_x0_5',
     'shufflenet_v2_x1_0',
     'squeezenet1_0',
     'squeezenet1_1',
     'vgg11',
     'vgg11_bn',
     'vgg13',
     'vgg13_bn',
     'vgg16',
     'vgg16_bn',
     'vgg19',
     'vgg19_bn',
     'wide_resnet101_2',
     'wide_resnet50_2']
    

    # Download an example image from the pytorch website
    import urllib
    url, filename = ("https://github.com/pytorch/hub/raw/master/dog.jpg", "dog.jpg")
    try: urllib.URLopener().retrieve(url, filename)
    except: urllib.request.urlretrieve(url, filename)
    
    # sample execution (requires torchvision)
    from PIL import Image
    from torchvision import transforms
    
    input_image = Image.open(filename)
    preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
    
    # move the input and model to GPU for speed if available
    if torch.cuda.is_available():
        input_batch = input_batch.to('cuda')
        model.to('cuda')
    
    with torch.no_grad():
        output = model(input_batch)['out'][0]
    output_predictions = output.argmax(0)
    
    # create a color pallette, selecting a color for each class
    palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
    colors = torch.as_tensor([i for i in range(21)])[:, None] * palette
    colors = (colors % 255).numpy().astype("uint8")
    
    # plot the semantic segmentation predictions of 21 classes in each color
    r = Image.fromarray(output_predictions.byte().cpu().numpy()).resize(input_image.size)
    r.putpalette(colors)
    
    import matplotlib.pyplot as plt
    plt.imshow(r)
    plt.show()
    


    1、查询可用的模型

    用户可以使用torch.hub.list()这个API列出repo中所有可用的入口点。比如你想知道PyTorch Hub中有哪些可用的计算机视觉模型:

    >>> torch.hub.list('pytorch/vision')
    >>>
    ['alexnet',
    'deeplabv3_resnet101',
    'densenet121',
    ...
    'vgg16',
    'vgg16_bn',
    'vgg19',
     'vgg19_bn']
    

    2、加载模型

    在上一步中能看到所有可用的计算机视觉模型,如果想调用其中的一个,也不必安装,只需一句话就能加载模型。

    model = torch.hub.load('pytorch/vision', 'deeplabv3_resnet101', pretrained=True)
    

    至于如何获得此模型的详细帮助信息,可以使用下面的API:

    print(torch.hub.help('pytorch/vision', 'deeplabv3_resnet101'))
    

    如果模型的发布者后续加入错误修复和性能改进,用户也可以非常简单地获取更新,确保自己用到的是最新版本:

    model = torch.hub.load(..., force_reload=True)
    对于另外一部分用户来说,稳定性更加重要,他们有时候需要调用特定分支的代码。例如pytorch_GAN_zoo的hub分支:
    
    model = torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub', 'DCGAN', pretrained=True, useGPU=False)
    

    3、查看模型可用方法

    从PyTorch Hub加载模型后,你可以用dir(model)查看模型的所有可用方法。以bertForMaskedLM模型为例:

    >>> dir(model)
    >>>
    ['forward'
    ...
    'to'
    'state_dict',
    ]
    

    forward

    如果你对forward方法感兴趣,使用help(model.forward) 了解运行运行该方法所需的参数。

    >>> help(model.forward)
    >>>
    Help on method forward in module pytorch_pretrained_bert.modeling:
    forward(input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None)
    ...
    

    支持 Colab

    PyTorch Hub中提供的模型也支持Colab。

    进入每个模型的介绍页面后,你不仅可以看到GitHub代码页的入口,甚至可以一键进入Colab运行模型Demo。



    对于模型发布者

    如果你希望把自己的模型发布到PyTorch Hub上供所有用户使用,可以去PyTorch Hub的GitHub页发送拉取请求。若你的模型符合高质量、易重复、最有利的要求,Facebook官方将会与你合作。

    一旦拉取请求被接受,你的模型将很快出现在PyTorch Hub官方网页上,供所有用户浏览。

    目前该网站上已经有18个提交的模型,英伟达率先提供支持,他们在PyTorch Hub已经发布了Tacotron2和WaveGlow两个TTS模型。

    图片

    发布模型的方法也是比较简单的,开发者只需在自己的GitHub存储库中添加一个简单的hubconf.py文件,在其中枚举运行模型所需的依赖项列表即可。

    比如,torchvision中的hubconf.py文件是这样的:

    # Optional list of dependencies required by the package
    dependencies = ['torch']
    
    from torchvision.models.alexnet import alexnet
    from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161
    from torchvision.models.inception import inception_v3
    from torchvision.models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152,
    resnext50_32x4d, resnext101_32x8d
    from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1
    from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn
    from torchvision.models.segmentation import fcn_resnet101, deeplabv3_resnet101
    from torchvision.models.googlenet import googlenet
    from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0
    from torchvision.models.mobilenet import mobilenet_v2
    

    Facebook官方向模型发布者提出了以下三点要求:

    1、每个模型文件都可以独立运行和执行
    2、不需要PyTorch以外的任何包
    3、不需要单独的入口点,让模型在创建时可以无缝地开箱即用

    Facebook还建议发布者最小化对包的依赖性,减少用户加载模型进行实验的阻力。


    更多资料

  • 相关阅读:
    ad_imh
    pc send instructor pc ad
    数据、模型、IT系统认知
    量化投资认知
    LinAlgError: Last 2 dimensions of the array must be square
    转:Hadoop大数据开发基础系列:七、Hive基础
    Run-Time Check Failure #2
    0x00007FFC8C5325E7 (ucrtbased.dll)处(位于 DataStructure.exe 中)引发的异常: 0xC0000005: 读取位置 0xFFFFFFFFFFFFFFFF 时发生访问冲突。
    栈与后缀表达式C实现
    Jupyter使用
  • 原文地址:https://www.cnblogs.com/fldev/p/14466499.html
Copyright © 2011-2022 走看看