zoukankan      html  css  js  c++  java
  • All-in-one 的Serving分析

    export_func.export(model, sess, signature_name=mission, version=fold + 1)
    def export(model, sess, signature_name, export_path=root_path + '/all_in_one/demo/exported_models/', version=1):
        # export path
        export_path = os.path.join(os.path.realpath(export_path), signature_name, str(version))
        print('Exporting trained model to {} ...'.format(export_path))
    
        builder = tf.saved_model.builder.SavedModelBuilder(export_path)
        # Build the signature_def_map.
        classification_w = tf.saved_model.utils.build_tensor_info(model.w)
        # classification_is_training = tf.saved_model.utils.build_tensor_info(model.is_training)
        classification_dropout_keep_prob_mlp = tf.saved_model.utils.build_tensor_info(
            model.dropout_keep_prob_mlp)
        # score
        classification_outputs_scores = tf.saved_model.utils.build_tensor_info(model.y)
    
        classification_signature = tf.saved_model.signature_def_utils.build_signature_def(
            inputs={tf.saved_model.signature_constants.CLASSIFY_INPUTS: classification_w},
            outputs={
                tf.saved_model.signature_constants.CLASSIFY_OUTPUT_SCORES:
                classification_outputs_scores
            },
            method_name=tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME)  # 'tensorflow/serving/classify'
    
        prediction_signature = tf.saved_model.signature_def_utils.build_signature_def(
            inputs={'input_plh': classification_w, 'dropout_keep_prob_mlp':
                    classification_dropout_keep_prob_mlp,
                    # 'is_training': classification_is_training
                    },
            outputs={'scores': classification_outputs_scores},
            method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)  # 'tensorflow/serving/predict'
        builder.add_meta_graph_and_variables(
            sess, [tf.saved_model.tag_constants.SERVING],
            signature_def_map={
                signature_name: prediction_signature,
                tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: classification_signature,
            })
        builder.save()

    在signature_def_map中定义了两个,一个是自己设计的别名,一个是默认的。

    定义一个解析类。

    model_name 是启动服务时明确的model_name

    signature_name是在signature_def_map中自己设计的别名对应的输入输出之类的。

    def classify(self,  sents):
            self.sents=self.sents2id(sents)
            hostport = '192.168.31.186:6000'
            # grpc
            host, port = hostport.split(':')
            channel = implementations.insecure_channel(host, int(port))
            stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
            # build request
            request = predict_pb2.PredictRequest()
            request.model_spec.name = self.model_name
            request.model_spec.signature_name = self.signature_name
            request.inputs['input_plh'].CopyFrom(
                tf.contrib.util.make_tensor_proto(self.sents, dtype=tf.int32))
            request.inputs['dropout_keep_prob_mlp'].CopyFrom(
                tf.contrib.util.make_tensor_proto(1.0, dtype=tf.float32))
            model_result = stub.Predict(request, 60.0)
            model_result = np.array(model_result.outputs['scores'].float_val)
            model_result = [model_result.tolist()][0]
            index, _ =max(enumerate(model_result), key=operator.itemgetter(1))
            if index>0:
                label = self.label_dict[index-1]
            else:
                label = ""
            # print("index:{}	label:{}".format(index, label))
            if self.encode == "part" :
                if label:
                    label=self.part[label]
                else:
                    label = "凌晨"
            if self.encode == "type" :
                if label:
                    label=self.type[label]
                else:
                    label = "录像"
            if self.encode == "door" and label:
                label = self.gate[label]
    
            return label
  • 相关阅读:
    error:undefined reference to 'net_message_processor::net_message_processor()'
    android 网络检测
    eclipse 安装 ndk 组件
    eclipse下编译cocos2dx 3.0
    Cocos2dx3.0 TextField 输入中文的问题
    记录与骗子进行的一次交锋. 与技术无关
    关于继承的设计
    kubernetes1.5.2--部署dashboard服务
    kubernetes1.5.2--部署DNS服务
    kubernetes1.5.2集群部署过程--安全模式
  • 原文地址:https://www.cnblogs.com/qniguoym/p/7920712.html
Copyright © 2011-2022 走看看