diff --git a/ocr_server.py b/ocr_server.py index b3fe8f0..06194cb 100644 --- a/ocr_server.py +++ b/ocr_server.py @@ -11,6 +11,8 @@ parser.add_argument("-p", "--port", type=int, default=9898) parser.add_argument("--ocr", action="store_true", help="开启ocr识别") parser.add_argument("--old", action="store_true", help="OCR是否启动旧模型") parser.add_argument("--det", action="store_true", help="开启目标检测") +parser.add_argument("--custom-onnx", type=str, help="使用dddd_trainer训练出的自定义ONNX模型文件") +parser.add_argument("--charsets-path", type=str, help="使用dddd_trainer训练出的自定义charsets文件") args = parser.parse_args() @@ -18,7 +20,7 @@ app = Flask(__name__) class Server(object): - def __init__(self, ocr=True, det=False, old=False): + def __init__(self, ocr=True, det=False, old=False, custom_onnx=None, charsets_path=None): self.ocr_option = ocr self.det_option = det self.old_option = old @@ -26,7 +28,10 @@ class Server(object): self.det = None if self.ocr_option: print("ocr模块开启") - if self.old_option: + if custom_onnx and charsets_path: + print(f"使用自定义ONNX模型: {custom_onnx} 和字符集: {charsets_path}") + self.ocr = ddddocr.DdddOcr(det=det, ocr=ocr, import_onnx_path=custom_onnx, charsets_path=charsets_path) + elif self.old_option: print("使用OCR旧模型启动") self.ocr = ddddocr.DdddOcr(old=True) else: @@ -61,7 +66,7 @@ class Server(object): else: raise Exception(f"不支持的滑块算法类型: {algo_type}") -server = Server(ocr=args.ocr, det=args.det, old=args.old) +server = Server(ocr=args.ocr, det=args.det, old=args.old, custom_onnx=args.custom_onnx, charsets_path=args.charsets_path) def get_img(request, img_type='file', img_name='image'):