Browse Source

update table recognition pipeline

zhouchangda 1 year ago
parent
commit
7f910eb762

+ 32 - 23
paddlex/inference/components/task_related/det.py

@@ -19,24 +19,35 @@ from ...utils.io import ImageReader
 from ..base import BaseComponent
 
 
-def restructured_boxes(boxes, labels):
-    return [
-        {
-            "cls_id": int(box[0]),
-            "label": labels[int(box[0])],
-            "score": float(box[1]),
-            "coordinate": list(map(int, box[2:])),
-        }
-        for box in boxes
-    ]
+def restructured_boxes(boxes, labels, img_size):
+
+    box_list = []
+    w, h = img_size
+
+    for box in boxes:
+        xmin, ymin, xmax, ymax = list(map(int, box[2:]))
+        xmin = max(0, xmin)
+        ymin = max(0, ymin)
+        xmax = min(w, xmax)
+        ymax = min(h, ymax)
+        box_list.append(
+            {
+                "cls_id": int(box[0]),
+                "label": labels[int(box[0])],
+                "score": float(box[1]),
+                "coordinate": [xmin, ymin, xmax, ymax],
+            }
+        )
+
+    return box_list
 
 
 class DetPostProcess(BaseComponent):
     """Save Result Transform"""
 
-    INPUT_KEYS = ["img_path", "boxes"]
+    INPUT_KEYS = ["img_path", "boxes", "img_size"]
     OUTPUT_KEYS = ["boxes"]
-    DEAULT_INPUTS = {"boxes": "boxes"}
+    DEAULT_INPUTS = {"boxes": "boxes", "img_size": "ori_img_size"}
     DEAULT_OUTPUTS = {"boxes": "boxes"}
 
     def __init__(self, threshold=0.5, labels=None):
@@ -44,11 +55,11 @@ class DetPostProcess(BaseComponent):
         self.threshold = threshold
         self.labels = labels
 
-    def apply(self, boxes):
+    def apply(self, boxes, img_size):
         """apply"""
         expect_boxes = (boxes[:, 1] > self.threshold) & (boxes[:, 0] > -1)
         boxes = boxes[expect_boxes, :]
-        boxes = restructured_boxes(boxes, self.labels)
+        boxes = restructured_boxes(boxes, self.labels, img_size)
         result = {"boxes": boxes}
 
         return result
@@ -57,25 +68,23 @@ class DetPostProcess(BaseComponent):
 class CropByBoxes(BaseComponent):
     """Crop Image by Box"""
 
-    INPUT_KEYS = ["img_path", "boxes", "labels"]
+    YIELD_BATCH = False
+    INPUT_KEYS = ["img_path", "boxes"]
     OUTPUT_KEYS = ["img", "box", "label"]
-    DEAULT_INPUTS = {"img_path": "img_path", "boxes": "boxes", "labels": "labels"}
+    DEAULT_INPUTS = {"img_path": "img_path", "boxes": "boxes"}
     DEAULT_OUTPUTS = {"img": "img", "box": "box", "label": "label"}
 
     def __init__(self):
         super().__init__()
         self._reader = ImageReader(backend="opencv")
 
-    def apply(self, img_path, boxes, labels=None):
+    def apply(self, img_path, boxes):
         output_list = []
         img = self._reader.read(img_path)
         for bbox in boxes:
-            label_id = int(bbox[0])
-            box = bbox[2:]
-            if labels is not None:
-                label = labels[label_id]
-            else:
-                label = label_id
+            label_id = bbox["cls_id"]
+            box = bbox["coordinate"]
+            label = bbox.get("label", label_id)
             xmin, ymin, xmax, ymax = [int(i) for i in box]
             img_crop = img[ymin:ymax, xmin:xmax]
             output_list.append({"img": img_crop, "box": box, "label": label})

+ 26 - 11
paddlex/inference/components/task_related/table_rec.py

@@ -24,9 +24,13 @@ class TableLabelDecode(BaseComponent):
 
     ENABLE_BATCH = True
 
-    INPUT_KEYS = ["pred", "ori_img_size"]
+    INPUT_KEYS = ["pred", "img_size", "ori_img_size"]
     OUTPUT_KEYS = ["bbox", "structure", "structure_score"]
-    DEAULT_INPUTS = {"pred": "pred", "ori_img_size": "ori_img_size"}
+    DEAULT_INPUTS = {
+        "pred": "pred",
+        "img_size": "img_size",
+        "ori_img_size": "ori_img_size",
+    }
     DEAULT_OUTPUTS = {
         "bbox": "bbox",
         "structure": "structure",
@@ -73,7 +77,7 @@ class TableLabelDecode(BaseComponent):
             assert False, "unsupported type %s in get_beg_end_flag_idx" % beg_or_end
         return idx
 
-    def apply(self, pred, ori_img_size):
+    def apply(self, pred, img_size, ori_img_size):
         """apply"""
         bbox_preds, structure_probs = [], []
         for bbox_pred, stru_prob in pred:
@@ -83,7 +87,7 @@ class TableLabelDecode(BaseComponent):
         structure_probs = np.array(structure_probs)
 
         bbox_list, structure_str_list, structure_score = self.decode(
-            structure_probs, bbox_preds, ori_img_size
+            structure_probs, bbox_preds, img_size, ori_img_size
         )
         structure_str_list = [
             (
@@ -98,7 +102,7 @@ class TableLabelDecode(BaseComponent):
             for bbox, structure in zip(bbox_list, structure_str_list)
         ]
 
-    def decode(self, structure_probs, bbox_preds, shape_list):
+    def decode(self, structure_probs, bbox_preds, padding_size, ori_img_size):
         """convert text-label into text-index."""
         ignored_tokens = self.get_ignored_tokens()
         end_idx = self.dict[self.end_str]
@@ -122,11 +126,13 @@ class TableLabelDecode(BaseComponent):
                 text = self.character[char_idx]
                 if text in self.td_token:
                     bbox = bbox_preds[batch_idx, idx]
-                    bbox = self._bbox_decode(bbox, shape_list[batch_idx])
+                    bbox = self._bbox_decode(
+                        bbox, padding_size[batch_idx], ori_img_size[batch_idx]
+                    )
                     bbox_list.append(bbox.tolist())
                 structure_list.append(text)
                 score_list.append(structure_probs[batch_idx, idx])
-            structure_batch_list.append([structure_list])
+            structure_batch_list.append(structure_list)
             structure_score = np.mean(score_list)
             bbox_batch_list.append(bbox_list)
 
@@ -162,8 +168,17 @@ class TableLabelDecode(BaseComponent):
             bbox_batch_list.append(bbox_list)
         return bbox_batch_list, structure_batch_list
 
-    def _bbox_decode(self, bbox, shape):
-        w, h = shape[:2]
-        bbox[0::2] *= w
-        bbox[1::2] *= h
+    def _bbox_decode(self, bbox, padding_shape, ori_shape):
+
+        pad_w, pad_h = padding_shape
+        w, h = ori_shape
+        ratio_w = pad_w / w
+        ratio_h = pad_h / h
+        ratio = min(ratio_w, ratio_h)
+
+        bbox[0::2] *= pad_w
+        bbox[1::2] *= pad_h
+        bbox[0::2] /= ratio
+        bbox[1::2] /= ratio
+
         return bbox

+ 64 - 83
paddlex/inference/pipelines/table_recognition/table_recognition.py

@@ -33,72 +33,60 @@ class TableRecPipeline(BasePipeline):
         table_model,
         batch_size=1,
         device="gpu",
-        chat_ocr=False,
         predictor_kwargs=None,
     ):
         super().__init__(predictor_kwargs)
 
-        self.layout_predictor = self._create_model(
-            model=layout_model, device=device, batch_size=batch_size
-        )
-
+        self.layout_predictor = self._create_model(model=layout_model)
         self.ocr_pipeline = OCRPipeline(
             text_det_model,
             text_rec_model,
-            rec_batch_size=batch_size,
-            rec_device=device,
-            det_device=device,
             predictor_kwargs=predictor_kwargs,
         )
-        self.table_predictor = self._create_model(
-            model=table_model, device=device, batch_size=batch_size
-        )
+        self.table_predictor = self._create_model(model=table_model)
         self._crop_by_boxes = CropByBoxes()
         self._match = TableMatch(filter_ocr_result=False)
-        self.chat_ocr = chat_ocr
+        self.set_predictor(batch_size=batch_size, device=device)
+
+    def set_predictor(self, batch_size, device):
+        self.layout_predictor.set_predict(device=device, batch_size=batch_size)
+        self.ocr_pipeline.det_model.set_predict(device=device)
+        self.ocr_pipeline.rec_model.set_predict(device=device, batch_size=batch_size)
+        self.table_predictor.set_predict(device=device, batch_size=batch_size)
 
     def predict(self, x):
-        batch_structure_res = []
-        for batch_layout_pred, batch_ocr_pred in zip(
+        for layout_pred, ocr_pred in zip(
             self.layout_predictor(x), self.ocr_pipeline(x)
         ):
-            for layout_pred, ocr_pred in zip(batch_layout_pred, batch_ocr_pred):
-                single_img_res = {
-                    "img_path": "",
-                    "layout_result": {},
-                    "ocr_result": {},
-                    "table_result": [],
-                }
-                layout_res = layout_pred["result"]
-                # update layout result
-                single_img_res["img_path"] = layout_res["img_path"]
-                single_img_res["layout_result"] = layout_res
-                ocr_res = ocr_pred["result"]
-                all_subs_of_img = list(self._crop_by_boxes(layout_res))
-                # get cropped images with label 'table'
-                table_subs = []
-                for batch_subs in all_subs_of_img:
-                    table_sub_list = []
-                    for sub in batch_subs:
-                        box = sub["box"]
-                        if sub["label"].lower() == "table":
-                            table_sub_list.append(sub)
-                            _, ocr_res = self.get_ocr_result_by_bbox(box, ocr_res)
-                    table_subs.append(table_sub_list)
-                table_res, all_table_ocr_res = self.get_table_result(table_subs)
-                for batch_table_ocr_res in all_table_ocr_res:
-                    for table_ocr_res in batch_table_ocr_res:
-                        ocr_res["dt_polys"].extend(table_ocr_res["dt_polys"])
-                        ocr_res["rec_text"].extend(table_ocr_res["rec_text"])
-                        ocr_res["rec_score"].extend(table_ocr_res["rec_score"])
+            single_img_res = {
+                "img_path": "",
+                "layout_result": {},
+                "ocr_result": {},
+                "table_result": [],
+            }
+            # update layout result
+            single_img_res["img_path"] = layout_pred["img_path"]
+            single_img_res["layout_result"] = layout_pred
+            subs_of_img = list(self._crop_by_boxes(layout_pred))
+            # get cropped images with label "table"
+            table_subs = []
+            for sub in subs_of_img:
+                box = sub["box"]
+                if sub["label"].lower() == "table":
+                    table_subs.append(sub)
+                    _, ocr_res = self.get_related_ocr_result(box, ocr_pred)
+            table_res, all_table_ocr_res = self.get_table_result(table_subs)
+            for table_ocr_res in all_table_ocr_res:
+                ocr_res["dt_polys"].extend(table_ocr_res["dt_polys"])
+                ocr_res["rec_text"].extend(table_ocr_res["rec_text"])
+                ocr_res["rec_score"].extend(table_ocr_res["rec_score"])
 
-                single_img_res["table_result"] = table_res
-                single_img_res["ocr_result"] = OCRResult(ocr_res)
+            single_img_res["table_result"] = table_res
+            single_img_res["ocr_result"] = OCRResult(ocr_res)
 
-                batch_structure_res.append({"result": TableResult(single_img_res)})
-        yield batch_structure_res
+            yield TableResult(single_img_res)
 
-    def get_ocr_result_by_bbox(self, box, ocr_res):
+    def get_related_ocr_result(self, box, ocr_res):
         dt_polys_list = []
         rec_text_list = []
         score_list = []
@@ -116,44 +104,37 @@ class TableRecPipeline(BasePipeline):
                 unmatched_ocr_res["rec_score"].append(ocr_res["rec_score"][i])
         return (dt_polys_list, rec_text_list, score_list), unmatched_ocr_res
 
-    def get_table_result(self, input_img):
+    def get_table_result(self, input_imgs):
         table_res_list = []
         ocr_res_list = []
         table_index = 0
-        for batch_input, batch_table_pred, batch_ocr_pred in zip(
-            input_img, self.table_predictor(input_img), self.ocr_pipeline(input_img)
+        img_list = [img["img"] for img in input_imgs]
+        for input_img, table_pred, ocr_pred in zip(
+            input_imgs, self.table_predictor(img_list), self.ocr_pipeline(img_list)
         ):
-            batch_table_res = []
-            batch_ocr_res = []
-            for input, table_pred, ocr_pred in zip(
-                batch_input, batch_table_pred, batch_ocr_pred
-            ):
-                single_table_res = table_pred["result"]
-                ocr_res = ocr_pred["result"]
-                single_table_box = single_table_res["bbox"]
-                ori_x, ori_y, _, _ = input["box"]
-                ori_bbox_list = np.array(
-                    get_ori_coordinate_for_table(ori_x, ori_y, single_table_box),
-                    dtype=np.float32,
-                )
-                ori_ocr_bbox_list = np.array(
-                    get_ori_coordinate_for_table(ori_x, ori_y, ocr_res["dt_polys"]),
-                    dtype=np.float32,
-                )
-                ocr_res["dt_polys"] = ori_ocr_bbox_list
-                html_res = self._match(single_table_res, ocr_res)
-                batch_table_res.append(
-                    StructureTableResult(
-                        {
-                            "img_path": input["img_path"],
-                            "bbox": ori_bbox_list,
-                            "img_idx": table_index,
-                            "html": html_res,
-                        }
-                    )
+            single_table_box = table_pred["bbox"]
+            ori_x, ori_y, _, _ = input_img["box"]
+            ori_bbox_list = np.array(
+                get_ori_coordinate_for_table(ori_x, ori_y, single_table_box),
+                dtype=np.float32,
+            )
+            ori_ocr_bbox_list = np.array(
+                get_ori_coordinate_for_table(ori_x, ori_y, ocr_pred["dt_polys"]),
+                dtype=np.float32,
+            )
+            html_res = self._match(table_pred, ocr_pred)
+            ocr_pred["dt_polys"] = ori_ocr_bbox_list
+            table_res_list.append(
+                StructureTableResult(
+                    {
+                        "img_path": input_img["img_path"],
+                        "layout_bbox": [int(x) for x in input_img["box"]],
+                        "bbox": ori_bbox_list,
+                        "img_idx": table_index,
+                        "html": html_res,
+                    }
                 )
-                batch_ocr_res.append(ocr_res)
-                table_index += 1
-            table_res_list.append(batch_table_res)
-            ocr_res_list.append(batch_ocr_res)
+            )
+            ocr_res_list.append(ocr_pred)
+            table_index += 1
         return table_res_list, ocr_res_list

+ 1 - 1
paddlex/inference/pipelines/table_recognition/utils.py

@@ -396,7 +396,7 @@ class TableMatch(object):
         td_index = 0
         head_structure = pred_structures[0:3]
         html = "".join(head_structure)
-        table_structure = pred_structures[3]
+        table_structure = pred_structures[3:-3]
         for tag in table_structure:
             if "</td>" in tag:
                 if "<td></td>" == tag:

+ 2 - 1
paddlex/inference/results/base.py

@@ -17,6 +17,7 @@ from pathlib import Path
 import numpy as np
 import json
 
+from copy import deepcopy
 from ...utils import logging
 from ..utils.io import JsonWriter, ImageReader, ImageWriter
 
@@ -36,7 +37,7 @@ def format_data(obj):
 
 class BaseResult(dict):
     def __init__(self, data):
-        super().__init__(format_data(data))
+        super().__init__(format_data(deepcopy(data)))
         self._check_res()
         self._json_writer = JsonWriter()
         self._img_reader = ImageReader(backend="opencv")

+ 2 - 3
paddlex/inference/results/table_rec.py

@@ -133,6 +133,5 @@ class TableResult(BaseResult):
         layout_result.save_to_img(layout_save_path)
         ocr_result = self["ocr_result"]
         ocr_result.save_to_img(ocr_save_path)
-        for batch_table_result in self["table_result"]:
-            for table_result in batch_table_result:
-                table_result.save_to_img(table_save_path)
+        for table_result in self["table_result"]:
+            table_result.save_to_img(table_save_path)

+ 10 - 2
paddlex/inference/utils/io/__init__.py

@@ -13,5 +13,13 @@
 # limitations under the License.
 
 
-from .readers import ReaderType, ImageReader, VideoReader, TSReader
-from .writers import WriterType, ImageWriter, TextWriter, JsonWriter, TSWriter, HtmlWriter, XlsxWriter
+from .readers import ReaderType, ImageReader, VideoReader, TSReader, PDFReader
+from .writers import (
+    WriterType,
+    ImageWriter,
+    TextWriter,
+    JsonWriter,
+    TSWriter,
+    HtmlWriter,
+    XlsxWriter,
+)

+ 39 - 1
paddlex/inference/utils/io/readers.py

@@ -16,10 +16,12 @@
 import enum
 import itertools
 import cv2
+import fitz
 from PIL import Image, ImageOps
 import pandas as pd
+import numpy as np
 
-__all__ = ["ReaderType", "ImageReader", "VideoReader", "TSReader"]
+__all__ = ["ReaderType", "ImageReader", "VideoReader", "TSReader", "PDFReader"]
 
 
 class ReaderType(enum.Enum):
@@ -30,6 +32,7 @@ class ReaderType(enum.Enum):
     POINT_CLOUD = 3
     JSON = 4
     TS = 5
+    PDF = 6
 
 
 class _BaseReader(object):
@@ -71,6 +74,22 @@ class _BaseReader(object):
         return {}
 
 
+class PDFReader(_BaseReader):
+    """PDFReader"""
+
+    def __init__(self, backend="fitz", **bk_args):
+        super().__init__(backend, **bk_args)
+
+    def read(self, in_path):
+        return self._backend.read_file(in_path)
+
+    def _init_backend(self, bk_type, bk_args):
+        return PDFReaderBackend(**bk_args)
+
+    def get_type(self):
+        return ReaderType.PDF
+
+
 class ImageReader(_BaseReader):
     """ImageReader"""
 
@@ -180,6 +199,25 @@ class PILImageReaderBackend(_ImageReaderBackend):
         return ImageOps.exif_transpose(Image.open(in_path))
 
 
+class PDFReaderBackend(_BaseReaderBackend):
+
+    def __init__(self, rotate=0, zoom_x=2.0, zoom_y=2.0):
+        super().__init__()
+        print(rotate)
+        self.mat = fitz.Matrix(zoom_x, zoom_y).prerotate(rotate)
+
+    def read_file(self, in_path):
+        images = []
+        for page in fitz.open(in_path):
+            pix = page.get_pixmap(matrix=self.mat, alpha=False)
+            getpngdata = pix.tobytes(output="png")
+            # decode as np.uint8
+            image_array = np.frombuffer(getpngdata, dtype=np.uint8)
+            img_cv = cv2.imdecode(image_array, cv2.IMREAD_ANYCOLOR)
+            images.append(img_cv)
+        return images
+
+
 class _VideoReaderBackend(_BaseReaderBackend):
     """_VideoReaderBackend"""
 

+ 11 - 1
paddlex/inference/utils/io/writers.py

@@ -25,7 +25,15 @@ import pandas as pd
 from .tablepyxl import document_to_xl
 
 
-__all__ = ["WriterType", "ImageWriter", "TextWriter", "JsonWriter", "TSWriter", "HtmlWriter", "XlsxWriter"]
+__all__ = [
+    "WriterType",
+    "ImageWriter",
+    "TextWriter",
+    "JsonWriter",
+    "TSWriter",
+    "HtmlWriter",
+    "XlsxWriter",
+]
 
 
 class WriterType(enum.Enum):
@@ -261,6 +269,8 @@ class PILImageWriterBackend(_ImageWriterBackend):
             img = Image.fromarray(obj)
         else:
             raise TypeError("Unsupported object type")
+        if len(img.getbands()) == 4:
+            self.format = "PNG"
         return img.save(out_path, format=self.format)