zoukankan      html  css  js  c++  java
  • 通过grpc调用tfserving模型(python+java)

    tfserving模型部署见:https://www.cnblogs.com/bincoding/p/13266685.html
    demo代码:https://github.com/haibincoder/tf_tools

    对应restful入参:

    {
        "inputs": {
            "input": [[13, 45, 13, 13, 49, 1, 49, 196, 594, 905, 48, 231, 318, 712, 1003, 477, 259, 291, 287, 161, 65, 62, 82, 68, 2, 10]],
            "drop_out": 1,
            "sequence_length": [26]
            },
        "signature_name":"predict"
    }
    

    python代码:

    from grpc.beta import implementations
    import tensorflow as tf
    
    from tensorflow_serving.apis import predict_pb2, prediction_service_pb2_grpc
    
    # 获取stub
    channel = implementations.insecure_channel('localhost', 8500)
    stub = prediction_service_pb2_grpc.PredictionServiceStub(channel._channel)
    
    # 模型签名
    request = predict_pb2.PredictRequest()
    request.model_spec.name = 'ner'
    # request.model_spec.version = 'latest'
    request.model_spec.signature_name = 'predict'
    
    # 构造入参
    x_data = [[13, 45, 13, 13, 49, 1, 49, 196, 594, 905, 48, 231, 318, 712, 1003, 477, 259, 291, 287, 161, 65, 62, 82, 68, 2, 10]]
    drop_out = 1
    sequence_length = [26]
    request.inputs['input'].CopyFrom(tf.make_tensor_proto(x_data, dtype=tf.int32))
    request.inputs['sequence_length'].CopyFrom(tf.make_tensor_proto(sequence_length, dtype=tf.int32))
    request.inputs['drop_out'].CopyFrom(tf.make_tensor_proto(drop_out, dtype=tf.float32))
    
    #  返回CRF结果,输出发射概率矩阵和状态转移概率矩阵
    result = stub.Predict(request, 10.0)  # 10 secs timeout
    
    print(result)
    

    java pom:

    <dependencies>
            <dependency>
                <groupId>com.yesup.oss</groupId>
                <artifactId>tensorflow-client</artifactId>
                <version>1.4-2</version>
            </dependency>
            <dependency>
                <groupId>io.grpc</groupId>
                <artifactId>grpc-netty-shaded</artifactId>
                <version>1.14.0</version>
            </dependency>
            <dependency>
                <groupId>io.grpc</groupId>
                <artifactId>grpc-protobuf</artifactId>
                <version>1.14.0</version>
            </dependency>
            <dependency>
                <groupId>io.grpc</groupId>
                <artifactId>grpc-stub</artifactId>
                <version>1.14.0</version>
            </dependency>
        </dependencies>
    

    java代码:

    public static void main(String[] args) {
            ManagedChannel channel = ManagedChannelBuilder.forAddress("localhost", 8500).usePlaintext(true).build();
            // 这里使用block模式
            PredictionServiceGrpc.PredictionServiceBlockingStub stub = PredictionServiceGrpc.newBlockingStub(channel);
            // 创建请求
            Predict.PredictRequest.Builder predictRequestBuilder = Predict.PredictRequest.newBuilder();
            // 模型名称和模型方法名预设
            Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder();
            modelSpecBuilder.setName("ner");
            modelSpecBuilder.setSignatureName("predict");
            predictRequestBuilder.setModelSpec(modelSpecBuilder);
            // 设置入参,访问默认是最新版本,如果需要特定版本可以使用tensorProtoBuilder.setVersionNumber方法
            List<Float> input = Arrays.asList(13f, 45f, 13f, 13f, 49f, 1f, 49f, 196f, 594f, 905f, 48f, 231f, 318f, 712f, 1003f, 477f, 259f, 291f, 287f, 161f, 65f, 62f, 82f, 68f, 2f, 10f);
            TensorProto.Builder inputTensorProto = TensorProto.newBuilder();
            inputTensorProto.setDtype(DataType.DT_INT32);
            inputTensorProto.addAllFloatVal(input);
            TensorShapeProto.Builder inputShapeBuilder = TensorShapeProto.newBuilder();
            inputShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
            inputShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(input.size()));
            inputTensorProto.setTensorShape(inputShapeBuilder.build());
            
            int dropout = 1;
            TensorProto.Builder dropoutTensorProto = TensorProto.newBuilder();
            dropoutTensorProto.setDtype(DataType.DT_FLOAT);
            dropoutTensorProto.addIntVal(dropout);
            TensorShapeProto.Builder dropoutShapeBuilder = TensorShapeProto.newBuilder();
            dropoutShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
            dropoutTensorProto.setTensorShape(dropoutShapeBuilder.build());
    
            List<Integer> seqLength = Collections.singletonList(26);
            TensorProto.Builder seqLengthTensorProto = TensorProto.newBuilder();
            seqLengthTensorProto.setDtype(DataType.DT_INT32);
            seqLengthTensorProto.addAllIntVal(seqLength);
            TensorShapeProto.Builder seqLengthShapeBuilder = TensorShapeProto.newBuilder();
            seqLengthShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
            seqLengthTensorProto.setTensorShape(seqLengthShapeBuilder.build());
    
            predictRequestBuilder.putInputs("input", inputTensorProto.build());
            predictRequestBuilder.putInputs("drop_out", dropoutTensorProto.build());
            predictRequestBuilder.putInputs("sequence_length", seqLengthTensorProto.build());
    
            // 访问并获取结果
            Predict.PredictResponse predictResponse = stub.withDeadlineAfter(3, TimeUnit.SECONDS).predict(predictRequestBuilder.build());
            Map<String, TensorProto> result = predictResponse.getOutputsMap();
            // CRF模型结果,发射概率矩阵和状态概率矩阵
            System.out.println("预测值是:" + result.toString());
        }
    

    注意事项:

    1. 请求type和模型定义的type保持一致,可以到tfserving网页查看模型参数:
      否则会报错:Expects arg[0] to be float but int32 is provided
      tfserving restful网页:http://localhost:8501/v1/models/ner/metadata
      tfserving部署方法见:https://www.cnblogs.com/bincoding/p/13266685.html
  • 相关阅读:
    SharePoint SSS(Security Store Service)服务-PowerShell
    SharePoint BDC(Business Data Connectivity)服务-PowerShell
    win32编辑控件字体
    创建选项卡控件
    利用VkKeyScanA判断大写字母
    使用powershell的remove
    x86和x64下指针的大小
    不使用C库函数(Sprintf)将void* 指针转换为十六进制字符串
    使用pycharm,配置环境
    使用python获得屏幕截图并保存为位图文件
  • 原文地址:https://www.cnblogs.com/bincoding/p/13274948.html
Copyright © 2011-2022 走看看