一、Type Error: Type 'tensor(bool)' of input parameter (121) of operator (ScatterND) in node (ScatterND_128) is invalid
问题
模型转出成功后,用onnxruntime加载,出现不支持参数问题, 这里出现tensor(bool)
是因为代码中使用了bool类型的索引
解决措施
索引采用torch.where替代
...
mask = dist < distance
distance[mask] = dist[mask]
...
更改为
distance = torch.where(dist < distance, dist, distance)
二、FAIL : Load model from ./test.onnx failed:Fatal error: ATen is not a registered function/op
问题
模型转出成功后,用onnxruntime加载,出现没有注册的算子
解决措施
在torch.onnx.export
函数中设置opset_version=12
三、动态输入/输出
有时候输入和输出维度是变化的,这个时候在导出的时候可以添加dynamic_axes
参数,并指定哪些参数和维度是动态的。
结果
四、Removing initializer 'bn1.num_batches_tracked'. It is not used by any node and should be removed from the model.
问题
模型转出成功后,用onnxruntime运行出现以上警告
解决措施
对模型进行优化
import onnx
import onnxoptimizer # pip install onnxoptimizer
onnx_model = onnx.load(onnxfile)
passes = ["extract_constant_to_initializer", "eliminate_unused_initializer"]
optimized_model = onnxoptimizer.optimize(onnx_model, passes)
onnx.save(optimized_model, onnxfile)