zoukankan      html  css  js  c++  java
  • PyTorch


    什么是 hub

    hub(modelzoo)主要用来调用其他人训练好的模型和参数
    Facebook官方博客表示,PyTorch Hub是一个简易API和工作流程,为复现研究提供了基本构建模块,包含预训练模型库。
    并且,PyTorch Hub还支持Colab,能与论文代码结合网站Papers With Code集成,用于更广泛的研究。

    github:https://github.com/pytorch/hub
    模型:https://pytorch.org/hub/research-models


    使用示例

    import torch
    model = torch.hub.load('pytorch/vision:v0.4.2', 'deeplabv3_resnet101', pretrained=True)
    model.eval()
    # 下载会显示下载数据和进度
    # Downloading: "https://github.com/pytorch/vision/archive/v0.4.2.zip" to /Users/xx/.cache/torch/hub/v0.4.2.zip
    # Downloading: "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth" to /Users/xx/.cache/torch/hub/checkpoints/resnet101-5d3b4d8f.pth
    
    Downloading: "https://github.com/pytorch/vision/archive/v0.4.2.zip" to /Users/shushu/.cache/torch/hub/v0.4.2.zip
    Downloading: "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth" to /Users/shushu/.cache/torch/hub/checkpoints/resnet101-5d3b4d8f.pth
    
    HBox(children=(IntProgress(value=0, max=178728960), HTML(value='')))
    
    torch.hub.list('pytorch/vision:v0.4.2')
    
    Using cache found in C:UsersAdministrator/.cache	orchhubpytorch_vision_v0.4.2
    
    ['alexnet',
     'deeplabv3_resnet101',
     'densenet121',
     'densenet161',
     'densenet169',
     'densenet201',
     'fcn_resnet101',
     'googlenet',
     'inception_v3',
     'mobilenet_v2',
     'resnet101',
     'resnet152',
     'resnet18',
     'resnet34',
     'resnet50',
     'resnext101_32x8d',
     'resnext50_32x4d',
     'shufflenet_v2_x0_5',
     'shufflenet_v2_x1_0',
     'squeezenet1_0',
     'squeezenet1_1',
     'vgg11',
     'vgg11_bn',
     'vgg13',
     'vgg13_bn',
     'vgg16',
     'vgg16_bn',
     'vgg19',
     'vgg19_bn',
     'wide_resnet101_2',
     'wide_resnet50_2']
    

    # Download an example image from the pytorch website
    import urllib
    url, filename = ("https://github.com/pytorch/hub/raw/master/dog.jpg", "dog.jpg")
    try: urllib.URLopener().retrieve(url, filename)
    except: urllib.request.urlretrieve(url, filename)
    
    # sample execution (requires torchvision)
    from PIL import Image
    from torchvision import transforms
    
    input_image = Image.open(filename)
    preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
    
    # move the input and model to GPU for speed if available
    if torch.cuda.is_available():
        input_batch = input_batch.to('cuda')
        model.to('cuda')
    
    with torch.no_grad():
        output = model(input_batch)['out'][0]
    output_predictions = output.argmax(0)
    
    # create a color pallette, selecting a color for each class
    palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
    colors = torch.as_tensor([i for i in range(21)])[:, None] * palette
    colors = (colors % 255).numpy().astype("uint8")
    
    # plot the semantic segmentation predictions of 21 classes in each color
    r = Image.fromarray(output_predictions.byte().cpu().numpy()).resize(input_image.size)
    r.putpalette(colors)
    
    import matplotlib.pyplot as plt
    plt.imshow(r)
    plt.show()
    


    1、查询可用的模型

    用户可以使用torch.hub.list()这个API列出repo中所有可用的入口点。比如你想知道PyTorch Hub中有哪些可用的计算机视觉模型:

    >>> torch.hub.list('pytorch/vision')
    >>>
    ['alexnet',
    'deeplabv3_resnet101',
    'densenet121',
    ...
    'vgg16',
    'vgg16_bn',
    'vgg19',
     'vgg19_bn']
    

    2、加载模型

    在上一步中能看到所有可用的计算机视觉模型,如果想调用其中的一个,也不必安装,只需一句话就能加载模型。

    model = torch.hub.load('pytorch/vision', 'deeplabv3_resnet101', pretrained=True)
    

    至于如何获得此模型的详细帮助信息,可以使用下面的API:

    print(torch.hub.help('pytorch/vision', 'deeplabv3_resnet101'))
    

    如果模型的发布者后续加入错误修复和性能改进,用户也可以非常简单地获取更新,确保自己用到的是最新版本:

    model = torch.hub.load(..., force_reload=True)
    对于另外一部分用户来说,稳定性更加重要,他们有时候需要调用特定分支的代码。例如pytorch_GAN_zoo的hub分支:
    
    model = torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub', 'DCGAN', pretrained=True, useGPU=False)
    

    3、查看模型可用方法

    从PyTorch Hub加载模型后,你可以用dir(model)查看模型的所有可用方法。以bertForMaskedLM模型为例:

    >>> dir(model)
    >>>
    ['forward'
    ...
    'to'
    'state_dict',
    ]
    

    forward

    如果你对forward方法感兴趣,使用help(model.forward) 了解运行运行该方法所需的参数。

    >>> help(model.forward)
    >>>
    Help on method forward in module pytorch_pretrained_bert.modeling:
    forward(input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None)
    ...
    

    支持 Colab

    PyTorch Hub中提供的模型也支持Colab。

    进入每个模型的介绍页面后,你不仅可以看到GitHub代码页的入口,甚至可以一键进入Colab运行模型Demo。



    对于模型发布者

    如果你希望把自己的模型发布到PyTorch Hub上供所有用户使用,可以去PyTorch Hub的GitHub页发送拉取请求。若你的模型符合高质量、易重复、最有利的要求,Facebook官方将会与你合作。

    一旦拉取请求被接受,你的模型将很快出现在PyTorch Hub官方网页上,供所有用户浏览。

    目前该网站上已经有18个提交的模型,英伟达率先提供支持,他们在PyTorch Hub已经发布了Tacotron2和WaveGlow两个TTS模型。

    图片

    发布模型的方法也是比较简单的,开发者只需在自己的GitHub存储库中添加一个简单的hubconf.py文件,在其中枚举运行模型所需的依赖项列表即可。

    比如,torchvision中的hubconf.py文件是这样的:

    # Optional list of dependencies required by the package
    dependencies = ['torch']
    
    from torchvision.models.alexnet import alexnet
    from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161
    from torchvision.models.inception import inception_v3
    from torchvision.models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152,
    resnext50_32x4d, resnext101_32x8d
    from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1
    from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn
    from torchvision.models.segmentation import fcn_resnet101, deeplabv3_resnet101
    from torchvision.models.googlenet import googlenet
    from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0
    from torchvision.models.mobilenet import mobilenet_v2
    

    Facebook官方向模型发布者提出了以下三点要求:

    1、每个模型文件都可以独立运行和执行
    2、不需要PyTorch以外的任何包
    3、不需要单独的入口点,让模型在创建时可以无缝地开箱即用

    Facebook还建议发布者最小化对包的依赖性,减少用户加载模型进行实验的阻力。


    更多资料

  • 相关阅读:
    7月的尾巴,你是XXX
    戏说Android view 工作流程《下》
    “燕子”
    Android开机动画bootanimation.zip
    戏说Android view 工作流程《上》
    ViewController里已连接的IBOutlet为什么会是nil
    My first App "Encrypt Wheel" is Ready to Download!
    iOS开发中角色Role所产生的悲剧(未完)
    UIScrollView实现不全屏分页的小技巧
    Apple misunderstood my app,now my app status changed to “In Review”
  • 原文地址:https://www.cnblogs.com/fldev/p/14466499.html
Copyright © 2011-2022 走看看