This commit is contained in:
lsylsy2 2024-01-22 12:59:46 +08:00 committed by GitHub
commit df93cf0fe6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -11,6 +11,8 @@ parser.add_argument("-p", "--port", type=int, default=9898)
parser.add_argument("--ocr", action="store_true", help="开启ocr识别") parser.add_argument("--ocr", action="store_true", help="开启ocr识别")
parser.add_argument("--old", action="store_true", help="OCR是否启动旧模型") parser.add_argument("--old", action="store_true", help="OCR是否启动旧模型")
parser.add_argument("--det", action="store_true", help="开启目标检测") parser.add_argument("--det", action="store_true", help="开启目标检测")
parser.add_argument("--custom-onnx", type=str, help="使用dddd_trainer训练出的自定义ONNX模型文件")
parser.add_argument("--charsets-path", type=str, help="使用dddd_trainer训练出的自定义charsets文件")
args = parser.parse_args() args = parser.parse_args()
@ -18,7 +20,7 @@ app = Flask(__name__)
class Server(object): class Server(object):
def __init__(self, ocr=True, det=False, old=False): def __init__(self, ocr=True, det=False, old=False, custom_onnx=None, charsets_path=None):
self.ocr_option = ocr self.ocr_option = ocr
self.det_option = det self.det_option = det
self.old_option = old self.old_option = old
@ -26,7 +28,10 @@ class Server(object):
self.det = None self.det = None
if self.ocr_option: if self.ocr_option:
print("ocr模块开启") print("ocr模块开启")
if self.old_option: if custom_onnx and charsets_path:
print(f"使用自定义ONNX模型: {custom_onnx} 和字符集: {charsets_path}")
self.ocr = ddddocr.DdddOcr(det=False, ocr=False, import_onnx_path=custom_onnx, charsets_path=charsets_path)
elif self.old_option:
print("使用OCR旧模型启动") print("使用OCR旧模型启动")
self.ocr = ddddocr.DdddOcr(old=True) self.ocr = ddddocr.DdddOcr(old=True)
else: else:
@ -61,7 +66,7 @@ class Server(object):
else: else:
raise Exception(f"不支持的滑块算法类型: {algo_type}") raise Exception(f"不支持的滑块算法类型: {algo_type}")
server = Server(ocr=args.ocr, det=args.det, old=args.old) server = Server(ocr=args.ocr, det=args.det, old=args.old, custom_onnx=args.custom_onnx, charsets_path=args.charsets_path)
def get_img(request, img_type='file', img_name='image'): def get_img(request, img_type='file', img_name='image'):