mirror of
https://github.com/sml2h3/ddddocr.git
synced 2025-05-03 00:01:11 +08:00
修复调用set_range没有得到期望结果的BUG
This commit is contained in:
parent
db75d4ac99
commit
90811d985f
@ -54,6 +54,7 @@ class DdddOcr(object):
|
|||||||
self.__word = False
|
self.__word = False
|
||||||
self.__resize = []
|
self.__resize = []
|
||||||
self.__charset_range = []
|
self.__charset_range = []
|
||||||
|
self.__valid_charset_range_index = [] # 指定字符对应的有效索引
|
||||||
self.__channel = 1
|
self.__channel = 1
|
||||||
if import_onnx_path != "":
|
if import_onnx_path != "":
|
||||||
det = False
|
det = False
|
||||||
@ -2593,6 +2594,16 @@ class DdddOcr(object):
|
|||||||
|
|
||||||
# 去重
|
# 去重
|
||||||
self.__charset_range = list(set(self.__charset_range)) + [""]
|
self.__charset_range = list(set(self.__charset_range)) + [""]
|
||||||
|
# 根据指定字符获取对应的索引
|
||||||
|
valid_charset_range_index = []
|
||||||
|
if len(self.__charset_range) > 0:
|
||||||
|
for item in self.__charset_range:
|
||||||
|
if item in self.__charset:
|
||||||
|
valid_charset_range_index.append(self.__charset.index(item))
|
||||||
|
else:
|
||||||
|
# 未知字符没有索引,直接忽略
|
||||||
|
pass
|
||||||
|
self.__valid_charset_range_index = valid_charset_range_index
|
||||||
|
|
||||||
|
|
||||||
def classification(self, img, png_fix: bool = False, probability=False):
|
def classification(self, img, png_fix: bool = False, probability=False):
|
||||||
@ -2667,28 +2678,36 @@ class DdddOcr(object):
|
|||||||
result['probability'] = ort_outs_probability
|
result['probability'] = ort_outs_probability
|
||||||
else:
|
else:
|
||||||
result['charsets'] = self.__charset_range
|
result['charsets'] = self.__charset_range
|
||||||
probability_result_index = []
|
valid_charset_range_index = self.__valid_charset_range_index
|
||||||
for item in self.__charset_range:
|
|
||||||
if item in self.__charset:
|
|
||||||
probability_result_index.append(self.__charset.index(item))
|
|
||||||
else:
|
|
||||||
# 未知字符
|
|
||||||
probability_result_index.append(-1)
|
|
||||||
probability_result = []
|
probability_result = []
|
||||||
for item in ort_outs_probability:
|
for item in ort_outs_probability:
|
||||||
probability_result.append([item[i] if i != -1 else -1 for i in probability_result_index ])
|
probability_result.append([item[i] for i in valid_charset_range_index ])
|
||||||
result['probability'] = probability_result
|
result['probability'] = probability_result
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
last_item = 0
|
if len(self.__charset_range) == 0:
|
||||||
argmax_result = np.squeeze(np.argmax(ort_outs[0], axis=2))
|
# 没有指定特定的字符集合,直接获取结果
|
||||||
for item in argmax_result:
|
last_item = 0
|
||||||
if item == last_item:
|
argmax_result = np.squeeze(np.argmax(ort_outs[0], axis=2))
|
||||||
continue
|
for item in argmax_result:
|
||||||
else:
|
if item == last_item:
|
||||||
last_item = item
|
continue
|
||||||
if item != 0:
|
else:
|
||||||
result.append(self.__charset[item])
|
last_item = item
|
||||||
|
if item != 0:
|
||||||
|
result.append(self.__charset[item])
|
||||||
|
else:
|
||||||
|
# 指定了特定的字符集合
|
||||||
|
last_item = 0
|
||||||
|
valid_charset_range_index = self.__valid_charset_range_index
|
||||||
|
for row in np.squeeze(ort_outs[0]):
|
||||||
|
# 仅在指定字符集合中寻找最大值
|
||||||
|
idx = np.argmax(row[list(valid_charset_range_index)])
|
||||||
|
if idx == last_item:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
last_item = idx
|
||||||
|
result.append(self.__charset[valid_charset_range_index[idx]])
|
||||||
return ''.join(result)
|
return ''.join(result)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user