mirror of
https://github.com/sml2h3/ddddocr.git
synced 2025-05-01 11:31:48 +08:00
修复调用set_range没有得到期望结果的BUG
This commit is contained in:
parent
db75d4ac99
commit
90811d985f
@ -54,6 +54,7 @@ class DdddOcr(object):
|
||||
self.__word = False
|
||||
self.__resize = []
|
||||
self.__charset_range = []
|
||||
self.__valid_charset_range_index = [] # 指定字符对应的有效索引
|
||||
self.__channel = 1
|
||||
if import_onnx_path != "":
|
||||
det = False
|
||||
@ -2593,6 +2594,16 @@ class DdddOcr(object):
|
||||
|
||||
# 去重
|
||||
self.__charset_range = list(set(self.__charset_range)) + [""]
|
||||
# 根据指定字符获取对应的索引
|
||||
valid_charset_range_index = []
|
||||
if len(self.__charset_range) > 0:
|
||||
for item in self.__charset_range:
|
||||
if item in self.__charset:
|
||||
valid_charset_range_index.append(self.__charset.index(item))
|
||||
else:
|
||||
# 未知字符没有索引,直接忽略
|
||||
pass
|
||||
self.__valid_charset_range_index = valid_charset_range_index
|
||||
|
||||
|
||||
def classification(self, img, png_fix: bool = False, probability=False):
|
||||
@ -2667,28 +2678,36 @@ class DdddOcr(object):
|
||||
result['probability'] = ort_outs_probability
|
||||
else:
|
||||
result['charsets'] = self.__charset_range
|
||||
probability_result_index = []
|
||||
for item in self.__charset_range:
|
||||
if item in self.__charset:
|
||||
probability_result_index.append(self.__charset.index(item))
|
||||
else:
|
||||
# 未知字符
|
||||
probability_result_index.append(-1)
|
||||
valid_charset_range_index = self.__valid_charset_range_index
|
||||
probability_result = []
|
||||
for item in ort_outs_probability:
|
||||
probability_result.append([item[i] if i != -1 else -1 for i in probability_result_index ])
|
||||
probability_result.append([item[i] for i in valid_charset_range_index ])
|
||||
result['probability'] = probability_result
|
||||
return result
|
||||
else:
|
||||
last_item = 0
|
||||
argmax_result = np.squeeze(np.argmax(ort_outs[0], axis=2))
|
||||
for item in argmax_result:
|
||||
if item == last_item:
|
||||
continue
|
||||
else:
|
||||
last_item = item
|
||||
if item != 0:
|
||||
result.append(self.__charset[item])
|
||||
if len(self.__charset_range) == 0:
|
||||
# 没有指定特定的字符集合,直接获取结果
|
||||
last_item = 0
|
||||
argmax_result = np.squeeze(np.argmax(ort_outs[0], axis=2))
|
||||
for item in argmax_result:
|
||||
if item == last_item:
|
||||
continue
|
||||
else:
|
||||
last_item = item
|
||||
if item != 0:
|
||||
result.append(self.__charset[item])
|
||||
else:
|
||||
# 指定了特定的字符集合
|
||||
last_item = 0
|
||||
valid_charset_range_index = self.__valid_charset_range_index
|
||||
for row in np.squeeze(ort_outs[0]):
|
||||
# 仅在指定字符集合中寻找最大值
|
||||
idx = np.argmax(row[list(valid_charset_range_index)])
|
||||
if idx == last_item:
|
||||
continue
|
||||
else:
|
||||
last_item = idx
|
||||
result.append(self.__charset[valid_charset_range_index[idx]])
|
||||
return ''.join(result)
|
||||
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user