changdazhou пре 1 година
родитељ
комит
e53c9907dd

+ 9 - 7
paddlex/inference/pipelines/table_recognition/table_recognition.py

@@ -103,14 +103,16 @@ class TableRecPipeline(BasePipeline):
             # update layout result
             single_img_res["input_path"] = layout_pred["input_path"]
             single_img_res["layout_result"] = layout_pred
-            subs_of_img = list(self._crop_by_boxes(layout_pred))
-            # get cropped images with label "table"
+            ocr_res = ocr_pred
             table_subs = []
-            for sub in subs_of_img:
-                box = sub["box"]
-                if sub["label"].lower() == "table":
-                    table_subs.append(sub)
-                    _, ocr_res = self.get_related_ocr_result(box, ocr_pred)
+            if len(layout_pred["boxes"]) > 0:
+                subs_of_img = list(self._crop_by_boxes(layout_pred))
+                # get cropped images with label "table"
+                for sub in subs_of_img:
+                    box = sub["box"]
+                    if sub["label"].lower() == "table":
+                        table_subs.append(sub)
+                        _, ocr_res = self.get_related_ocr_result(box, ocr_res)
             table_res, all_table_ocr_res = self.get_table_result(table_subs)
             for table_ocr_res in all_table_ocr_res:
                 ocr_res["dt_polys"].extend(table_ocr_res["dt_polys"])

+ 7 - 7
paddlex/inference/results/table_rec.py

@@ -22,18 +22,13 @@ from .utils.mixin import HtmlMixin, XlsxMixin
 from .base import BaseResult, CVResult
 
 
-class TableRecResult(CVResult, HtmlMixin):
+class TableRecResult(CVResult):
     """SaveTableResults"""
 
     _HARD_FLAG = False
 
     def __init__(self, data):
         super().__init__(data)
-        HtmlMixin.__init__(self)
-        self._show_func_register("save_to_html")(self.save_to_html)
-
-    def _to_html(self):
-        return self["html"]
 
     def _to_img(self):
         image = self._img_reader.read(self["input_path"])
@@ -64,13 +59,18 @@ class TableRecResult(CVResult, HtmlMixin):
         return image
 
 
-class StructureTableResult(TableRecResult, XlsxMixin):
+class StructureTableResult(TableRecResult, HtmlMixin, XlsxMixin):
     """StructureTableResult"""
 
     def __init__(self, data):
         super().__init__(data)
+        HtmlMixin.__init__(self)
+        self._show_func_register("save_to_html")(self.save_to_html)
         XlsxMixin.__init__(self)
 
+    def _to_html(self):
+        return self["html"]
+
 
 class TableResult(BaseResult):
     """TableResult"""