@@ -147,6 +147,8 @@ def init_args():
147
147
148
148
parser .add_argument ("--show_log" , type = str2bool , default = True )
149
149
parser .add_argument ("--use_onnx" , type = str2bool , default = False )
150
+ parser .add_argument ("--onnx_providers" , nargs = "+" , type = str , default = False )
151
+ parser .add_argument ("--onnx_sess_options" , type = list , default = False )
150
152
151
153
# extended function
152
154
parser .add_argument (
@@ -193,7 +195,16 @@ def create_predictor(args, mode, logger):
193
195
model_file_path = model_dir
194
196
if not os .path .exists (model_file_path ):
195
197
raise ValueError ("not find model file path {}" .format (model_file_path ))
196
- if args .use_gpu :
198
+
199
+ sess_options = args .onnx_sess_options or None
200
+
201
+ if args .onnx_providers and len (args .onnx_providers ) > 0 :
202
+ sess = ort .InferenceSession (
203
+ model_file_path ,
204
+ providers = args .onnx_providers ,
205
+ sess_options = sess_options ,
206
+ )
207
+ elif args .use_gpu :
197
208
sess = ort .InferenceSession (
198
209
model_file_path ,
199
210
providers = [
@@ -202,10 +213,13 @@ def create_predictor(args, mode, logger):
202
213
{"device_id" : args .gpu_id , "cudnn_conv_algo_search" : "DEFAULT" },
203
214
)
204
215
],
216
+ sess_options = sess_options ,
205
217
)
206
218
else :
207
219
sess = ort .InferenceSession (
208
- model_file_path , providers = ["CPUExecutionProvider" ]
220
+ model_file_path ,
221
+ providers = ["CPUExecutionProvider" ],
222
+ sess_options = sess_options ,
209
223
)
210
224
return sess , sess .get_inputs ()[0 ], None , None
211
225
0 commit comments