فهرست منبع

fix slanet inference (#2276)

changdazhou 1 سال پیش
والد
کامیت
e624dd52e6
2فایلهای تغییر یافته به همراه23 افزوده شده و 13 حذف شده
  1. 18 11
      paddlex/inference/components/task_related/table_rec.py
  2. 5 2
      paddlex/inference/models/table_recognition.py

+ 18 - 11
paddlex/inference/components/task_related/table_rec.py

@@ -37,7 +37,7 @@ class TableLabelDecode(BaseComponent):
         "structure_score": "structure_score",
     }
 
-    def __init__(self, merge_no_span_structure=True, dict_character=[]):
+    def __init__(self, model_name, merge_no_span_structure=True, dict_character=[]):
         super().__init__()
 
         if merge_no_span_structure:
@@ -46,6 +46,8 @@ class TableLabelDecode(BaseComponent):
             if "<td>" in dict_character:
                 dict_character.remove("<td>")
 
+        self.model_name = model_name
+
         dict_character = self.add_special_char(dict_character)
         self.dict = {}
         for i, char in enumerate(dict_character):
@@ -170,15 +172,20 @@ class TableLabelDecode(BaseComponent):
 
     def _bbox_decode(self, bbox, padding_shape, ori_shape):
 
-        pad_w, pad_h = padding_shape
-        w, h = ori_shape
-        ratio_w = pad_w / w
-        ratio_h = pad_h / h
-        ratio = min(ratio_w, ratio_h)
-
-        bbox[0::2] *= pad_w
-        bbox[1::2] *= pad_h
-        bbox[0::2] /= ratio
-        bbox[1::2] /= ratio
+        if self.model_name == "SLANet":
+            w, h = ori_shape
+            bbox[0::2] *= w
+            bbox[1::2] *= h
+        else:
+            w, h = padding_shape
+            ori_w, ori_h = ori_shape
+            ratio_w = w / ori_w
+            ratio_h = h / ori_h
+            ratio = min(ratio_w, ratio_h)
+
+            bbox[0::2] *= w
+            bbox[1::2] *= h
+            bbox[0::2] /= ratio
+            bbox[1::2] /= ratio
 
         return bbox

+ 5 - 2
paddlex/inference/models/table_recognition.py

@@ -46,12 +46,15 @@ class TablePredictor(BasicPredictor):
         )
         self._add_component(predictor)
 
-        op = self.build_postprocess(**self.config["PostProcess"])
+        op = self.build_postprocess(
+            model_name=self.config["Global"]["model_name"], **self.config["PostProcess"]
+        )
         self._add_component(op)
 
-    def build_postprocess(self, **kwargs):
+    def build_postprocess(self, model_name, **kwargs):
         if kwargs.get("name") == "TableLabelDecode":
             return TableLabelDecode(
+                model_name=model_name,
                 merge_no_span_structure=kwargs.get("merge_no_span_structure"),
                 dict_character=kwargs.get("character_dict"),
             )