|
|
@@ -37,7 +37,7 @@ class TableLabelDecode(BaseComponent):
|
|
|
"structure_score": "structure_score",
|
|
|
}
|
|
|
|
|
|
- def __init__(self, merge_no_span_structure=True, dict_character=[]):
|
|
|
+ def __init__(self, model_name, merge_no_span_structure=True, dict_character=[]):
|
|
|
super().__init__()
|
|
|
|
|
|
if merge_no_span_structure:
|
|
|
@@ -46,6 +46,8 @@ class TableLabelDecode(BaseComponent):
|
|
|
if "<td>" in dict_character:
|
|
|
dict_character.remove("<td>")
|
|
|
|
|
|
+ self.model_name = model_name
|
|
|
+
|
|
|
dict_character = self.add_special_char(dict_character)
|
|
|
self.dict = {}
|
|
|
for i, char in enumerate(dict_character):
|
|
|
@@ -170,15 +172,20 @@ class TableLabelDecode(BaseComponent):
|
|
|
|
|
|
def _bbox_decode(self, bbox, padding_shape, ori_shape):
|
|
|
|
|
|
- pad_w, pad_h = padding_shape
|
|
|
- w, h = ori_shape
|
|
|
- ratio_w = pad_w / w
|
|
|
- ratio_h = pad_h / h
|
|
|
- ratio = min(ratio_w, ratio_h)
|
|
|
-
|
|
|
- bbox[0::2] *= pad_w
|
|
|
- bbox[1::2] *= pad_h
|
|
|
- bbox[0::2] /= ratio
|
|
|
- bbox[1::2] /= ratio
|
|
|
+ if self.model_name == "SLANet":
|
|
|
+ w, h = ori_shape
|
|
|
+ bbox[0::2] *= w
|
|
|
+ bbox[1::2] *= h
|
|
|
+ else:
|
|
|
+ w, h = padding_shape
|
|
|
+ ori_w, ori_h = ori_shape
|
|
|
+ ratio_w = w / ori_w
|
|
|
+ ratio_h = h / ori_h
|
|
|
+ ratio = min(ratio_w, ratio_h)
|
|
|
+
|
|
|
+ bbox[0::2] *= w
|
|
|
+ bbox[1::2] *= h
|
|
|
+ bbox[0::2] /= ratio
|
|
|
+ bbox[1::2] /= ratio
|
|
|
|
|
|
return bbox
|