Skip to content

Commit 4a189b0

Browse files
committed
Merge branch 'main' of github.com:PaddlePaddle/PaddleOCR
2 parents f8835d2 + 37e1775 commit 4a189b0

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

tools/infer/utility.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ def init_args():
147147

148148
parser.add_argument("--show_log", type=str2bool, default=True)
149149
parser.add_argument("--use_onnx", type=str2bool, default=False)
150+
parser.add_argument("--onnx_providers", nargs="+", type=str, default=False)
151+
parser.add_argument("--onnx_sess_options", type=list, default=False)
150152

151153
# extended function
152154
parser.add_argument(
@@ -193,7 +195,16 @@ def create_predictor(args, mode, logger):
193195
model_file_path = model_dir
194196
if not os.path.exists(model_file_path):
195197
raise ValueError("not find model file path {}".format(model_file_path))
196-
if args.use_gpu:
198+
199+
sess_options = args.onnx_sess_options or None
200+
201+
if args.onnx_providers and len(args.onnx_providers) > 0:
202+
sess = ort.InferenceSession(
203+
model_file_path,
204+
providers=args.onnx_providers,
205+
sess_options=sess_options,
206+
)
207+
elif args.use_gpu:
197208
sess = ort.InferenceSession(
198209
model_file_path,
199210
providers=[
@@ -202,10 +213,13 @@ def create_predictor(args, mode, logger):
202213
{"device_id": args.gpu_id, "cudnn_conv_algo_search": "DEFAULT"},
203214
)
204215
],
216+
sess_options=sess_options,
205217
)
206218
else:
207219
sess = ort.InferenceSession(
208-
model_file_path, providers=["CPUExecutionProvider"]
220+
model_file_path,
221+
providers=["CPUExecutionProvider"],
222+
sess_options=sess_options,
209223
)
210224
return sess, sess.get_inputs()[0], None, None
211225

0 commit comments

Comments
 (0)