5
0
mirror of https://github.com/sml2h3/ddddocr.git synced 2025-05-04 13:51:12 +08:00

Merge pull request #39 from lededev/mtp-2

classification() add image type Image.Image and pathlib.Path
This commit is contained in:
Sml2h3 2022-03-01 13:00:18 +08:00 committed by GitHub
commit fa5db35dff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -6,6 +6,7 @@ import io
import os import os
import base64 import base64
import json import json
import pathlib
import onnxruntime import onnxruntime
from PIL import Image, ImageChops from PIL import Image, ImageChops
import numpy as np import numpy as np
@ -1594,13 +1595,20 @@ class DdddOcr(object):
return [] return []
return result return result
def classification(self, img_bytes: bytes = None, img_base64: str = None): def classification(self, img):
if self.det: if self.det:
raise TypeError("当前识别类型为目标检测") raise TypeError("当前识别类型为目标检测")
if img_bytes: if not isinstance(img, (bytes, str, pathlib.PurePath, Image.Image)):
image = Image.open(io.BytesIO(img_bytes)) raise TypeError("未知图片类型")
if isinstance(img, bytes):
image = Image.open(io.BytesIO(img))
elif isinstance(img, Image.Image):
image = img.copy()
elif isinstance(img, str):
image = base64_to_image(img)
else: else:
image = base64_to_image(img_base64) assert isinstance(img, pathlib.PurePath)
image = Image.open(img)
if not self.use_import_onnx: if not self.use_import_onnx:
image = image.resize((int(image.size[0] * (64 / image.size[1])), 64), Image.ANTIALIAS).convert('L') image = image.resize((int(image.size[0] * (64 / image.size[1])), 64), Image.ANTIALIAS).convert('L')
else: else: