5
0
mirror of https://github.com/sml2h3/ddddocr.git synced 2025-05-01 11:31:48 +08:00

修复调用set_range没有得到期望结果的BUG

This commit is contained in:
lotomer@163.com 2024-10-18 09:02:09 +08:00
parent db75d4ac99
commit 90811d985f

View File

@ -54,6 +54,7 @@ class DdddOcr(object):
self.__word = False
self.__resize = []
self.__charset_range = []
self.__valid_charset_range_index = [] # 指定字符对应的有效索引
self.__channel = 1
if import_onnx_path != "":
det = False
@ -2593,6 +2594,16 @@ class DdddOcr(object):
# 去重
self.__charset_range = list(set(self.__charset_range)) + [""]
# 根据指定字符获取对应的索引
valid_charset_range_index = []
if len(self.__charset_range) > 0:
for item in self.__charset_range:
if item in self.__charset:
valid_charset_range_index.append(self.__charset.index(item))
else:
# 未知字符没有索引,直接忽略
pass
self.__valid_charset_range_index = valid_charset_range_index
def classification(self, img, png_fix: bool = False, probability=False):
@ -2667,28 +2678,36 @@ class DdddOcr(object):
result['probability'] = ort_outs_probability
else:
result['charsets'] = self.__charset_range
probability_result_index = []
for item in self.__charset_range:
if item in self.__charset:
probability_result_index.append(self.__charset.index(item))
else:
# 未知字符
probability_result_index.append(-1)
valid_charset_range_index = self.__valid_charset_range_index
probability_result = []
for item in ort_outs_probability:
probability_result.append([item[i] if i != -1 else -1 for i in probability_result_index ])
probability_result.append([item[i] for i in valid_charset_range_index ])
result['probability'] = probability_result
return result
else:
last_item = 0
argmax_result = np.squeeze(np.argmax(ort_outs[0], axis=2))
for item in argmax_result:
if item == last_item:
continue
else:
last_item = item
if item != 0:
result.append(self.__charset[item])
if len(self.__charset_range) == 0:
# 没有指定特定的字符集合,直接获取结果
last_item = 0
argmax_result = np.squeeze(np.argmax(ort_outs[0], axis=2))
for item in argmax_result:
if item == last_item:
continue
else:
last_item = item
if item != 0:
result.append(self.__charset[item])
else:
# 指定了特定的字符集合
last_item = 0
valid_charset_range_index = self.__valid_charset_range_index
for row in np.squeeze(ort_outs[0]):
# 仅在指定字符集合中寻找最大值
idx = np.argmax(row[list(valid_charset_range_index)])
if idx == last_item:
continue
else:
last_item = idx
result.append(self.__charset[valid_charset_range_index[idx]])
return ''.join(result)
else: