Sfoglia il codice sorgente

support table recognition pipeline

zhouchangda 1 anno fa
parent
commit
f3e488fd9c

+ 2 - 2
paddlex/inference/components/task_related/__init__.py

@@ -15,6 +15,6 @@
 from .clas import Topk, MultiLabelThreshOutput, NormalizeFeatures
 from .text_det import DetResizeForTest, NormalizeImage, DBPostProcess, CropByPolys
 from .text_rec import OCRReisizeNormImg, CTCLabelDecode
-from .table_rec import TableLabelDecode, TableMasterLabelDecode
-from .det import DetPostProcess
+from .table_rec import TableLabelDecode
+from .det import DetPostProcess, CropByBoxes
 from .instance_seg import InstanceSegPostProcess

+ 30 - 0
paddlex/inference/components/task_related/det.py

@@ -15,6 +15,7 @@
 import os
 
 from ....utils import logging
+from ...utils.io import ImageReader
 from ..base import BaseComponent
 
 
@@ -41,3 +42,32 @@ class DetPostProcess(BaseComponent):
         result = {"boxes": boxes, "labels": self.labels}
 
         return result
+
+
+class CropByBoxes(BaseComponent):
+    """Crop Image by Box"""
+
+    INPUT_KEYS = ["img_path", "boxes", "labels"]
+    OUTPUT_KEYS = ["img", "box", "label"]
+    DEAULT_INPUTS = {"img_path": "img_path", "boxes": "boxes", "labels": "labels"}
+    DEAULT_OUTPUTS = {"img": "img", "box": "box", "label": "label"}
+
+    def __init__(self):
+        super().__init__()
+        self._reader = ImageReader(backend="opencv")
+
+    def apply(self, img_path, boxes, labels=None):
+        output_list = []
+        img = self._reader.read(img_path)
+        for bbox in boxes:
+            label_id = int(bbox[0])
+            box = bbox[2:]
+            if labels is not None:
+                label = labels[label_id]
+            else:
+                label = label_id
+            xmin, ymin, xmax, ymax = [int(i) for i in box]
+            img_crop = img[ymin:ymax, xmin:xmax]
+            output_list.append({"img": img_crop, "box": box, "label": label})
+
+        return output_list

+ 12 - 64
paddlex/inference/components/task_related/table_rec.py

@@ -16,7 +16,7 @@ import numpy as np
 
 from ..base import BaseComponent
 
-__all__ = ["TableLabelDecode", "TableMasterLabelDecode"]
+__all__ = ["TableLabelDecode"]
 
 
 class TableLabelDecode(BaseComponent):
@@ -25,9 +25,13 @@ class TableLabelDecode(BaseComponent):
     ENABLE_BATCH = True
 
     INPUT_KEYS = ["pred", "ori_img_size"]
-    OUTPUT_KEYS = ["bbox", "structure"]
+    OUTPUT_KEYS = ["bbox", "structure", "structure_score"]
     DEAULT_INPUTS = {"pred": "pred", "ori_img_size": "ori_img_size"}
-    DEAULT_OUTPUTS = {"bbox": "bbox", "structure": "structure"}
+    DEAULT_OUTPUTS = {
+        "bbox": "bbox",
+        "structure": "structure",
+        "structure_score": "structure_score",
+    }
 
     def __init__(self, merge_no_span_structure=True, dict_character=[]):
         super().__init__()
@@ -78,7 +82,7 @@ class TableLabelDecode(BaseComponent):
         bbox_preds = np.array(bbox_preds)
         structure_probs = np.array(structure_probs)
 
-        bbox_list, structure_str_list = self.decode(
+        bbox_list, structure_str_list, structure_score = self.decode(
             structure_probs, bbox_preds, ori_img_size
         )
         structure_str_list = [
@@ -89,9 +93,8 @@ class TableLabelDecode(BaseComponent):
             )
             for structure in structure_str_list
         ]
-
         return [
-            {"bbox": bbox, "structure": structure}
+            {"bbox": bbox, "structure": structure, "structure_score": structure_score}
             for bbox, structure in zip(bbox_list, structure_str_list)
         ]
 
@@ -123,10 +126,11 @@ class TableLabelDecode(BaseComponent):
                     bbox_list.append(bbox.tolist())
                 structure_list.append(text)
                 score_list.append(structure_probs[batch_idx, idx])
-            structure_batch_list.append([structure_list, float(np.mean(score_list))])
+            structure_batch_list.append([structure_list])
+            structure_score = np.mean(score_list)
             bbox_batch_list.append(bbox_list)
 
-        return bbox_batch_list, structure_batch_list
+        return bbox_batch_list, structure_batch_list, structure_score
 
     def decode_label(self, batch):
         """convert text-label into text-index."""
@@ -163,59 +167,3 @@ class TableLabelDecode(BaseComponent):
         bbox[0::2] *= w
         bbox[1::2] *= h
         return bbox
-
-
-class TableMasterLabelDecode(TableLabelDecode):
-    """decode the table model outputs(probs) to character str"""
-
-    def __init__(
-        self,
-        character_dict_type="TableMaster",
-        box_shape="pad",
-        merge_no_span_structure=True,
-    ):
-        super(TableMasterLabelDecode, self).__init__(
-            character_dict_type, merge_no_span_structure
-        )
-        self.box_shape = box_shape
-        assert box_shape in [
-            "ori",
-            "pad",
-        ], "The shape used for box normalization must be ori or pad"
-
-    def add_special_char(self, dict_character):
-        """add_special_char"""
-        self.beg_str = "<SOS>"
-        self.end_str = "<EOS>"
-        self.unknown_str = "<UKN>"
-        self.pad_str = "<PAD>"
-        dict_character = dict_character
-        dict_character = dict_character + [
-            self.unknown_str,
-            self.beg_str,
-            self.end_str,
-            self.pad_str,
-        ]
-        return dict_character
-
-    def get_ignored_tokens(self):
-        """get_ignored_tokens"""
-        pad_idx = self.dict[self.pad_str]
-        start_idx = self.dict[self.beg_str]
-        end_idx = self.dict[self.end_str]
-        unknown_idx = self.dict[self.unknown_str]
-        return [start_idx, end_idx, pad_idx, unknown_idx]
-
-    def _bbox_decode(self, bbox, shape):
-        """_bbox_decode"""
-        h, w, ratio_h, ratio_w, pad_h, pad_w = shape
-        if self.box_shape == "pad":
-            h, w = pad_h, pad_w
-        bbox[0::2] *= w
-        bbox[1::2] *= h
-        bbox[0::2] /= ratio_w
-        bbox[1::2] /= ratio_h
-        x, y, w, h = bbox
-        x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
-        bbox = np.array([x1, y1, x2, y2])
-        return bbox

+ 1 - 0
paddlex/inference/pipelines/__init__.py

@@ -14,6 +14,7 @@
 
 from .image_classification import ClasPipeline
 from .ocr import OCRPipeline
+from .table_recognition import TableRecPipeline
 from .object_detection import DetPipeline
 from .instance_segmentation import InstanceSegPipeline
 from .semantic_segmentation import SegPipeline

+ 1 - 0
paddlex/inference/pipelines/ocr.py

@@ -14,6 +14,7 @@
 
 from .base import BasePipeline
 from ..predictors import create_predictor
+from ...utils import logging
 from ..components import CropByPolys
 from ..results import OCRResult
 

+ 15 - 0
paddlex/inference/pipelines/table_recognition/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .table_recognition import TableRecPipeline

+ 177 - 0
paddlex/inference/pipelines/table_recognition/table_recognition.py

@@ -0,0 +1,177 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+import numpy as np
+from ..base import BasePipeline
+from ...predictors import create_predictor
+from ..ocr import OCRPipeline
+from ...components import CropByBoxes
+from ...results import OCRResult, TableResult, StructureTableResult
+from copy import deepcopy
+from .utils import *
+
+
+class TableRecPipeline(BasePipeline):
+    """Table Recognition Pipeline"""
+
+    def __init__(
+        self,
+        layout_model,
+        text_det_model,
+        text_rec_model,
+        table_model,
+        batch_size=1,
+        device="gpu",
+        chat_ocr=False,
+    ):
+
+        self.layout_predictor = create_predictor(
+            model=layout_model, device=device, batch_size=batch_size
+        )
+        self.ocr_pipeline = OCRPipeline(
+            text_det_model, text_rec_model, batch_size, device
+        )
+        self.table_predictor = create_predictor(
+            model=table_model, device=device, batch_size=batch_size
+        )
+        self._crop_by_boxes = CropByBoxes()
+        self._match = TableMatch(filter_ocr_result=False)
+        self.chat_ocr = chat_ocr
+        super().__init__()
+
+    def predict(self, x):
+        batch_structure_res = []
+        for batch_layout_pred, batch_ocr_pred in zip(
+            self.layout_predictor(x), self.ocr_pipeline(x)
+        ):
+            for layout_pred, ocr_pred in zip(batch_layout_pred, batch_ocr_pred):
+                single_img_structure_res = {
+                    "img_path": "",
+                    "layout_result": {},
+                    "ocr_result": {},
+                    "table_result": [],
+                }
+                layout_res = layout_pred["result"]
+                # update layout result
+                single_img_structure_res["img_path"] = layout_res["img_path"]
+                single_img_structure_res["layout_result"] = layout_res
+                single_img_ocr_res = ocr_pred["result"]
+                all_subs_of_img = list(self._crop_by_boxes(layout_res))
+                table_subs_of_img = []
+                seal_subs_of_img = []
+                # ocr result without table and seal
+                ocr_res = deepcopy(single_img_ocr_res)
+                # ocr result in table and seal, is for batch
+                table_ocr_res, seal_ocr_res = [], []
+                # get cropped images and ocr result
+                for batch_subs in all_subs_of_img:
+                    table_batch_list, seal_batch_list = [], []
+                    table_batch_ocr_res, seal_batch_ocr_res = [], []
+                    for sub in batch_subs:
+                        box = sub["box"]
+                        if sub["label"].lower() == "table":
+                            table_batch_list.append(sub)
+                            relative_res, ocr_res = self.get_ocr_result_by_bbox(
+                                box, ocr_res
+                            )
+                            table_batch_ocr_res.append(
+                                {
+                                    "dt_polys": relative_res[0],
+                                    "rec_text": relative_res[1],
+                                }
+                            )
+                        elif sub["label"].lower() == "seal":
+                            seal_batch_list.append(sub)
+                            relative_res, ocr_res = self.get_ocr_result_by_bbox(
+                                box, ocr_res
+                            )
+                            seal_batch_ocr_res.append(
+                                {
+                                    "dt_polys": relative_res[0],
+                                    "rec_text": relative_res[1],
+                                }
+                            )
+                        elif sub["label"].lower() == "figure":
+                            # remove ocr result in figure
+                            _, ocr_res = self.get_ocr_result_by_bbox(box, ocr_res)
+                    table_subs_of_img.append(table_batch_list)
+                    table_ocr_res.append(table_batch_ocr_res)
+                    seal_subs_of_img.append(seal_batch_list)
+                    seal_ocr_res.append(seal_batch_ocr_res)
+
+                # get table result
+                table_res = self.get_table_result(table_subs_of_img, table_ocr_res)
+                # get seal result
+                if seal_subs_of_img:
+                    pass
+
+                if self.chat_ocr:
+                    # chat ocr does not visualize table results in ocr result
+                    single_img_structure_res["ocr_result"] = OCRResult(ocr_res)
+                else:
+                    single_img_structure_res["ocr_result"] = single_img_ocr_res
+                single_img_structure_res["table_result"] = table_res
+                batch_structure_res.append(
+                    {"result": TableResult(single_img_structure_res)}
+                )
+        yield batch_structure_res
+
+    def get_ocr_result_by_bbox(self, box, ocr_res):
+        dt_polys_list = []
+        rec_text_list = []
+        unmatched_ocr_res = {"dt_polys": [], "rec_text": []}
+        for text_box, text_res in zip(ocr_res["dt_polys"], ocr_res["rec_text"]):
+            text_box_area = convert_4point2rect(text_box)
+            if is_inside(box, text_box_area):
+                dt_polys_list.append(text_box)
+                rec_text_list.append(text_res)
+            else:
+                unmatched_ocr_res["dt_polys"].append(text_box)
+                unmatched_ocr_res["rec_text"].append(text_res)
+        return (dt_polys_list, rec_text_list), unmatched_ocr_res
+
+    def get_table_result(self, input_img, table_ocr_res):
+        table_res_list = []
+        table_index = 0
+        for batch_input, batch_table_res, batch_ocr_res in zip(
+            input_img, self.table_predictor(input_img), table_ocr_res
+        ):
+            batch_res_list = []
+            for roi_img, table_res, ocr_res in zip(
+                batch_input, batch_table_res, batch_ocr_res
+            ):
+                single_table_res = table_res["result"]
+                single_table_box = single_table_res["bbox"]
+                ori_x, ori_y, _, _ = roi_img["box"]
+                ori_bbox_list = np.array(
+                    get_ori_coordinate_for_table(ori_x, ori_y, single_table_box),
+                    dtype=np.float32,
+                )
+                single_table_res["bbox"] = ori_bbox_list
+                html_res = self._match(single_table_res, ocr_res)
+                batch_res_list.append(
+                    StructureTableResult(
+                        {
+                            "img_path": roi_img["img_path"],
+                            "img_idx": table_index,
+                            "bbox": ori_bbox_list,
+                            "html": html_res,
+                            "structure": single_table_res["structure"],
+                        }
+                    )
+                )
+                table_index += 1
+            table_res_list.append(batch_res_list)
+        return table_res_list

+ 462 - 0
paddlex/inference/pipelines/table_recognition/utils.py

@@ -0,0 +1,462 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import re
+import copy
+
+__all__ = [
+    "TableMatch",
+    "convert_4point2rect",
+    "get_ori_coordinate_for_table",
+    "is_inside",
+]
+
+
+def deal_eb_token(master_token):
+    """
+    post process with <eb></eb>, <eb1></eb1>, ...
+    emptyBboxTokenDict = {
+        "[]": '<eb></eb>',
+        "[' ']": '<eb1></eb1>',
+        "['<b>', ' ', '</b>']": '<eb2></eb2>',
+        "['\\u2028', '\\u2028']": '<eb3></eb3>',
+        "['<sup>', ' ', '</sup>']": '<eb4></eb4>',
+        "['<b>', '</b>']": '<eb5></eb5>',
+        "['<i>', ' ', '</i>']": '<eb6></eb6>',
+        "['<b>', '<i>', '</i>', '</b>']": '<eb7></eb7>',
+        "['<b>', '<i>', ' ', '</i>', '</b>']": '<eb8></eb8>',
+        "['<i>', '</i>']": '<eb9></eb9>',
+        "['<b>', ' ', '\\u2028', ' ', '\\u2028', ' ', '</b>']": '<eb10></eb10>',
+    }
+    :param master_token:
+    :return:
+    """
+    master_token = master_token.replace("<eb></eb>", "<td></td>")
+    master_token = master_token.replace("<eb1></eb1>", "<td> </td>")
+    master_token = master_token.replace("<eb2></eb2>", "<td><b> </b></td>")
+    master_token = master_token.replace("<eb3></eb3>", "<td>\u2028\u2028</td>")
+    master_token = master_token.replace("<eb4></eb4>", "<td><sup> </sup></td>")
+    master_token = master_token.replace("<eb5></eb5>", "<td><b></b></td>")
+    master_token = master_token.replace("<eb6></eb6>", "<td><i> </i></td>")
+    master_token = master_token.replace("<eb7></eb7>", "<td><b><i></i></b></td>")
+    master_token = master_token.replace("<eb8></eb8>", "<td><b><i> </i></b></td>")
+    master_token = master_token.replace("<eb9></eb9>", "<td><i></i></td>")
+    master_token = master_token.replace(
+        "<eb10></eb10>", "<td><b> \u2028 \u2028 </b></td>"
+    )
+    return master_token
+
+
+def deal_bb(result_token):
+    """
+    In our opinion, <b></b> always occurs in <thead></thead> text's context.
+    This function will find out all tokens in <thead></thead> and insert <b></b> by manual.
+    :param result_token:
+    :return:
+    """
+    # find out <thead></thead> parts.
+    thead_pattern = "<thead>(.*?)</thead>"
+    if re.search(thead_pattern, result_token) is None:
+        return result_token
+    thead_part = re.search(thead_pattern, result_token).group()
+    origin_thead_part = copy.deepcopy(thead_part)
+
+    # check "rowspan" or "colspan" occur in <thead></thead> parts or not .
+    span_pattern = (
+        '<td rowspan="(\d)+" colspan="(\d)+">|<td colspan="(\d)+" rowspan="(\d)+">|<td rowspan'
+        '="(\d)+">|<td colspan="(\d)+">'
+    )
+    span_iter = re.finditer(span_pattern, thead_part)
+    span_list = [s.group() for s in span_iter]
+    has_span_in_head = True if len(span_list) > 0 else False
+
+    if not has_span_in_head:
+        # <thead></thead> not include "rowspan" or "colspan" branch 1.
+        # 1. replace <td> to <td><b>, and </td> to </b></td>
+        # 2. it is possible to predict text include <b> or </b> by Text-line recognition,
+        #    so we replace <b><b> to <b>, and </b></b> to </b>
+        thead_part = (
+            thead_part.replace("<td>", "<td><b>")
+            .replace("</td>", "</b></td>")
+            .replace("<b><b>", "<b>")
+            .replace("</b></b>", "</b>")
+        )
+    else:
+        # <thead></thead> include "rowspan" or "colspan" branch 2.
+        # Firstly, we deal rowspan or colspan cases.
+        # 1. replace > to ><b>
+        # 2. replace </td> to </b></td>
+        # 3. it is possible to predict text include <b> or </b> by Text-line recognition,
+        #    so we replace <b><b> to <b>, and </b><b> to </b>
+
+        # Secondly, deal ordinary cases like branch 1
+
+        # replace ">" to "<b>"
+        replaced_span_list = []
+        for sp in span_list:
+            replaced_span_list.append(sp.replace(">", "><b>"))
+        for sp, rsp in zip(span_list, replaced_span_list):
+            thead_part = thead_part.replace(sp, rsp)
+
+        # replace "</td>" to "</b></td>"
+        thead_part = thead_part.replace("</td>", "</b></td>")
+
+        # remove duplicated <b> by re.sub
+        mb_pattern = "(<b>)+"
+        single_b_string = "<b>"
+        thead_part = re.sub(mb_pattern, single_b_string, thead_part)
+
+        mgb_pattern = "(</b>)+"
+        single_gb_string = "</b>"
+        thead_part = re.sub(mgb_pattern, single_gb_string, thead_part)
+
+        # ordinary cases like branch 1
+        thead_part = thead_part.replace("<td>", "<td><b>").replace("<b><b>", "<b>")
+
+    # convert <tb><b></b></tb> back to <tb></tb>, empty cell has no <b></b>.
+    # but space cell(<tb> </tb>)  is suitable for <td><b> </b></td>
+    thead_part = thead_part.replace("<td><b></b></td>", "<td></td>")
+    # deal with duplicated <b></b>
+    thead_part = deal_duplicate_bb(thead_part)
+    # deal with isolate span tokens, which causes by wrong predict by structure prediction.
+    # eg.PMC5994107_011_00.png
+    thead_part = deal_isolate_span(thead_part)
+    # replace original result with new thead part.
+    result_token = result_token.replace(origin_thead_part, thead_part)
+    return result_token
+
+
+def deal_isolate_span(thead_part):
+    """
+    Deal with isolate span cases in this function.
+    It causes by wrong prediction in structure recognition model.
+    eg. predict <td rowspan="2"></td> to <td></td> rowspan="2"></b></td>.
+    :param thead_part:
+    :return:
+    """
+    # 1. find out isolate span tokens.
+    isolate_pattern = (
+        '<td></td> rowspan="(\d)+" colspan="(\d)+"></b></td>|'
+        '<td></td> colspan="(\d)+" rowspan="(\d)+"></b></td>|'
+        '<td></td> rowspan="(\d)+"></b></td>|'
+        '<td></td> colspan="(\d)+"></b></td>'
+    )
+    isolate_iter = re.finditer(isolate_pattern, thead_part)
+    isolate_list = [i.group() for i in isolate_iter]
+
+    # 2. find out span number, by step 1 results.
+    span_pattern = (
+        ' rowspan="(\d)+" colspan="(\d)+"|'
+        ' colspan="(\d)+" rowspan="(\d)+"|'
+        ' rowspan="(\d)+"|'
+        ' colspan="(\d)+"'
+    )
+    corrected_list = []
+    for isolate_item in isolate_list:
+        span_part = re.search(span_pattern, isolate_item)
+        spanStr_in_isolateItem = span_part.group()
+        # 3. merge the span number into the span token format string.
+        if spanStr_in_isolateItem is not None:
+            corrected_item = "<td{}></td>".format(spanStr_in_isolateItem)
+            corrected_list.append(corrected_item)
+        else:
+            corrected_list.append(None)
+
+    # 4. replace original isolated token.
+    for corrected_item, isolate_item in zip(corrected_list, isolate_list):
+        if corrected_item is not None:
+            thead_part = thead_part.replace(isolate_item, corrected_item)
+        else:
+            pass
+    return thead_part
+
+
+def deal_duplicate_bb(thead_part):
+    """
+    Deal duplicate <b> or </b> after replace.
+    Keep one <b></b> in a <td></td> token.
+    :param thead_part:
+    :return:
+    """
+    # 1. find out <td></td> in <thead></thead>.
+    td_pattern = (
+        '<td rowspan="(\d)+" colspan="(\d)+">(.+?)</td>|'
+        '<td colspan="(\d)+" rowspan="(\d)+">(.+?)</td>|'
+        '<td rowspan="(\d)+">(.+?)</td>|'
+        '<td colspan="(\d)+">(.+?)</td>|'
+        "<td>(.*?)</td>"
+    )
+    td_iter = re.finditer(td_pattern, thead_part)
+    td_list = [t.group() for t in td_iter]
+
+    # 2. is multiply <b></b> in <td></td> or not?
+    new_td_list = []
+    for td_item in td_list:
+        if td_item.count("<b>") > 1 or td_item.count("</b>") > 1:
+            # multiply <b></b> in <td></td> case.
+            # 1. remove all <b></b>
+            td_item = td_item.replace("<b>", "").replace("</b>", "")
+            # 2. replace <tb> -> <tb><b>, </tb> -> </b></tb>.
+            td_item = td_item.replace("<td>", "<td><b>").replace("</td>", "</b></td>")
+            new_td_list.append(td_item)
+        else:
+            new_td_list.append(td_item)
+
+    # 3. replace original thead part.
+    for td_item, new_td_item in zip(td_list, new_td_list):
+        thead_part = thead_part.replace(td_item, new_td_item)
+    return thead_part
+
+
+def distance(box_1, box_2):
+    """
+    compute the distance between two boxes
+
+    Args:
+        box_1 (list): first rectangle box,eg.(x1, y1, x2, y2)
+        box_2 (list): second rectangle box,eg.(x1, y1, x2, y2)
+
+    Returns:
+        int: the distance between two boxes
+
+    """
+    x1, y1, x2, y2 = box_1
+    x3, y3, x4, y4 = box_2
+    dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2)
+    dis_2 = abs(x3 - x1) + abs(y3 - y1)
+    dis_3 = abs(x4 - x2) + abs(y4 - y2)
+    return dis + min(dis_2, dis_3)
+
+
+def compute_iou(rec1, rec2):
+    """
+    computing IoU
+    Args:
+        rec1 (list): (x1, y1, x2, y2)
+        rec2 (list): (x1, y1, x2, y2)
+    Returns:
+        float: Intersection over Union
+    """
+    # computing area of each rectangles
+    S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
+    S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
+
+    # computing the sum_area
+    sum_area = S_rec1 + S_rec2
+
+    # find the each edge of intersect rectangle
+    left_line = max(rec1[0], rec2[0])
+    right_line = min(rec1[2], rec2[2])
+    top_line = max(rec1[1], rec2[1])
+    bottom_line = min(rec1[3], rec2[3])
+
+    # judge if there is an intersect
+    if left_line >= right_line or top_line >= bottom_line:
+        return 0.0
+    else:
+        intersect = (right_line - left_line) * (bottom_line - top_line)
+        return (intersect / (sum_area - intersect)) * 1.0
+
+
+def convert_4point2rect(bbox):
+    """
+    Convert 4 point coordinate to rectangle coordinate
+    Args:
+        bbox (list): list of 4 points, eg. [x1, y1, x2, y2,...] or [[x1,y1],[x2,y2],...]
+    """
+    if isinstance(bbox, list):
+        bbox = np.array(bbox)
+    if bbox.shape[0] == 8:
+        bbox = np.reshape(bbox, (4, 2))
+    x1 = min(bbox[:, 0])
+    y1 = min(bbox[:, 1])
+    x2 = max(bbox[:, 0])
+    y2 = max(bbox[:, 1])
+    return [x1, y1, x2, y2]
+
+
+def get_ori_coordinate_for_table(x, y, table_bbox):
+    """
+    get the original coordinate from Cropped image to Original image.
+    Args:
+        x (int): x coordinate of cropped image
+        y (int): y coordinate of cropped image
+        table_bbox (list): list of table bounding boxes, eg. [[x1, y1, x2, y2, x3, y3, x4, y4]]
+    Returns:
+        list: list of original coordinates, eg. [[x1, y1, x2, y2, x3, y3, x4, y4]]
+    """
+    bbox_list = []
+    for x1, y1, x2, y2, x3, y3, x4, y4 in table_bbox:
+        x1 = x + x1
+        y1 = y + y1
+        x2 = x + x2
+        y2 = y + y2
+        x3 = x + x3
+        y3 = y + y3
+        x4 = x + x4
+        y4 = y + y4
+        bbox_list.append([x1, y1, x2, y2, x3, y3, x4, y4])
+    return bbox_list
+
+
+def is_inside(target_box, text_box):
+    """
+    check if text box is inside target box
+    Args:
+        target_box (list): target box where we want to detect, eg. [x1, y1, x2, y2]
+        text_box (list): text box, eg. [x1, y1, x2, y2]
+    Returns:
+        bool: True if text box is inside target box
+    """
+
+    x1_1, y1_1, x2_1, y2_1 = target_box
+    x1_2, y1_2, x2_2, y2_2 = text_box
+
+    inter_x1 = max(x1_1, x1_2)
+    inter_y1 = max(y1_1, y1_2)
+    inter_x2 = min(x2_1, x2_2)
+    inter_y2 = min(y2_1, y2_2)
+
+    if inter_x1 < inter_x2 and inter_y1 < inter_y2:
+        inter_area = (inter_x2 - inter_x1) * (inter_y2 - inter_y1)
+    else:
+        inter_area = 0
+
+    area1 = (x2_1 - x1_1) * (y2_1 - y1_1)
+    area2 = (x2_2 - x1_2) * (y2_2 - y1_2)
+
+    union_area = area1 + area2 - inter_area
+
+    iou = inter_area / union_area if union_area != 0 else 0
+    return iou > 0
+
+
+class TableMatch(object):
+    """
+    match table html and ocr res
+    """
+
+    def __init__(self, filter_ocr_result=False):
+        self.filter_ocr_result = filter_ocr_result
+
+    def __call__(self, table_pred, ocr_pred):
+        structures = table_pred["structure"]
+        table_boxes = table_pred["bbox"]
+        ocr_dt_ploys = ocr_pred["dt_polys"]
+        ocr_text_res = ocr_pred["rec_text"]
+        if self.filter_ocr_result:
+            ocr_dt_ploys, ocr_text_res = self._filter_ocr_result(
+                table_boxes, ocr_dt_ploys, ocr_text_res
+            )
+        matched_index = self.metch_table_and_ocr(table_boxes, ocr_dt_ploys)
+        pred_html = self.get_html_result(matched_index, ocr_text_res, structures)
+        return pred_html
+
+    def metch_table_and_ocr(self, table_boxes, ocr_boxes):
+        """
+        match table bo
+
+        Args:
+            table_boxes (list): bbox for table, 4 points, [x1,y1,x2,y2,x3,y3,x4,y4]
+            ocr_boxes (list): bbox for ocr, 4 points, [[x1,y1],[x2,y2],[x3,y3],[x4,y4]]
+
+        Returns:
+            dict: matched dict, key is table index, value is ocr index
+        """
+        matched = {}
+        for i, ocr_box in enumerate(np.array(ocr_boxes)):
+            ocr_box = convert_4point2rect(ocr_box)
+            distances = []
+            for j, table_box in enumerate(table_boxes):
+                table_box = convert_4point2rect(table_box)
+                distances.append(
+                    (
+                        distance(table_box, ocr_box),
+                        1.0 - compute_iou(table_box, ocr_box),
+                    )
+                )  # compute iou and l1 distance
+            sorted_distances = distances.copy()
+            # select det box by iou and l1 distance
+            sorted_distances = sorted(
+                sorted_distances, key=lambda item: (item[1], item[0])
+            )
+            if distances.index(sorted_distances[0]) not in matched.keys():
+                matched[distances.index(sorted_distances[0])] = [i]
+            else:
+                matched[distances.index(sorted_distances[0])].append(i)
+        return matched
+
+    def get_html_result(self, matched_index, ocr_contents, pred_structures):
+        pred_html = []
+        td_index = 0
+        head_structure = pred_structures[0:3]
+        html = "".join(head_structure)
+        table_structure = pred_structures[3]
+        for tag in table_structure:
+            if "</td>" in tag:
+                if "<td></td>" == tag:
+                    pred_html.extend("<td>")
+                if td_index in matched_index.keys():
+                    b_with = False
+                    if (
+                        "<b>" in ocr_contents[matched_index[td_index][0]]
+                        and len(matched_index[td_index]) > 1
+                    ):
+                        b_with = True
+                        pred_html.extend("<b>")
+                    for i, td_index_index in enumerate(matched_index[td_index]):
+                        content = ocr_contents[td_index_index]
+                        if len(matched_index[td_index]) > 1:
+                            if len(content) == 0:
+                                continue
+                            if content[0] == " ":
+                                content = content[1:]
+                            if "<b>" in content:
+                                content = content[3:]
+                            if "</b>" in content:
+                                content = content[:-4]
+                            if len(content) == 0:
+                                continue
+                            if (
+                                i != len(matched_index[td_index]) - 1
+                                and " " != content[-1]
+                            ):
+                                content += " "
+                        pred_html.extend(content)
+                    if b_with:
+                        pred_html.extend("</b>")
+                if "<td></td>" == tag:
+                    pred_html.append("</td>")
+                else:
+                    pred_html.append(tag)
+                td_index += 1
+            else:
+                pred_html.append(tag)
+        html += "".join(pred_html)
+        end_structure = pred_structures[-3:]
+        html += "".join(end_structure)
+        return html
+
+    def _filter_ocr_result(self, pred_bboxes, dt_boxes, rec_res):
+        y1 = pred_bboxes[:, 1::2].min()
+        new_dt_boxes = []
+        new_rec_res = []
+
+        for box, rec in zip(dt_boxes, rec_res):
+            if np.max(box[1::2]) < y1:
+                continue
+            new_dt_boxes.append(box)
+            new_rec_res.append(rec)
+        return new_dt_boxes, new_rec_res

+ 2 - 2
paddlex/inference/predictors/table_recognition.py

@@ -19,11 +19,11 @@ from ...utils.func_register import FuncRegister
 from ...modules.table_recognition.model_list import MODELS
 from ..components import *
 from ..results import TableRecResult
-from .base import BasePredictor
+from .base import BasicPredictor
 from ..utils.process_hook import batchable_method
 
 
-class TablePredictor(BasePredictor):
+class TablePredictor(BasicPredictor):
     """table recognition predictor"""
 
     entities = MODELS

+ 1 - 1
paddlex/inference/results/__init__.py

@@ -16,7 +16,7 @@ from .base import BaseResult
 from .topk import TopkResult
 from .text_det import TextDetResult
 from .text_rec import TextRecResult
-from .table_rec import TableRecResult
+from .table_rec import TableRecResult, StructureTableResult, TableResult
 from .ocr import OCRResult
 from .det import DetResult
 from .seg import SegResult

+ 5 - 13
paddlex/inference/results/base.py

@@ -18,16 +18,10 @@ import numpy as np
 import json
 
 from ...utils import logging
+import numpy as np
 from ..utils.io import JsonWriter, ImageReader, ImageWriter
 
 
-class NumpyEncoder(json.JSONEncoder):
-    def default(self, obj):
-        if isinstance(obj, np.ndarray):
-            return obj.tolist()
-        return super(NumpyEncoder, self).default(obj)
-
-
 class BaseResult(dict):
     def __init__(self, data):
         super().__init__(data)
@@ -39,13 +33,13 @@ class BaseResult(dict):
     def save_to_json(self, save_path, indent=4, ensure_ascii=False):
         if not save_path.endswith(".json"):
             save_path = Path(save_path) / f"{Path(self['img_path']).stem}.json"
-        self._json_writer.write(
-            save_path, self, indent=4, ensure_ascii=False, cls=NumpyEncoder
-        )
+        self._json_writer.write(save_path, self, indent=4, ensure_ascii=False)
 
     def save_to_img(self, save_path):
         if not save_path.lower().endswith((".jpg", ".png")):
             save_path = Path(save_path) / f"{Path(self['img_path']).stem}.jpg"
+        else:
+            save_path = Path(save_path)
         res_img = self._get_res_img()
         if res_img is not None:
             self._img_writer.write(save_path.as_posix(), res_img)
@@ -54,9 +48,7 @@ class BaseResult(dict):
     def print(self, json_format=True, indent=4, ensure_ascii=False):
         str_ = self
         if json_format:
-            str_ = json.dumps(
-                str_, indent=indent, ensure_ascii=ensure_ascii, cls=NumpyEncoder
-            )
+            str_ = json.dumps(str_, indent=indent, ensure_ascii=ensure_ascii)
         logging.info(str_)
 
     def _check_res(self):

+ 0 - 3
paddlex/inference/results/det.py

@@ -44,8 +44,6 @@ def draw_box(img, np_boxes, labels):
     clsid2color = {}
     catid2fontcolor = {}
     color_list = get_colormap(rgb=True)
-    expect_boxes = np_boxes[:, 0] > -1
-    np_boxes = np_boxes[expect_boxes, :]
 
     for i, dt in enumerate(np_boxes):
         clsid, bbox, score = int(dt[0]), dt[2:], dt[1]
@@ -100,6 +98,5 @@ class DetResult(BaseResult):
 
         image = self._img_reader.read(img_path)
         image = draw_box(image, boxes, labels=labels)
-        self["boxes"] = boxes.tolist()
 
         return image

+ 1 - 1
paddlex/inference/results/ocr.py

@@ -38,7 +38,7 @@ class OCRResult(BaseResult):
     ):
         """draw ocr result"""
         boxes = self["dt_polys"]
-        txts = (self["rec_text"],)
+        txts = self["rec_text"]
         scores = self["rec_score"]
         img = self._img_reader.read(self["img_path"])
         image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))

+ 81 - 9
paddlex/inference/results/table_rec.py

@@ -12,14 +12,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from pathlib import Path
-import json
-import numpy as np
 import cv2
+import numpy as np
+from pathlib import Path
 
-from ...utils import logging
-from ..utils.io import JsonWriter, ImageWriter, ImageReader
 from .base import BaseResult
+from ...utils import logging
+from ..utils.io import HtmlWriter, XlsxWriter
 
 
 class TableRecResult(BaseResult):
@@ -55,12 +54,85 @@ class TableRecResult(BaseResult):
         return image
 
 
-class StructureResult(BaseResult):
-    """StructureResult"""
+class StructureTableResult(TableRecResult):
+    """StructureTableResult"""
 
     def __init__(self, data):
+        """__init__"""
         super().__init__(data)
         self._img_writer.set_backend("pillow")
+        self._html_writer = HtmlWriter()
+        self._xlsx_writer = XlsxWriter()
 
-    def _get_res_img(self):
-        return self._img_reader.read(self["img_path"])
+    def save_to_html(self, save_path):
+        """save_to_html"""
+        img_idx = self["img_idx"]
+        if not save_path.endswith(".html"):
+            if img_idx > 0:
+                save_path = (
+                    Path(save_path) / f"{Path(self['img_path']).stem}_{img_idx}.html"
+                )
+            else:
+                save_path = Path(save_path) / f"{Path(self['img_path']).stem}.html"
+        elif img_idx > 0:
+            save_path = Path(save_path).stem / f"_{img_idx}.html"
+        self._html_writer.write(save_path.as_posix(), self["html"])
+        logging.info(f"The result has been saved in {save_path}.")
+
+    def save_to_excel(self, save_path):
+        """save_to_excel"""
+        img_idx = self["img_idx"]
+        if not save_path.endswith(".xlsx"):
+            if img_idx > 0:
+                save_path = (
+                    Path(save_path) / f"{Path(self['img_path']).stem}_{img_idx}.xlsx"
+                )
+            else:
+                save_path = Path(save_path) / f"{Path(self['img_path']).stem}.xlsx"
+        elif img_idx > 0:
+            save_path = Path(save_path).stem / f"_{img_idx}.xlsx"
+        self._xlsx_writer.write(save_path.as_posix(), self["html"])
+        logging.info(f"The result has been saved in {save_path}.")
+
+    def save_to_img(self, save_path):
+        img_idx = self["img_idx"]
+        if not save_path.endswith((".jpg", ".png")):
+            if img_idx > 0:
+                save_path = (
+                    Path(save_path) / f"{Path(self['img_path']).stem}_{img_idx}.jpg"
+                )
+            else:
+                save_path = Path(save_path) / f"{Path(self['img_path']).stem}.jpg"
+        elif img_idx > 0:
+            save_path = Path(save_path).stem / f"_{img_idx}.jpg"
+        else:
+            save_path = Path(save_path)
+        res_img = self._get_res_img()
+        if res_img is not None:
+            self._img_writer.write(save_path.as_posix(), res_img)
+            logging.info(f"The result has been saved in {save_path}.")
+
+
+class TableResult(BaseResult):
+    """TableResult"""
+
+    def __init__(self, data):
+        """__init__"""
+        super().__init__(data)
+
+    def save_to_img(self, save_path):
+        if not save_path.lower().endswith((".jpg", ".png")):
+            img_path = self["img_path"]
+            save_path = Path(save_path) / f"{Path(img_path).stem}"
+        else:
+            save_path = Path(save_path).stem
+        layout_save_path = f"{save_path}_layout.jpg"
+        ocr_save_path = f"{save_path}_ocr.jpg"
+        table_save_path = f"{save_path}_table.jpg"
+        layout_result = self["layout_result"]
+        layout_result.save_to_img(layout_save_path)
+        ocr_result = self["ocr_result"]
+        ocr_result.save_to_img(ocr_save_path)
+        for batch_table_result in self["table_result"]:
+            for table_result in batch_table_result:
+                table_result.save_to_img(table_save_path)

+ 8 - 1
paddlex/inference/utils/io/__init__.py

@@ -14,4 +14,11 @@
 
 
 from .readers import ImageReader, VideoReader, ReaderType
-from .writers import ImageWriter, TextWriter, JsonWriter, WriterType
+from .writers import (
+    ImageWriter,
+    TextWriter,
+    JsonWriter,
+    WriterType,
+    HtmlWriter,
+    XlsxWriter,
+)

+ 374 - 0
paddlex/inference/utils/io/style.py

@@ -0,0 +1,374 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from openpyxl.cell import cell
+from openpyxl.styles import (
+    Font,
+    Alignment,
+    PatternFill,
+    NamedStyle,
+    Border,
+    Side,
+    Color,
+)
+from openpyxl.styles.fills import FILL_SOLID
+from openpyxl.styles.numbers import FORMAT_CURRENCY_USD_SIMPLE, FORMAT_PERCENTAGE
+from openpyxl.styles.colors import BLACK
+
+FORMAT_DATE_MMDDYYYY = "mm/dd/yyyy"
+
+
+def colormap(color):
+    """
+    Convenience for looking up known colors
+    """
+    cmap = {"black": BLACK}
+    return cmap.get(color, color)
+
+
+def style_string_to_dict(style):
+    """
+    Convert css style string to a python dictionary
+    """
+
+    def clean_split(string, delim):
+        """
+        Clean up a string by removing all spaces and splitting on delim.
+        """
+        return (s.strip() for s in string.split(delim))
+
+    styles = [clean_split(s, ":") for s in style.split(";") if ":" in s]
+    return dict(styles)
+
+
+def get_side(style, name):
+    """
+    get side
+    """
+    return {
+        "border_style": style.get("border-{}-style".format(name)),
+        "color": colormap(style.get("border-{}-color".format(name))),
+    }
+
+
+known_styles = {}
+
+
+def style_dict_to_named_style(style_dict, number_format=None):
+    """
+    Change css style (stored in a python dictionary) to openpyxl NamedStyle
+    """
+
+    style_and_format_string = str(
+        {
+            "style_dict": style_dict,
+            "parent": style_dict.parent,
+            "number_format": number_format,
+        }
+    )
+
+    if style_and_format_string not in known_styles:
+        # Font
+        font = Font(
+            bold=style_dict.get("font-weight") == "bold",
+            color=style_dict.get_color("color", None),
+            size=style_dict.get("font-size"),
+        )
+
+        # Alignment
+        alignment = Alignment(
+            horizontal=style_dict.get("text-align", "general"),
+            vertical=style_dict.get("vertical-align"),
+            wrap_text=style_dict.get("white-space", "nowrap") == "normal",
+        )
+
+        # Fill
+        bg_color = style_dict.get_color("background-color")
+        fg_color = style_dict.get_color("foreground-color", Color())
+        fill_type = style_dict.get("fill-type")
+        if bg_color and bg_color != "transparent":
+            fill = PatternFill(
+                fill_type=fill_type or FILL_SOLID,
+                start_color=bg_color,
+                end_color=fg_color,
+            )
+        else:
+            fill = PatternFill()
+
+        # Border
+        border = Border(
+            left=Side(**get_side(style_dict, "left")),
+            right=Side(**get_side(style_dict, "right")),
+            top=Side(**get_side(style_dict, "top")),
+            bottom=Side(**get_side(style_dict, "bottom")),
+            diagonal=Side(**get_side(style_dict, "diagonal")),
+            diagonal_direction=None,
+            outline=Side(**get_side(style_dict, "outline")),
+            vertical=None,
+            horizontal=None,
+        )
+
+        name = "Style {}".format(len(known_styles) + 1)
+
+        pyxl_style = NamedStyle(
+            name=name,
+            font=font,
+            fill=fill,
+            alignment=alignment,
+            border=border,
+            number_format=number_format,
+        )
+
+        known_styles[style_and_format_string] = pyxl_style
+
+    return known_styles[style_and_format_string]
+
+
+class StyleDict(dict):
+    """
+    It's like a dictionary, but it looks for items in the parent dictionary
+    """
+
+    def __init__(self, *args, **kwargs):
+        self.parent = kwargs.pop("parent", None)
+        super(StyleDict, self).__init__(*args, **kwargs)
+
+    def __getitem__(self, item):
+        if item in self:
+            return super(StyleDict, self).__getitem__(item)
+        elif self.parent:
+            return self.parent[item]
+        else:
+            raise KeyError("{} not found".format(item))
+
+    def __hash__(self):
+        return hash(tuple([(k, self.get(k)) for k in self._keys()]))
+
+    # Yielding the keys avoids creating unnecessary data structures
+    # and happily works with both python2 and python3 where the
+    # .keys() method is a dictionary_view in python3 and a list in python2.
+    def _keys(self):
+        yielded = set()
+        for k in self.keys():
+            yielded.add(k)
+            yield k
+        if self.parent:
+            for k in self.parent._keys():
+                if k not in yielded:
+                    yielded.add(k)
+                    yield k
+
+    def get(self, k, d=None):
+        try:
+            return self[k]
+        except KeyError:
+            return d
+
+    def get_color(self, k, d=None):
+        """
+        Strip leading # off colors if necessary
+        """
+        color = self.get(k, d)
+        if hasattr(color, "startswith") and color.startswith("#"):
+            color = color[1:]
+            if (
+                len(color) == 3
+            ):  # Premailers reduces colors like #00ff00 to #0f0, openpyxl doesn't like that
+                color = "".join(2 * c for c in color)
+        return color
+
+
+class Element(object):
+    """
+    Our base class for representing an html element along with a cascading style.
+    The element is created along with a parent so that the StyleDict that we store
+    can point to the parent's StyleDict.
+    """
+
+    def __init__(self, element, parent=None):
+        self.element = element
+        self.number_format = None
+        parent_style = parent.style_dict if parent else None
+        self.style_dict = StyleDict(
+            style_string_to_dict(element.get("style", "")), parent=parent_style
+        )
+        self._style_cache = None
+
+    def style(self):
+        """
+        Turn the css styles for this element into an openpyxl NamedStyle.
+        """
+        if not self._style_cache:
+            self._style_cache = style_dict_to_named_style(
+                self.style_dict, number_format=self.number_format
+            )
+        return self._style_cache
+
+    def get_dimension(self, dimension_key):
+        """
+        Extracts the dimension from the style dict of the Element and returns it as a float.
+        """
+        dimension = self.style_dict.get(dimension_key)
+        if dimension:
+            if dimension[-2:] in ["px", "em", "pt", "in", "cm"]:
+                dimension = dimension[:-2]
+            dimension = float(dimension)
+        return dimension
+
+
+class Table(Element):
+    """
+    The concrete implementations of Elements are semantically named for the types of elements we are interested in.
+    This defines a very concrete tree structure for html tables that we expect to deal with. I prefer this compared to
+    allowing Element to have an arbitrary number of children and dealing with an abstract element tree.
+    """
+
+    def __init__(self, table):
+        """
+        takes an html table object (from lxml)
+        """
+        super(Table, self).__init__(table)
+        table_head = table.find("thead")
+        self.head = (
+            TableHead(table_head, parent=self) if table_head is not None else None
+        )
+        table_body = table.find("tbody")
+        self.body = TableBody(
+            table_body if table_body is not None else table, parent=self
+        )
+
+
+class TableHead(Element):
+    """
+    This class maps to the `<th>` element of the html table.
+    """
+
+    def __init__(self, head, parent=None):
+        super(TableHead, self).__init__(head, parent=parent)
+        self.rows = [TableRow(tr, parent=self) for tr in head.findall("tr")]
+
+
+class TableBody(Element):
+    """
+    This class maps to the `<tbody>` element of the html table.
+    """
+
+    def __init__(self, body, parent=None):
+        super(TableBody, self).__init__(body, parent=parent)
+        self.rows = [TableRow(tr, parent=self) for tr in body.findall("tr")]
+
+
+class TableRow(Element):
+    """
+    This class maps to the `<tr>` element of the html table.
+    """
+
+    def __init__(self, tr, parent=None):
+        super(TableRow, self).__init__(tr, parent=parent)
+        self.cells = [
+            TableCell(cell, parent=self) for cell in tr.findall("th") + tr.findall("td")
+        ]
+
+
+def element_to_string(el):
+    """
+    element to string
+    """
+    return _element_to_string(el).strip()
+
+
+def _element_to_string(el):
+    """
+    element to string
+    """
+    string = ""
+
+    for x in el.iterchildren():
+        string += "\n" + _element_to_string(x)
+
+    text = el.text.strip() if el.text else ""
+    tail = el.tail.strip() if el.tail else ""
+
+    return text + string + "\n" + tail
+
+
+class TableCell(Element):
+    """
+    This class maps to the `<td>` element of the html table.
+    """
+
+    CELL_TYPES = {
+        "TYPE_STRING",
+        "TYPE_FORMULA",
+        "TYPE_NUMERIC",
+        "TYPE_BOOL",
+        "TYPE_CURRENCY",
+        "TYPE_PERCENTAGE",
+        "TYPE_NULL",
+        "TYPE_INLINE",
+        "TYPE_ERROR",
+        "TYPE_FORMULA_CACHE_STRING",
+        "TYPE_INTEGER",
+    }
+
+    def __init__(self, cell, parent=None):
+        super(TableCell, self).__init__(cell, parent=parent)
+        self.value = element_to_string(cell)
+        self.number_format = self.get_number_format()
+
+    def data_type(self):
+        """
+        get data type
+        """
+        cell_types = self.CELL_TYPES & set(self.element.get("class", "").split())
+        if cell_types:
+            if "TYPE_FORMULA" in cell_types:
+                # Make sure TYPE_FORMULA takes precedence over the other classes in the set.
+                cell_type = "TYPE_FORMULA"
+            elif cell_types & {"TYPE_CURRENCY", "TYPE_INTEGER", "TYPE_PERCENTAGE"}:
+                cell_type = "TYPE_NUMERIC"
+            else:
+                cell_type = cell_types.pop()
+        else:
+            cell_type = "TYPE_STRING"
+        return getattr(cell, cell_type)
+
+    def get_number_format(self):
+        """
+        get number format
+        """
+        if "TYPE_CURRENCY" in self.element.get("class", "").split():
+            return FORMAT_CURRENCY_USD_SIMPLE
+        if "TYPE_INTEGER" in self.element.get("class", "").split():
+            return "#,##0"
+        if "TYPE_PERCENTAGE" in self.element.get("class", "").split():
+            return FORMAT_PERCENTAGE
+        if "TYPE_DATE" in self.element.get("class", "").split():
+            return FORMAT_DATE_MMDDYYYY
+        if self.data_type() == cell.TYPE_NUMERIC:
+            try:
+                int(self.value)
+            except ValueError:
+                return "#,##0.##"
+            else:
+                return "#,##0"
+
+    def format(self, cell):
+        """
+        format
+        """
+        cell.style = self.style()
+        data_type = self.data_type()
+        if data_type:
+            cell.data_type = data_type

+ 149 - 0
paddlex/inference/utils/io/tablepyxl.py

@@ -0,0 +1,149 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+
+from lxml import html
+from openpyxl import Workbook
+from openpyxl.utils import get_column_letter
+from premailer import Premailer
+from .style import Table
+
+
+def string_to_int(s):
+    """
+    Convert a string to an integer
+    """
+    if s.isdigit():
+        return int(s)
+    return 0
+
+
+def get_Tables(doc):
+    """
+    Find all the tables in the doc
+    """
+    tree = html.fromstring(doc)
+    comments = tree.xpath("//comment()")
+    for comment in comments:
+        comment.drop_tag()
+    return [Table(table) for table in tree.xpath("//table")]
+
+
+def write_rows(worksheet, elem, row, column=1):
+    """
+    Writes every tr child element of elem to a row in the worksheet
+    returns the next row after all rows are written
+    """
+    from openpyxl.cell.cell import MergedCell
+
+    initial_column = column
+    for table_row in elem.rows:
+        for table_cell in table_row.cells:
+            cell = worksheet.cell(row=row, column=column)
+            while isinstance(cell, MergedCell):
+                column += 1
+                cell = worksheet.cell(row=row, column=column)
+
+            colspan = string_to_int(table_cell.element.get("colspan", "1"))
+            rowspan = string_to_int(table_cell.element.get("rowspan", "1"))
+            if rowspan > 1 or colspan > 1:
+                worksheet.merge_cells(
+                    start_row=row,
+                    start_column=column,
+                    end_row=row + rowspan - 1,
+                    end_column=column + colspan - 1,
+                )
+
+            cell.value = table_cell.value
+            table_cell.format(cell)
+            min_width = table_cell.get_dimension("min-width")
+            max_width = table_cell.get_dimension("max-width")
+
+            if colspan == 1:
+                # Initially, when iterating for the first time through the loop, the width of all the cells is None.
+                # As we start filling in contents, the initial width of the cell (which can be retrieved by:
+                # worksheet.column_dimensions[get_column_letter(column)].width) is equal to the width of the previous
+                # cell in the same column (i.e. width of A2 = width of A1)
+                width = max(
+                    worksheet.column_dimensions[get_column_letter(column)].width or 0,
+                    len(table_cell.value) + 2,
+                )
+                if max_width and width > max_width:
+                    width = max_width
+                elif min_width and width < min_width:
+                    width = min_width
+                worksheet.column_dimensions[get_column_letter(column)].width = width
+            column += colspan
+        row += 1
+        column = initial_column
+    return row
+
+
+def table_to_sheet(table, wb):
+    """
+    Takes a table and workbook and writes the table to a new sheet.
+    The sheet title will be the same as the table attribute name.
+    """
+    ws = wb.create_sheet(title=table.element.get("name"))
+    insert_table(table, ws, 1, 1)
+
+
+def document_to_workbook(doc, wb=None, base_url=None):
+    """
+    Takes a string representation of an html document and writes one sheet for
+    every table in the document.
+    The workbook is returned
+    """
+    if not wb:
+        wb = Workbook()
+        wb.remove(wb.active)
+
+    inline_styles_doc = Premailer(
+        doc, base_url=base_url, remove_classes=False
+    ).transform()
+    tables = get_Tables(inline_styles_doc)
+
+    for table in tables:
+        table_to_sheet(table, wb)
+
+    return wb
+
+
+def document_to_xl(doc, filename, base_url=None):
+    """
+    Takes a string representation of an html document and writes one sheet for
+    every table in the document. The workbook is written out to a file called filename
+    """
+    wb = document_to_workbook(doc, base_url=base_url)
+    wb.save(filename)
+
+
+def insert_table(table, worksheet, column, row):
+    """
+    Inserts a table into the worksheet at the specified column and row
+    """
+    if table.head:
+        row = write_rows(worksheet, table.head, row, column)
+    if table.body:
+        row = write_rows(worksheet, table.body, row, column)
+
+
+def insert_table_at_cell(table, cell):
+    """
+    Inserts a table at the location of an openpyxl Cell object.
+    """
+    ws = cell.parent
+    column, row = cell.column, cell.row
+    insert_table(table, ws, column, row)

+ 64 - 1
paddlex/inference/utils/io/writers.py

@@ -21,8 +21,16 @@ from pathlib import Path
 import cv2
 import numpy as np
 from PIL import Image
+from .tablepyxl import document_to_xl
 
-__all__ = ["ImageWriter", "TextWriter", "JsonWriter", "WriterType"]
+__all__ = [
+    "ImageWriter",
+    "TextWriter",
+    "JsonWriter",
+    "WriterType",
+    "HtmlWriter",
+    "XlsxWriter",
+]
 
 
 class WriterType(enum.Enum):
@@ -32,6 +40,8 @@ class WriterType(enum.Enum):
     VIDEO = 2
     TEXT = 3
     JSON = 4
+    HTML = 5
+    XLSX = 6
 
 
 class _BaseWriter(object):
@@ -139,6 +149,42 @@ class JsonWriter(_BaseWriter):
         return WriterType.JSON
 
 
+class HtmlWriter(_BaseWriter):
+    def __init__(self, backend="html", **bk_args):
+        super().__init__(backend=backend, **bk_args)
+
+    def write(self, out_path, obj, **bk_args):
+        return self._backend.write_obj(out_path, obj, **bk_args)
+
+    def _init_backend(self, bk_type, bk_args):
+        if bk_type == "html":
+            return HtmlWriterBackend(**bk_args)
+        else:
+            raise ValueError("Unsupported backend type")
+
+    def get_type(self):
+        """get type"""
+        return WriterType.HTML
+
+
+class XlsxWriter(_BaseWriter):
+    def __init__(self, backend="xlsx", **bk_args):
+        super().__init__(backend=backend, **bk_args)
+
+    def write(self, out_path, obj, **bk_args):
+        return self._backend.write_obj(out_path, obj, **bk_args)
+
+    def _init_backend(self, bk_type, bk_args):
+        if bk_type == "xlsx":
+            return XlsxWriterBackend(**bk_args)
+        else:
+            raise ValueError("Unsupported backend type")
+
+    def get_type(self):
+        """get type"""
+        return WriterType.XLSX
+
+
 class _BaseWriterBackend(object):
     """_BaseWriterBackend"""
 
@@ -167,6 +213,23 @@ class TextWriterBackend(_BaseWriterBackend):
             f.write(obj)
 
 
+class HtmlWriterBackend(_BaseWriterBackend):
+
+    def __init__(self, mode="w", encoding="utf-8"):
+        super().__init__()
+        self.mode = mode
+        self.encoding = encoding
+
+    def _write_obj(self, out_path, obj, **bk_args):
+        with open(out_path, mode=self.mode, encoding=self.encoding) as f:
+            f.write(obj)
+
+
+class XlsxWriterBackend(_BaseWriterBackend):
+    def _write_obj(self, out_path, obj, **bk_args):
+        document_to_xl(obj, out_path)
+
+
 class _ImageWriterBackend(_BaseWriterBackend):
     """_ImageWriterBackend"""