mirror of
https://github.com/sml2h3/ddddocr.git
synced 2025-05-03 21:21:49 +08:00
Merge branch 'master' of https://github.com/sml2h3/ddddocr
Conflicts: ddddocr/__init__.py
This commit is contained in:
commit
2cf2924ed1
@ -5,7 +5,7 @@ warnings.filterwarnings('ignore')
|
||||
import io
|
||||
import os
|
||||
import base64
|
||||
import json
|
||||
import pathlib
|
||||
import onnxruntime
|
||||
from PIL import Image, ImageChops
|
||||
import numpy as np
|
||||
@ -29,26 +29,13 @@ class TypeError(Exception):
|
||||
|
||||
class DdddOcr(object):
|
||||
def __init__(self, ocr: bool = True, det: bool = False, old: bool = False, use_gpu: bool = False,
|
||||
device_id: int = 0, show_ad=True, import_onnx_path: str = "", charsets_path: str = ""):
|
||||
device_id: int = 0, show_ad=True):
|
||||
if show_ad:
|
||||
print("欢迎使用ddddocr,本项目专注带动行业内卷,个人博客:wenanzhe.com")
|
||||
print("训练数据支持来源于:http://146.56.204.113:19199/preview")
|
||||
print("爬虫框架feapder可快速一键接入,快速开启爬虫之旅:https://github.com/Boris-code/feapder")
|
||||
self.use_import_onnx = False
|
||||
self.__word = False
|
||||
self.__resize = []
|
||||
self.__channel = 1
|
||||
if import_onnx_path != "":
|
||||
det = False
|
||||
ocr = False
|
||||
self.__graph_path = import_onnx_path
|
||||
with open(charsets_path, 'r', encoding="utf-8") as f:
|
||||
info = json.loads(f.read())
|
||||
self.__charset = info['charset']
|
||||
self.__word = info['word']
|
||||
self.__resize = info['image']
|
||||
self.__channel = info['channel']
|
||||
self.use_import_onnx = True
|
||||
|
||||
|
||||
if det:
|
||||
ocr = False
|
||||
print("开启det后自动关闭ocr")
|
||||
@ -1453,7 +1440,7 @@ class DdddOcr(object):
|
||||
self.__providers = [
|
||||
'CPUExecutionProvider',
|
||||
]
|
||||
if ocr or det or self.use_import_onnx:
|
||||
if ocr or det:
|
||||
self.__ort_session = onnxruntime.InferenceSession(self.__graph_path, providers=self.__providers)
|
||||
|
||||
def preproc(self, img, input_size, swap=(2, 0, 1)):
|
||||
@ -1594,53 +1581,35 @@ class DdddOcr(object):
|
||||
return []
|
||||
return result
|
||||
|
||||
def classification(self, img_bytes: bytes = None, img_base64: str = None):
|
||||
def classification(self, img):
|
||||
if self.det:
|
||||
raise TypeError("当前识别类型为目标检测")
|
||||
if img_bytes:
|
||||
image = Image.open(io.BytesIO(img_bytes))
|
||||
if not isinstance(img, (bytes, str, pathlib.PurePath, Image.Image)):
|
||||
raise TypeError("未知图片类型")
|
||||
if isinstance(img, bytes):
|
||||
image = Image.open(io.BytesIO(img))
|
||||
elif isinstance(img, Image.Image):
|
||||
image = img.copy()
|
||||
elif isinstance(img, str):
|
||||
image = base64_to_image(img)
|
||||
else:
|
||||
image = base64_to_image(img_base64)
|
||||
if not self.use_import_onnx:
|
||||
image = image.resize((int(image.size[0] * (64 / image.size[1])), 64), Image.ANTIALIAS).convert('L')
|
||||
else:
|
||||
if self.__resize[0] == -1:
|
||||
if self.__word:
|
||||
image = image.resize((self.__resize[1], self.__resize[1]), Image.ANTIALIAS)
|
||||
else:
|
||||
image = image.resize((int(image.size[0] * (self.__resize[1] / image.size[1])), self.__resize[1]), Image.ANTIALIAS)
|
||||
else:
|
||||
image = image.resize((self.__resize[0], self.__resize[1]), Image.ANTIALIAS)
|
||||
if self.__channel == 1:
|
||||
image = image.convert('L')
|
||||
else:
|
||||
image = image.convert('RGB')
|
||||
assert isinstance(img, pathlib.PurePath)
|
||||
image = Image.open(img)
|
||||
image = image.resize((int(image.size[0] * (64 / image.size[1])), 64), Image.ANTIALIAS).convert('L')
|
||||
image = np.array(image).astype(np.float32)
|
||||
image = np.expand_dims(image, axis=0) / 255.
|
||||
if not self.use_import_onnx:
|
||||
image = (image - 0.5) / 0.5
|
||||
else:
|
||||
if self.__channel == 1:
|
||||
image = (image - 0.456) / 0.224
|
||||
else:
|
||||
image = (image - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])
|
||||
|
||||
image = (image - 0.5) / 0.5
|
||||
ort_inputs = {'input1': np.array([image])}
|
||||
ort_outs = self.__ort_session.run(None, ort_inputs)
|
||||
result = []
|
||||
|
||||
last_item = 0
|
||||
if self.__word:
|
||||
for item in ort_outs[1]:
|
||||
for item in ort_outs[0][0]:
|
||||
if item == last_item:
|
||||
continue
|
||||
else:
|
||||
last_item = item
|
||||
if item != 0:
|
||||
result.append(self.__charset[item])
|
||||
else:
|
||||
for item in ort_outs[0][0]:
|
||||
if item == last_item:
|
||||
continue
|
||||
else:
|
||||
last_item = item
|
||||
if item != 0:
|
||||
result.append(self.__charset[item])
|
||||
|
||||
return ''.join(result)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user