|
|
@@ -27,7 +27,6 @@ import tempfile
|
|
|
from tokenizers import Tokenizer as TokenizerFast
|
|
|
|
|
|
from ....utils import logging
|
|
|
-from ...results import TextRecResult
|
|
|
from ..base import BaseComponent
|
|
|
|
|
|
__all__ = [
|
|
|
@@ -192,10 +191,10 @@ class OCRReisizeNormImg(BaseComponent):
|
|
|
class BaseRecLabelDecode(BaseComponent):
|
|
|
"""Convert between text-label and text-index"""
|
|
|
|
|
|
- INPUT_KEYS = ["pred", "img_path"]
|
|
|
- OUTPUT_KEYS = ["text_rec_res"]
|
|
|
- DEAULT_INPUTS = {"pred": "pred", "img_path": "img_path"}
|
|
|
- DEAULT_OUTPUTS = {"text_rec_res": "text_rec_res"}
|
|
|
+ INPUT_KEYS = ["pred"]
|
|
|
+ OUTPUT_KEYS = ["rec_text", "rec_score"]
|
|
|
+ DEAULT_INPUTS = {"pred": "pred"}
|
|
|
+ DEAULT_OUTPUTS = {"rec_text": "rec_text", "rec_score": "rec_score"}
|
|
|
|
|
|
ENABLE_BATCH = True
|
|
|
|
|
|
@@ -271,7 +270,7 @@ class BaseRecLabelDecode(BaseComponent):
|
|
|
"""get_ignored_tokens"""
|
|
|
return [0] # for ctc blank
|
|
|
|
|
|
- def apply(self, pred, img_path):
|
|
|
+ def apply(self, pred):
|
|
|
"""apply"""
|
|
|
preds = np.array(pred)
|
|
|
if isinstance(preds, tuple) or isinstance(preds, list):
|
|
|
@@ -279,14 +278,7 @@ class BaseRecLabelDecode(BaseComponent):
|
|
|
preds_idx = preds.argmax(axis=2)
|
|
|
preds_prob = preds.max(axis=2)
|
|
|
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
|
|
|
- return [
|
|
|
- {
|
|
|
- "text_rec_res": TextRecResult(
|
|
|
- {"img_path": path, "rec_text": t[0], "rec_score": t[1]}
|
|
|
- )
|
|
|
- }
|
|
|
- for path, t in zip(img_path, text)
|
|
|
- ]
|
|
|
+ return [{"rec_text": t[0], "rec_score": t[1]} for t in text]
|
|
|
|
|
|
|
|
|
class CTCLabelDecode(BaseRecLabelDecode):
|
|
|
@@ -295,20 +287,13 @@ class CTCLabelDecode(BaseRecLabelDecode):
|
|
|
def __init__(self, character_list=None, use_space_char=True):
|
|
|
super().__init__(character_list, use_space_char=use_space_char)
|
|
|
|
|
|
- def apply(self, pred, img_path):
|
|
|
+ def apply(self, pred):
|
|
|
"""apply"""
|
|
|
- preds = np.array(pred)
|
|
|
+ preds = np.array(pred[0])
|
|
|
preds_idx = preds.argmax(axis=2)
|
|
|
preds_prob = preds.max(axis=2)
|
|
|
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
|
|
|
- return [
|
|
|
- {
|
|
|
- "text_rec_res": TextRecResult(
|
|
|
- {"img_path": path, "rec_text": t[0], "rec_score": t[1]}
|
|
|
- )
|
|
|
- }
|
|
|
- for path, t in zip(img_path, text)
|
|
|
- ]
|
|
|
+ return [{"rec_text": t[0], "rec_score": t[1]} for t in text]
|
|
|
|
|
|
def add_special_char(self, character_list):
|
|
|
"""add_special_char"""
|