ソースを参照

update table recognition pipeline

zhouchangda 1 年間 前
コミット
0b8147d5bc

+ 0 - 1
paddlex/inference/components/task_related/det.py

@@ -14,7 +14,6 @@
 
 import os
 
-from ....utils import logging
 from ...utils.io import ImageReader
 from ..base import BaseComponent
 

+ 24 - 71
paddlex/inference/pipelines/table_recognition/table_recognition.py

@@ -12,14 +12,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import re
 import numpy as np
 from ..base import BasePipeline
 from ...predictors import create_predictor
 from ..ocr import OCRPipeline
 from ...components import CropByBoxes
-from ...results import OCRResult, TableResult, StructureTableResult
-from copy import deepcopy
+from ...results import TableResult, StructureTableResult
 from .utils import *
 
 
@@ -57,7 +55,7 @@ class TableRecPipeline(BasePipeline):
             self.layout_predictor(x), self.ocr_pipeline(x)
         ):
             for layout_pred, ocr_pred in zip(batch_layout_pred, batch_ocr_pred):
-                single_img_structure_res = {
+                single_img_res = {
                     "img_path": "",
                     "layout_result": {},
                     "ocr_result": {},
@@ -65,67 +63,22 @@ class TableRecPipeline(BasePipeline):
                 }
                 layout_res = layout_pred["result"]
                 # update layout result
-                single_img_structure_res["img_path"] = layout_res["img_path"]
-                single_img_structure_res["layout_result"] = layout_res
-                single_img_ocr_res = ocr_pred["result"]
+                single_img_res["img_path"] = layout_res["img_path"]
+                single_img_res["layout_result"] = layout_res
+                ocr_res = ocr_pred["result"]
+                single_img_res["ocr_result"] = ocr_res
                 all_subs_of_img = list(self._crop_by_boxes(layout_res))
-                table_subs_of_img = []
-                seal_subs_of_img = []
-                # ocr result without table and seal
-                ocr_res = deepcopy(single_img_ocr_res)
-                # ocr result in table and seal, is for batch
-                table_ocr_res, seal_ocr_res = [], []
-                # get cropped images and ocr result
+                # get cropped images with label 'table'
+                table_subs = []
                 for batch_subs in all_subs_of_img:
-                    table_batch_list, seal_batch_list = [], []
-                    table_batch_ocr_res, seal_batch_ocr_res = [], []
+                    table_sub_list = []
                     for sub in batch_subs:
-                        box = sub["box"]
                         if sub["label"].lower() == "table":
-                            table_batch_list.append(sub)
-                            relative_res, ocr_res = self.get_ocr_result_by_bbox(
-                                box, ocr_res
-                            )
-                            table_batch_ocr_res.append(
-                                {
-                                    "dt_polys": relative_res[0],
-                                    "rec_text": relative_res[1],
-                                }
-                            )
-                        elif sub["label"].lower() == "seal":
-                            seal_batch_list.append(sub)
-                            relative_res, ocr_res = self.get_ocr_result_by_bbox(
-                                box, ocr_res
-                            )
-                            seal_batch_ocr_res.append(
-                                {
-                                    "dt_polys": relative_res[0],
-                                    "rec_text": relative_res[1],
-                                }
-                            )
-                        elif sub["label"].lower() == "figure":
-                            # remove ocr result in figure
-                            _, ocr_res = self.get_ocr_result_by_bbox(box, ocr_res)
-                    table_subs_of_img.append(table_batch_list)
-                    table_ocr_res.append(table_batch_ocr_res)
-                    seal_subs_of_img.append(seal_batch_list)
-                    seal_ocr_res.append(seal_batch_ocr_res)
+                            table_sub_list.append(sub)
+                    table_subs.append(table_sub_list)
+                single_img_res["table_result"] = self.get_table_result(table_subs)
 
-                # get table result
-                table_res = self.get_table_result(table_subs_of_img, table_ocr_res)
-                # get seal result
-                if seal_subs_of_img:
-                    pass
-
-                if self.chat_ocr:
-                    # chat ocr does not visualize table results in ocr result
-                    single_img_structure_res["ocr_result"] = OCRResult(ocr_res)
-                else:
-                    single_img_structure_res["ocr_result"] = single_img_ocr_res
-                single_img_structure_res["table_result"] = table_res
-                batch_structure_res.append(
-                    {"result": TableResult(single_img_structure_res)}
-                )
+                batch_structure_res.append({"result": TableResult(single_img_res)})
         yield batch_structure_res
 
     def get_ocr_result_by_bbox(self, box, ocr_res):
@@ -142,33 +95,33 @@ class TableRecPipeline(BasePipeline):
                 unmatched_ocr_res["rec_text"].append(text_res)
         return (dt_polys_list, rec_text_list), unmatched_ocr_res
 
-    def get_table_result(self, input_img, table_ocr_res):
+    def get_table_result(self, input_img):
         table_res_list = []
         table_index = 0
-        for batch_input, batch_table_res, batch_ocr_res in zip(
-            input_img, self.table_predictor(input_img), table_ocr_res
+        for batch_input, batch_table_pred, batch_ocr_pred in zip(
+            input_img, self.table_predictor(input_img), self.ocr_pipeline(input_img)
         ):
             batch_res_list = []
-            for roi_img, table_res, ocr_res in zip(
-                batch_input, batch_table_res, batch_ocr_res
+            for input, table_pred, ocr_pred in zip(
+                batch_input, batch_table_pred, batch_ocr_pred
             ):
-                single_table_res = table_res["result"]
+                single_table_res = table_pred["result"]
+                ocr_res = ocr_pred["result"]
                 single_table_box = single_table_res["bbox"]
-                ori_x, ori_y, _, _ = roi_img["box"]
+                ori_x, ori_y, _, _ = input["box"]
                 ori_bbox_list = np.array(
                     get_ori_coordinate_for_table(ori_x, ori_y, single_table_box),
                     dtype=np.float32,
                 )
-                single_table_res["bbox"] = ori_bbox_list
                 html_res = self._match(single_table_res, ocr_res)
                 batch_res_list.append(
                     StructureTableResult(
                         {
-                            "img_path": roi_img["img_path"],
-                            "img_idx": table_index,
+                            "img_path": input["img_path"],
                             "bbox": ori_bbox_list,
+                            "img_idx": table_index,
+                            "ocr_res": ocr_res,
                             "html": html_res,
-                            "structure": single_table_res["structure"],
                         }
                     )
                 )