diff --git a/ddddocr/__init__.py b/ddddocr/__init__.py index eac66cc..41aa654 100644 --- a/ddddocr/__init__.py +++ b/ddddocr/__init__.py @@ -5,7 +5,7 @@ warnings.filterwarnings('ignore') import io import os import base64 -import json +import pathlib import onnxruntime from PIL import Image, ImageChops import numpy as np @@ -29,26 +29,13 @@ class TypeError(Exception): class DdddOcr(object): def __init__(self, ocr: bool = True, det: bool = False, old: bool = False, use_gpu: bool = False, - device_id: int = 0, show_ad=True, import_onnx_path: str = "", charsets_path: str = ""): + device_id: int = 0, show_ad=True): if show_ad: print("欢迎使用ddddocr,本项目专注带动行业内卷,个人博客:wenanzhe.com") print("训练数据支持来源于:http://146.56.204.113:19199/preview") print("爬虫框架feapder可快速一键接入,快速开启爬虫之旅:https://github.com/Boris-code/feapder") - self.use_import_onnx = False - self.__word = False - self.__resize = [] - self.__channel = 1 - if import_onnx_path != "": - det = False - ocr = False - self.__graph_path = import_onnx_path - with open(charsets_path, 'r', encoding="utf-8") as f: - info = json.loads(f.read()) - self.__charset = info['charset'] - self.__word = info['word'] - self.__resize = info['image'] - self.__channel = info['channel'] - self.use_import_onnx = True + + if det: ocr = False print("开启det后自动关闭ocr") @@ -1453,7 +1440,7 @@ class DdddOcr(object): self.__providers = [ 'CPUExecutionProvider', ] - if ocr or det or self.use_import_onnx: + if ocr or det: self.__ort_session = onnxruntime.InferenceSession(self.__graph_path, providers=self.__providers) def preproc(self, img, input_size, swap=(2, 0, 1)): @@ -1594,53 +1581,35 @@ class DdddOcr(object): return [] return result - def classification(self, img_bytes: bytes = None, img_base64: str = None): + def classification(self, img): if self.det: raise TypeError("当前识别类型为目标检测") - if img_bytes: - image = Image.open(io.BytesIO(img_bytes)) + if not isinstance(img, (bytes, str, pathlib.PurePath, Image.Image)): + raise TypeError("未知图片类型") + if isinstance(img, bytes): + image = Image.open(io.BytesIO(img)) + elif isinstance(img, Image.Image): + image = img.copy() + elif isinstance(img, str): + image = base64_to_image(img) else: - image = base64_to_image(img_base64) - if not self.use_import_onnx: - image = image.resize((int(image.size[0] * (64 / image.size[1])), 64), Image.ANTIALIAS).convert('L') - else: - if self.__resize[0] == -1: - if self.__word: - image = image.resize((self.__resize[1], self.__resize[1]), Image.ANTIALIAS) - else: - image = image.resize((int(image.size[0] * (self.__resize[1] / image.size[1])), self.__resize[1]), Image.ANTIALIAS) - else: - image = image.resize((self.__resize[0], self.__resize[1]), Image.ANTIALIAS) - if self.__channel == 1: - image = image.convert('L') - else: - image = image.convert('RGB') + assert isinstance(img, pathlib.PurePath) + image = Image.open(img) + image = image.resize((int(image.size[0] * (64 / image.size[1])), 64), Image.ANTIALIAS).convert('L') image = np.array(image).astype(np.float32) image = np.expand_dims(image, axis=0) / 255. - if not self.use_import_onnx: - image = (image - 0.5) / 0.5 - else: - if self.__channel == 1: - image = (image - 0.456) / 0.224 - else: - image = (image - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225]) - + image = (image - 0.5) / 0.5 ort_inputs = {'input1': np.array([image])} ort_outs = self.__ort_session.run(None, ort_inputs) result = [] - last_item = 0 - if self.__word: - for item in ort_outs[1]: + for item in ort_outs[0][0]: + if item == last_item: + continue + else: + last_item = item + if item != 0: result.append(self.__charset[item]) - else: - for item in ort_outs[0][0]: - if item == last_item: - continue - else: - last_item = item - if item != 0: - result.append(self.__charset[item]) return ''.join(result)