diff --git a/ddddocr/__init__.py b/ddddocr/__init__.py index eac66cc..6379509 100644 --- a/ddddocr/__init__.py +++ b/ddddocr/__init__.py @@ -6,6 +6,7 @@ import io import os import base64 import json +import pathlib import onnxruntime from PIL import Image, ImageChops import numpy as np @@ -1594,13 +1595,20 @@ class DdddOcr(object): return [] return result - def classification(self, img_bytes: bytes = None, img_base64: str = None): + def classification(self, img): if self.det: raise TypeError("当前识别类型为目标检测") - if img_bytes: - image = Image.open(io.BytesIO(img_bytes)) + if not isinstance(img, (bytes, str, pathlib.PurePath, Image.Image)): + raise TypeError("未知图片类型") + if isinstance(img, bytes): + image = Image.open(io.BytesIO(img)) + elif isinstance(img, Image.Image): + image = img.copy() + elif isinstance(img, str): + image = base64_to_image(img) else: - image = base64_to_image(img_base64) + assert isinstance(img, pathlib.PurePath) + image = Image.open(img) if not self.use_import_onnx: image = image.resize((int(image.size[0] * (64 / image.size[1])), 64), Image.ANTIALIAS).convert('L') else: