tfserving模型部署见:https://www.cnblogs.com/bincoding/p/13266685.html
demo代码:https://github.com/haibincoder/tf_tools
对应restful入参:
{
"inputs": {
"input": [[13, 45, 13, 13, 49, 1, 49, 196, 594, 905, 48, 231, 318, 712, 1003, 477, 259, 291, 287, 161, 65, 62, 82, 68, 2, 10]],
"drop_out": 1,
"sequence_length": [26]
},
"signature_name":"predict"
}
python代码:
from grpc.beta import implementations
import tensorflow as tf
from tensorflow_serving.apis import predict_pb2, prediction_service_pb2_grpc
# 获取stub
channel = implementations.insecure_channel('localhost', 8500)
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel._channel)
# 模型签名
request = predict_pb2.PredictRequest()
request.model_spec.name = 'ner'
# request.model_spec.version = 'latest'
request.model_spec.signature_name = 'predict'
# 构造入参
x_data = [[13, 45, 13, 13, 49, 1, 49, 196, 594, 905, 48, 231, 318, 712, 1003, 477, 259, 291, 287, 161, 65, 62, 82, 68, 2, 10]]
drop_out = 1
sequence_length = [26]
request.inputs['input'].CopyFrom(tf.make_tensor_proto(x_data, dtype=tf.int32))
request.inputs['sequence_length'].CopyFrom(tf.make_tensor_proto(sequence_length, dtype=tf.int32))
request.inputs['drop_out'].CopyFrom(tf.make_tensor_proto(drop_out, dtype=tf.float32))
# 返回CRF结果,输出发射概率矩阵和状态转移概率矩阵
result = stub.Predict(request, 10.0) # 10 secs timeout
print(result)
java pom:
<dependencies>
<dependency>
<groupId>com.yesup.oss</groupId>
<artifactId>tensorflow-client</artifactId>
<version>1.4-2</version>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-netty-shaded</artifactId>
<version>1.14.0</version>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-protobuf</artifactId>
<version>1.14.0</version>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-stub</artifactId>
<version>1.14.0</version>
</dependency>
</dependencies>
java代码:
public static void main(String[] args) {
ManagedChannel channel = ManagedChannelBuilder.forAddress("localhost", 8500).usePlaintext(true).build();
// 这里使用block模式
PredictionServiceGrpc.PredictionServiceBlockingStub stub = PredictionServiceGrpc.newBlockingStub(channel);
// 创建请求
Predict.PredictRequest.Builder predictRequestBuilder = Predict.PredictRequest.newBuilder();
// 模型名称和模型方法名预设
Model.ModelSpec.Builder modelSpecBuilder = Model.ModelSpec.newBuilder();
modelSpecBuilder.setName("ner");
modelSpecBuilder.setSignatureName("predict");
predictRequestBuilder.setModelSpec(modelSpecBuilder);
// 设置入参,访问默认是最新版本,如果需要特定版本可以使用tensorProtoBuilder.setVersionNumber方法
List<Float> input = Arrays.asList(13f, 45f, 13f, 13f, 49f, 1f, 49f, 196f, 594f, 905f, 48f, 231f, 318f, 712f, 1003f, 477f, 259f, 291f, 287f, 161f, 65f, 62f, 82f, 68f, 2f, 10f);
TensorProto.Builder inputTensorProto = TensorProto.newBuilder();
inputTensorProto.setDtype(DataType.DT_INT32);
inputTensorProto.addAllFloatVal(input);
TensorShapeProto.Builder inputShapeBuilder = TensorShapeProto.newBuilder();
inputShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
inputShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(input.size()));
inputTensorProto.setTensorShape(inputShapeBuilder.build());
int dropout = 1;
TensorProto.Builder dropoutTensorProto = TensorProto.newBuilder();
dropoutTensorProto.setDtype(DataType.DT_FLOAT);
dropoutTensorProto.addIntVal(dropout);
TensorShapeProto.Builder dropoutShapeBuilder = TensorShapeProto.newBuilder();
dropoutShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
dropoutTensorProto.setTensorShape(dropoutShapeBuilder.build());
List<Integer> seqLength = Collections.singletonList(26);
TensorProto.Builder seqLengthTensorProto = TensorProto.newBuilder();
seqLengthTensorProto.setDtype(DataType.DT_INT32);
seqLengthTensorProto.addAllIntVal(seqLength);
TensorShapeProto.Builder seqLengthShapeBuilder = TensorShapeProto.newBuilder();
seqLengthShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
seqLengthTensorProto.setTensorShape(seqLengthShapeBuilder.build());
predictRequestBuilder.putInputs("input", inputTensorProto.build());
predictRequestBuilder.putInputs("drop_out", dropoutTensorProto.build());
predictRequestBuilder.putInputs("sequence_length", seqLengthTensorProto.build());
// 访问并获取结果
Predict.PredictResponse predictResponse = stub.withDeadlineAfter(3, TimeUnit.SECONDS).predict(predictRequestBuilder.build());
Map<String, TensorProto> result = predictResponse.getOutputsMap();
// CRF模型结果,发射概率矩阵和状态概率矩阵
System.out.println("预测值是:" + result.toString());
}
注意事项:
- 请求type和模型定义的type保持一致,可以到tfserving网页查看模型参数:
否则会报错:Expects arg[0] to be float but int32 is provided
tfserving restful网页:http://localhost:8501/v1/models/ner/metadata
tfserving部署方法见:https://www.cnblogs.com/bincoding/p/13266685.html