zoukankan      html  css  js  c++  java
  • tensorflow学习笔记-SavedModel文件解释及TFServing的模型加载、使用

    tensorflow基本概念:https://www.cnblogs.com/wanyu416/p/8954098.html 这里是一系列文章

    Tensorflow SavedModel 模型的保存和加载 https://www.jianshu.com/p/83cfd9571158

    Tensorflow如何加载离线模型 https://www.zhihu.com/question/300914772

    TensorFlow模型的跨平台部署 https://zhuanlan.zhihu.com/p/40481765

    TensorFlow程序结构 http://c.biancheng.net/view/1883.html

    SavedModel的格式:https://www.tensorflow.org/guide/saved_model

    SavedModel 是一个包含序列化签名和运行这些签名所需的状态的目录,其中包括变量值和词汇表。
    目录如下:

    saved_model.pb 文件用于存储实际 TensorFlow 程序或模型,以及一组已命名的签名(signatures)——每个签名标识一个接受tensor输入和产生tensor输出的函数。
    variables 目录包含一个标准训练检查点(checkpoint)
    名词解释:
    signatures: 使用SavedModel保存的签名。只适用于“tf”格式,详情查看 tf.saved_model.save
    checkpoint: 检查点,保存模型并不限于在训练模型后,在训练模型之中也需要保存,因为TensorFlow训练模型时难免会出现中断的情况,我们自然希望能够将训练得到的参数保存下来,否则下次又要重新训练。这种在训练中保存模型,习惯上称之为保存检查点。
    TensorFlow——Checkpoint为模型添加检查点 https://www.cnblogs.com/baby-lily/p/10930591.html

    TFS模型加载:
    tensorflow/tensorflow/cc/saved_model/loader.cc

    /// Checks whether the provided directory could contain a SavedModel. Note that
    /// the method does not load any data by itself. If the method returns `false`,
    /// the export directory definitely does not contain a SavedModel. If the method
    /// returns `true`, the export directory may contain a SavedModel but provides
    /// no guarantee that it can be loaded.
    bool MaybeSavedModelDirectory(const string& export_dir);
    

    检查提供的目录是否可以包含SavedModel。 请注意,该方法本身不会加载任何数据。 如果该方法返回false,则导出目录肯定不包含SavedModel。 如果该方法返回true,则导出目录可能包含SavedModel,但不保证可以加载它。

    /// Loads a SavedModel from the specified export directory. The meta graph def
    /// to be loaded is identified by the supplied tags, corresponding exactly to
    /// the set of tags used at SavedModel build time. Returns a SavedModel bundle
    /// with a session and the requested meta graph def, if found.
    Status LoadSavedModel(const SessionOptions& session_options,
                          const RunOptions& run_options, const string& export_dir,
                          const std::unordered_set<string>& tags,
                          SavedModelBundle* const bundle);
    
    

    从指定的导出目录加载SavedModel。 所要加载的meta graph def由所提供的tag标识,该标签恰好与SavedModel构建时使用的标签集相对应。 如果找到,返回带有会话和请求的meta graph def的SavedModel bundle。

    Eg:例子: https://gist.github.com/OneRaynyDay/c79346890dda095aecc6e9249a9ff3e1#file-run_model-py
    tensorflow::MaybeSavedModelDirectory
    tensorflow::LoadSavedModel
    bundle.session->Run

    点击查看例子
    #include <tensorflow/cc/saved_model/loader.h>
    #include <tensorflow/cc/saved_model/tag_constants.h>
    #include <tensorflow/core/public/session_options.h>
    #include <tensorflow/core/framework/tensor.h>
    
    #include <xtensor/xarray.hpp>
    #include <xtensor/xnpy.hpp>
    
    #include <string>
    #include <iostream>
    #include <vector>
    #include <cfloat>
    
    static const int IMG_SIZE = 784;
    static const int NUM_SAMPLES = 10000;
    
    tensorflow::Tensor load_npy_img(const std::string& filename) {
        auto data = xt::load_npy<float>(filename);
        tensorflow::Tensor t(tensorflow::DT_FLOAT, tensorflow::TensorShape({NUM_SAMPLES, IMG_SIZE}));
    
        for (int i = 0; i < NUM_SAMPLES; i++)
            for (int j = 0; j < IMG_SIZE; j++)
                t.tensor<float, 2>()(i,j) = data(i, j);
    
        return t;
    }
    
    std::vector<int> get_tensor_shape(const tensorflow::Tensor& tensor)
    {
        std::vector<int> shape;
        auto num_dimensions = tensor.shape().dims();
        for(int i=0; i < num_dimensions; i++) {
            shape.push_back(tensor.shape().dim_size(i));
        }
        return shape;
    }
    
    template <typename M>
    void print_keys(const M& sig_map) {
        for (auto const& p : sig_map) {
            std::cout << "key : " << p.first << std::endl;
        }
    }
    
    template <typename K, typename M>
    bool assert_in(const K& k, const M& m) {
        return !(m.find(k) == m.end());
    }
    
    std::string _input_name = "digits";
    std::string _output_name = "predictions";
    
    int main() {
        // This is passed into LoadSavedModel to be populated.
        tensorflow::SavedModelBundle bundle;
    
        // From docs: "If 'target' is empty or unspecified, the local TensorFlow runtime
        // implementation will be used.  Otherwise, the TensorFlow engine
        // defined by 'target' will be used to perform all computations."
        tensorflow::SessionOptions session_options;
    
        // Run option flags here: https://www.tensorflow.org/api_docs/python/tf/compat/v1/RunOptions
        // We don't need any of these yet.
        tensorflow::RunOptions run_options;
    
        // Fills in this from a session run call
        std::vector<tensorflow::Tensor> out;
    
        std::string dir = "pyfiles/foo";
        std::string npy_file = "pyfiles/data.npy";
        std::string prediction_npy_file = "pyfiles/predictions.npy";
    
        std::cout << "Found model: " << tensorflow::MaybeSavedModelDirectory(dir) << std::endl;
        // TF_CHECK_OK takes the status and checks whether it works.
        TF_CHECK_OK(tensorflow::LoadSavedModel(session_options,
                                               run_options,
                                               dir,
                                               // Refer to tag_constants. We just want to serve the model.
                                               {tensorflow::kSavedModelTagServe},
                                               &bundle));
    
        auto sig_map = bundle.meta_graph_def.signature_def();
    
        // not sure why it's called this but upon running this for loop to check for keys we see it.
        print_keys(sig_map);
        std::string sig_def = "serving_default";
        auto model_def = sig_map.at(sig_def);
        auto inputs = model_def.inputs().at(_input_name);
        auto input_name = inputs.name();
        auto outputs = model_def.outputs().at(_output_name);
        auto output_name = outputs.name();
    
        auto input = load_npy_img(npy_file);
    
        TF_CHECK_OK(bundle.session->Run({{input_name, input}},
                            {output_name},
                            {},
                            &out));
        std::cout << out[0].DebugString() << std::endl;
    
        auto res = out[0];
        auto shape = get_tensor_shape(res);
        // we only care about the first dimension of shape
        xt::xarray<float> predictions = xt::zeros<float>({shape[0]});
        for(int row = 0; row < shape[0]; row++) {
            float max = FLT_MIN;
            int max_idx = -1;
            for(int col = 0; col < shape[1]; col++) {
                auto val = res.tensor<float, 2>()(row, col);
                if(max < val) {
                    max_idx = col;
                    max = val;
                }
            }
            predictions(row) = max_idx;
        }
        xt::dump_npy(prediction_npy_file, predictions);
    }
    
  • 相关阅读:
    Spark——为数据分析处理提供更为灵活的赋能
    秋读|10本热门图书(人工智能、编程开发、架构、区块链等)免费送!
    使用Phaser开发你的第一个H5游戏(一)
    Java web 服务启动时Xss溢出异常处理笔记
    为什么我打的jar包没有注解?
    收集、分析线上日志数据实战——ELK
    阿里云PolarDB及其共享存储PolarFS技术实现分析(下)
    14.5 富文本编辑【JavaScript高级程序设计第三版】
    【收藏】15个常用的javaScript正则表达式
    C# 网络请求
  • 原文地址:https://www.cnblogs.com/gnivor/p/13374173.html
Copyright © 2011-2022 走看看