zoukankan      html  css  js  c++  java
  • MindSpore计算框架如何发布训练好的模型到官方模型仓库MindSpore_Hub上

    相关官方资料:

    https://www.mindspore.cn/tutorial/training/zh-CN/r1.2/use/publish_model.html

     参考地址:

    https://gitee.com/mindspore/hub/blob/r1.2/mshub_res/README.md

    https://gitee.com/mindspore/mindspore/blob/r1.2/CONTRIBUTING.md

    ==========================================================

    1. 将你的预训练模型托管在可以访问的存储位置。

     这里假设我们训练的网络模型参数文件为 ckpt 文件,我们需要提前把参数文件保存到一个可以公开访问的地址,如:

    https://download.mindspore.cn/model_zoo/official/cv/googlenet/googlenet_ascend_0.7.0_cifar10_official_classification_20200922/googlenet.ckpt

    但是这个地址是不是一定是http服务的呢?这里也是搞不太清楚,具体讨论看问末尾。

     2.      参照模板,在你自己的代码仓中添加模型生成文件mindspore_hub_conf.py,文件放置的位置如下:

     

     这一步骤是说我们需要提供网络的定义文件,当然如果你可以提供网络的具体说明,训练代码,测试代码,等等吧,这是更好的,但是最低要求是需要提供一个网络定义的文件。而这个网络定义的文件需要满足两个条件:

    1).  以代码库的形式来体现,如:

    repo-link: https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/googlenet

    这里我们可以使用gitee 或 github 来存放自己构建的代码库,该代码库中存放网络定义的文件。

    2).  代码库中需要提供mindspore_hub可以调用的API接口文件。

    在你自己的代码仓中添加模型生成文件mindspore_hub_conf.py ,该文件存储位置为根目录,如上图所示。

     文件mindspore_hub_conf.py 编写格式参考:

    https://gitee.com/mindspore/mindspore/blob/r1.2/model_zoo/official/cv/googlenet/mindspore_hub_conf.py

    """hub config."""
    from src.googlenet import GoogleNet
    
    def googlenet(*args, **kwargs):
        return GoogleNet(*args, **kwargs)
    
    
    def create_network(name, *args, **kwargs):
        if name == "googlenet":
            return googlenet(*args, **kwargs)
        raise NotImplementedError(f"{name} is not implemented in the repo")

    也就是说 文件mindspore_hub_conf.py 中需要有 函数 def  create_network(name, *args, **kwargs):

    通过调用函数 create_network 我们可以获得返回的定义好的mindspore框架下的网络对象,即上面代码中的  return googlenet(*args, **kwargs)  。

     mindspore_hub中load网络时便会自动调用 create_network函数获得定义好后的网络模型,并通过访问模型参数托管的位置来加载网络参数。

     3.  编写说明文件,即 .md 文件,然后以提交PR的形式提交给mindspore_hub官方代码库。(这里假设已经从官方gitee地址mindspore上拉取了hub库)

     .md 文件的存放地址(在自己拉取的mindspore的官方hub库下的位置,因为我们最后是以提交PR的形式提交给官方的)

    参照模板,在hub/mshub_res/assets/mindspore/ascend/0.7文件夹下创建{model_name}_{model_version}_{dataset}.md文件,其中ascend为模型运行的硬件平台,0.7为MindSpore的版本号,hub/mshub_res的目录结构为:

     

    假设我们是用的mindspore1.3gpu版本,那么我们存放.md文件在自己拉取的hub代码库的路径为:

    hub/mshub_res/assets/mindspore/gpu/1.3

    而 .md 文件的命名格式:

    {model_name}_{model_version}_{dataset}.md

    如:https://gitee.com/mindspore/hub/blob/r1.2/mshub_res/assets/mindspore/ascend/0.7/googlenet_v1_cifar10.md

     中的    googlenet_v1_cifar10.md

     其中,googlenet  为我们训练的神经网络的名称, v1 为我们个人命名的版本号(这个可以自己自由随便起),cifar10 是我们用来进行训练的数据集名称。

    .md 文件的内容最少包括:

    注意,{model_name}_{model_version}_{dataset}.md文件中需要补充如下所示的file-formatasset-linkasset-sha256信息,它们分别表示模型文件格式、模型存储位置(步骤1所得)和模型哈希值。 

    即:

     file-format: ckpt
    asset-link: https://download.mindspore.cn/model_zoo/official/cv/googlenet/goolenet_ascend_0.2.0_cifar10_official_classification_20200713/googlenet.ckpt
    asset-sha256: 114e5acc31dad444fa8ed2aafa02ca34734419f602b9299f3b53013dfc71b0f7

    其中,模型存储位置  asset-link ,则是我们前文说的那个可以公网访问的地址。

    而  asset-sha256 字符串需要使用hub代码库中的代码进行计算,如下操作:

    cd /hub/mshub_res/tools

    python get_sha256.py --file ../googlenet.ckpt 

     

     获得hash码后填写回  .md  文件,完成  .md 文件的编写。

    验证  .md 文件的编写是否符合规范:

    使用hub/mshub_res/tools/md_validator.py在本地核对.md文件的格式,执行以下命令,输出结果为All Passed,表示.md文件的格式和内容均符合要求。

    如:

    python md_validator.py   --check_path    ../assets/mindspore/gpu/1.3/googlenet_v1_cifar10.md

     4.   完成个人拉取的hub代码库中  .md  文件的编写后提交PR给官方请求合并。

    ==========================================================

    相关问题:

    想发布模型到MindSpore_hub上所需asset-link的地址可以是百度网盘吗,除http服务可以访问以外ftp服务行吗

    官方需要我们提供已经训练好的模型的参数文件地址,该地址需要是可以访问的,那这个地址是不是一定要是http服务访问的呢?

    因为我们很有可能是没有公网IP下的http服务器的,而如果必须是可访问的http服务的地址可能就很难满足了,但是如果可以用百度云盘之类的存储方式就可以很好解决了,但是是否可以呢,我们可以关注下上面的帖子。

     

    本博客是博主个人学习时的一些记录,不保证是为原创,个别文章加入了转载的源地址还有个别文章是汇总网上多份资料所成,在这之中也必有疏漏未加标注者,如有侵权请与博主联系。
  • 相关阅读:
    体温填报APP--体温填报
    体温填报APP--主界面设计
    剑指Offer_#60_n个骰子的点数
    剑指Offer_#56-II_ 数组中数字出现的次数II
    剑指Offer_#56-I_数组中数字出现的次数
    剑指Offer_#55
    用Python从头开始构建神经网络
    使用RetinaNet构建的人脸口罩探测器
    如何利用PyTorch中的Moco-V2减少计算约束
    TF2目标检测API
  • 原文地址:https://www.cnblogs.com/devilmaycry812839668/p/15015226.html
Copyright © 2011-2022 走看看