5
0
mirror of https://github.com/sml2h3/ddddocr.git synced 2025-05-04 09:49:47 +08:00
 Conflicts:
	ddddocr/__init__.py
This commit is contained in:
sml2h3 2022-02-26 15:59:08 +08:00
commit 2cf2924ed1

View File

@ -5,7 +5,7 @@ warnings.filterwarnings('ignore')
import io import io
import os import os
import base64 import base64
import json import pathlib
import onnxruntime import onnxruntime
from PIL import Image, ImageChops from PIL import Image, ImageChops
import numpy as np import numpy as np
@ -29,26 +29,13 @@ class TypeError(Exception):
class DdddOcr(object): class DdddOcr(object):
def __init__(self, ocr: bool = True, det: bool = False, old: bool = False, use_gpu: bool = False, def __init__(self, ocr: bool = True, det: bool = False, old: bool = False, use_gpu: bool = False,
device_id: int = 0, show_ad=True, import_onnx_path: str = "", charsets_path: str = ""): device_id: int = 0, show_ad=True):
if show_ad: if show_ad:
print("欢迎使用ddddocr本项目专注带动行业内卷个人博客:wenanzhe.com") print("欢迎使用ddddocr本项目专注带动行业内卷个人博客:wenanzhe.com")
print("训练数据支持来源于:http://146.56.204.113:19199/preview") print("训练数据支持来源于:http://146.56.204.113:19199/preview")
print("爬虫框架feapder可快速一键接入快速开启爬虫之旅https://github.com/Boris-code/feapder") print("爬虫框架feapder可快速一键接入快速开启爬虫之旅https://github.com/Boris-code/feapder")
self.use_import_onnx = False
self.__word = False
self.__resize = []
self.__channel = 1
if import_onnx_path != "":
det = False
ocr = False
self.__graph_path = import_onnx_path
with open(charsets_path, 'r', encoding="utf-8") as f:
info = json.loads(f.read())
self.__charset = info['charset']
self.__word = info['word']
self.__resize = info['image']
self.__channel = info['channel']
self.use_import_onnx = True
if det: if det:
ocr = False ocr = False
print("开启det后自动关闭ocr") print("开启det后自动关闭ocr")
@ -1453,7 +1440,7 @@ class DdddOcr(object):
self.__providers = [ self.__providers = [
'CPUExecutionProvider', 'CPUExecutionProvider',
] ]
if ocr or det or self.use_import_onnx: if ocr or det:
self.__ort_session = onnxruntime.InferenceSession(self.__graph_path, providers=self.__providers) self.__ort_session = onnxruntime.InferenceSession(self.__graph_path, providers=self.__providers)
def preproc(self, img, input_size, swap=(2, 0, 1)): def preproc(self, img, input_size, swap=(2, 0, 1)):
@ -1594,53 +1581,35 @@ class DdddOcr(object):
return [] return []
return result return result
def classification(self, img_bytes: bytes = None, img_base64: str = None): def classification(self, img):
if self.det: if self.det:
raise TypeError("当前识别类型为目标检测") raise TypeError("当前识别类型为目标检测")
if img_bytes: if not isinstance(img, (bytes, str, pathlib.PurePath, Image.Image)):
image = Image.open(io.BytesIO(img_bytes)) raise TypeError("未知图片类型")
if isinstance(img, bytes):
image = Image.open(io.BytesIO(img))
elif isinstance(img, Image.Image):
image = img.copy()
elif isinstance(img, str):
image = base64_to_image(img)
else: else:
image = base64_to_image(img_base64) assert isinstance(img, pathlib.PurePath)
if not self.use_import_onnx: image = Image.open(img)
image = image.resize((int(image.size[0] * (64 / image.size[1])), 64), Image.ANTIALIAS).convert('L') image = image.resize((int(image.size[0] * (64 / image.size[1])), 64), Image.ANTIALIAS).convert('L')
else:
if self.__resize[0] == -1:
if self.__word:
image = image.resize((self.__resize[1], self.__resize[1]), Image.ANTIALIAS)
else:
image = image.resize((int(image.size[0] * (self.__resize[1] / image.size[1])), self.__resize[1]), Image.ANTIALIAS)
else:
image = image.resize((self.__resize[0], self.__resize[1]), Image.ANTIALIAS)
if self.__channel == 1:
image = image.convert('L')
else:
image = image.convert('RGB')
image = np.array(image).astype(np.float32) image = np.array(image).astype(np.float32)
image = np.expand_dims(image, axis=0) / 255. image = np.expand_dims(image, axis=0) / 255.
if not self.use_import_onnx: image = (image - 0.5) / 0.5
image = (image - 0.5) / 0.5
else:
if self.__channel == 1:
image = (image - 0.456) / 0.224
else:
image = (image - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])
ort_inputs = {'input1': np.array([image])} ort_inputs = {'input1': np.array([image])}
ort_outs = self.__ort_session.run(None, ort_inputs) ort_outs = self.__ort_session.run(None, ort_inputs)
result = [] result = []
last_item = 0 last_item = 0
if self.__word: for item in ort_outs[0][0]:
for item in ort_outs[1]: if item == last_item:
continue
else:
last_item = item
if item != 0:
result.append(self.__charset[item]) result.append(self.__charset[item])
else:
for item in ort_outs[0][0]:
if item == last_item:
continue
else:
last_item = item
if item != 0:
result.append(self.__charset[item])
return ''.join(result) return ''.join(result)