Browse Source

new table structure recognition models

Liu Jiaxuan 10 months ago
parent
commit
459a8e6d6a
22 changed files with 1647 additions and 1 deletions
  1. 40 0
      paddlex/configs/modules/table_cells_detection/RT-DETR-L_wired_table_cell_det.yaml
  2. 40 0
      paddlex/configs/modules/table_cells_detection/RT-DETR-L_wireless_table_cell_det.yaml
  3. 41 0
      paddlex/configs/modules/table_classification/PP-LCNet_x1_0_table_cls.yaml
  4. 39 0
      paddlex/configs/modules/table_structure_recognition/SLANeXt_wired.yaml
  5. 39 0
      paddlex/configs/modules/table_structure_recognition/SLANeXt_wireless.yaml
  6. 1 0
      paddlex/inference/models_new/__init__.py
  7. 15 0
      paddlex/inference/models_new/table_structure_recognition/__init__.py
  8. 170 0
      paddlex/inference/models_new/table_structure_recognition/predictor.py
  9. 240 0
      paddlex/inference/models_new/table_structure_recognition/processors.py
  10. 108 0
      paddlex/inference/models_new/table_structure_recognition/result.py
  11. 1 0
      paddlex/modules/image_classification/model_list.py
  12. 2 0
      paddlex/modules/object_detection/model_list.py
  13. 2 0
      paddlex/modules/table_recognition/model_list.py
  14. 10 0
      paddlex/repo_apis/PaddleClas_api/cls/register.py
  15. 142 0
      paddlex/repo_apis/PaddleClas_api/configs/PP-LCNet_x1_0_table_cls.yaml
  16. 173 0
      paddlex/repo_apis/PaddleDetection_api/configs/RT-DETR-L_wired_table_cell_det.yaml
  17. 173 0
      paddlex/repo_apis/PaddleDetection_api/configs/RT-DETR-L_wireless_table_cell_det.yaml
  18. 2 0
      paddlex/repo_apis/PaddleDetection_api/object_det/official_categories.py
  19. 33 1
      paddlex/repo_apis/PaddleDetection_api/object_det/register.py
  20. 179 0
      paddlex/repo_apis/PaddleOCR_api/configs/SLANeXt_wired.yaml
  21. 179 0
      paddlex/repo_apis/PaddleOCR_api/configs/SLANeXt_wireless.yaml
  22. 18 0
      paddlex/repo_apis/PaddleOCR_api/table_rec/register.py

+ 40 - 0
paddlex/configs/modules/table_cells_detection/RT-DETR-L_wired_table_cell_det.yaml

@@ -0,0 +1,40 @@
+Global:
+  model: RT-DETR-L_wired_table_cell_det
+  mode: check_dataset # check_dataset/train/evaluate/predict
+  dataset_dir: "/paddle/dataset/paddlex/cells_det/cells_det_coco_examples"
+  device: gpu:0,1,2,3
+  output: "output"
+
+CheckDataset:
+  convert:
+    enable: False
+    src_dataset_type: null
+  split:
+    enable: False
+    train_percent: null
+    val_percent: null
+
+Train:
+  num_classes: 1
+  epochs_iters: 40
+  batch_size: 2
+  learning_rate: 0.0001
+  pretrain_weight_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/RT-DETR-L_wired_table_cell_det_pretrained.pdparams"
+  warmup_steps: 100
+  resume_path: null
+  log_interval: 10
+  eval_interval: 1
+
+Evaluate:
+  weight_path: "output/best_model/best_model.pdparams"
+  log_interval: 10
+
+Export:
+  weight_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/RT-DETR-L_wired_table_cell_det_pretrained.pdparams"
+
+Predict:
+  batch_size: 1
+  model_dir: "output/best_model/inference"
+  input: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/table_recognition.jpg"
+  kernel_option:
+    run_mode: paddle

+ 40 - 0
paddlex/configs/modules/table_cells_detection/RT-DETR-L_wireless_table_cell_det.yaml

@@ -0,0 +1,40 @@
+Global:
+  model: RT-DETR-L_wireless_table_cell_det
+  mode: check_dataset # check_dataset/train/evaluate/predict
+  dataset_dir: "/paddle/dataset/paddlex/cells_det/cells_det_coco_examples"
+  device: gpu:0,1,2,3
+  output: "output"
+
+CheckDataset:
+  convert:
+    enable: False
+    src_dataset_type: null
+  split:
+    enable: False
+    train_percent: null
+    val_percent: null
+
+Train:
+  num_classes: 1
+  epochs_iters: 40
+  batch_size: 2
+  learning_rate: 0.0001
+  pretrain_weight_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/RT-DETR-L_wireless_table_cell_det_pretrained.pdparams"
+  warmup_steps: 100
+  resume_path: null
+  log_interval: 10
+  eval_interval: 1
+
+Evaluate:
+  weight_path: "output/best_model/best_model.pdparams"
+  log_interval: 10
+
+Export:
+  weight_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/RT-DETR-L_wireless_table_cell_det_pretrained.pdparams"
+
+Predict:
+  batch_size: 1
+  model_dir: "output/best_model/inference"
+  input: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/table_recognition.jpg"
+  kernel_option:
+    run_mode: paddle

+ 41 - 0
paddlex/configs/modules/table_classification/PP-LCNet_x1_0_table_cls.yaml

@@ -0,0 +1,41 @@
+Global:
+  model: PP-LCNet_x1_0_table_cls
+  mode: check_dataset # check_dataset/train/evaluate/predict/predict
+  dataset_dir: "/paddle/dataset/paddlex/cls/table_cls_examples"
+  device: gpu:0,1,2,3
+  output: "output"
+
+CheckDataset:
+  convert: 
+    enable: False
+    src_dataset_type: null
+  split: 
+    enable: False
+    train_percent: null
+    val_percent: null
+
+Train:
+  num_classes: 2
+  epochs_iters: 20
+  batch_size: 128
+  learning_rate: 0.1
+  pretrain_weight_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/PP-LCNet_x1_0_table_cls_pretrained.pdparams"
+  warmup_steps: 5
+  resume_path: null
+  log_interval: 1
+  eval_interval: 1
+  save_interval: 1
+
+Evaluate:
+  weight_path: "output/best_model/best_model.pdparams"
+  log_interval: 1
+
+Export:
+  weight_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/PP-LCNet_x1_0_table_cls_pretrained.pdparams"
+
+Predict:
+  batch_size: 1
+  model_dir: "output/best_model/inference"
+  input: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/table_recognition.jpg"
+  kernel_option:
+    run_mode: paddle

+ 39 - 0
paddlex/configs/modules/table_structure_recognition/SLANeXt_wired.yaml

@@ -0,0 +1,39 @@
+Global:
+  model: SLANeXt_wired
+  mode: check_dataset # check_dataset/train/evaluate/predict
+  dataset_dir: "/paddle/dataset/paddlex/table_rec/table_rec_dataset_examples"
+  device: gpu:0,1,2,3
+  output: "output"
+
+CheckDataset:
+  convert:
+    enable: False
+    src_dataset_type: null
+  split:
+    enable: False
+    train_percent: null
+    val_percent: null
+
+Train:
+  epochs_iters: 10
+  batch_size: 16
+  learning_rate: 0.001
+  pretrain_weight_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/SLANeXt_wired_pretrained.pdparams"
+  resume_path: null
+  log_interval: 20
+  eval_interval: 1
+  save_interval: 1
+
+Evaluate:
+  weight_path: "output/best_accuracy/best_accuracy.pdparams"
+  log_interval: 1
+
+Export:
+  weight_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/SLANeXt_wired_pretrained.pdparams"
+
+Predict:
+  batch_size: 1
+  model_dir: "output/best_accuracy/inference"
+  input: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/table_recognition.jpg"
+  kernel_option:
+    run_mode: paddle

+ 39 - 0
paddlex/configs/modules/table_structure_recognition/SLANeXt_wireless.yaml

@@ -0,0 +1,39 @@
+Global:
+  model: SLANeXt_wireless
+  mode: check_dataset # check_dataset/train/evaluate/predict
+  dataset_dir: "/paddle/dataset/paddlex/table_rec/table_rec_dataset_examples"
+  device: gpu:0,1,2,3
+  output: "output"
+
+CheckDataset:
+  convert:
+    enable: False
+    src_dataset_type: null
+  split:
+    enable: False
+    train_percent: null
+    val_percent: null
+
+Train:
+  epochs_iters: 10
+  batch_size: 16
+  learning_rate: 0.001
+  pretrain_weight_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/SLANeXt_wireless_pretrained.pdparams"
+  resume_path: null
+  log_interval: 20
+  eval_interval: 1
+  save_interval: 1
+
+Evaluate:
+  weight_path: "output/best_accuracy/best_accuracy.pdparams"
+  log_interval: 1
+
+Export:
+  weight_path: "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_pretrained_model/SLANeXt_wireless_pretrained.pdparams"
+
+Predict:
+  batch_size: 1
+  model_dir: "output/best_accuracy/inference"
+  input: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/table_recognition.jpg"
+  kernel_option:
+    run_mode: paddle

+ 1 - 0
paddlex/inference/models_new/__init__.py

@@ -24,6 +24,7 @@ from .image_classification import ClasPredictor
 from .object_detection import DetPredictor
 from .text_detection import TextDetPredictor
 from .text_recognition import TextRecPredictor
+from .table_structure_recognition import TablePredictor
 from .formula_recognition import FormulaRecPredictor
 from .instance_segmentation import InstanceSegPredictor
 from .semantic_segmentation import SegPredictor

+ 15 - 0
paddlex/inference/models_new/table_structure_recognition/__init__.py

@@ -0,0 +1,15 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .predictor import TablePredictor

+ 170 - 0
paddlex/inference/models_new/table_structure_recognition/predictor.py

@@ -0,0 +1,170 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Union, Dict, List, Tuple
+import numpy as np
+
+from ....utils.func_register import FuncRegister
+from ....modules.table_recognition.model_list import MODELS
+from ...common.batch_sampler import ImageBatchSampler
+from ...common.reader import ReadImage
+from ..common import (
+    Resize,
+    ResizeByLong,
+    Normalize,
+    ToCHWImage,
+    ToBatch,
+    StaticInfer,
+)
+from ..base import BasicPredictor
+from .processors import Pad, TableLabelDecode
+from .result import TableRecResult
+
+
+class TablePredictor(BasicPredictor):
+    entities = MODELS
+
+    _FUNC_MAP = {}
+    register = FuncRegister(_FUNC_MAP)
+
+    def __init__(self, *args: List, **kwargs: Dict) -> None:
+        super().__init__(*args, **kwargs)
+        self.preprocessors, self.infer, self.postprocessors = self._build()
+
+    def _build_batch_sampler(self) -> ImageBatchSampler:
+        return ImageBatchSampler()
+
+    def _get_result_class(self) -> type:
+        return TableRecResult
+
+    def _build(self) -> Tuple:
+        preprocessors = []
+        for cfg in self.config["PreProcess"]["transform_ops"]:
+            tf_key = list(cfg.keys())[0]
+            func = self._FUNC_MAP[tf_key]
+            args = cfg.get(tf_key, {})
+            op = func(self, **args) if args else func(self)
+            if op:
+                preprocessors.append(op)
+        preprocessors.append(ToBatch())
+
+        infer = StaticInfer(
+            model_dir=self.model_dir,
+            model_prefix=self.MODEL_FILE_PREFIX,
+            option=self.pp_option,
+        )
+
+        postprocessors = TableLabelDecode(
+            model_name="SLANet",
+            merge_no_span_structure=self.config["PreProcess"]["transform_ops"][1][
+                "TableLabelEncode"
+            ]["merge_no_span_structure"],
+            dict_character=self.config["PostProcess"]["character_dict"],
+        )
+        return preprocessors, infer, postprocessors
+
+    def process(self, batch_data: List[Union[str, np.ndarray]]) -> Dict[str, Any]:
+        """
+        Process a batch of data through the preprocessing, inference, and postprocessing.
+
+        Args:
+            batch_data (List[Union[str, np.ndarray], ...]): A batch of input data (e.g., image file paths).
+
+        Returns:
+            dict: A dictionary containing the input path, raw image, class IDs, scores, and label names for every instance of the batch. Keys include 'input_path', 'input_img', 'class_ids', 'scores', and 'label_names'.
+        """
+        batch_raw_imgs = self.preprocessors[0](imgs=batch_data)  # ReadImage
+        ori_shapes = []
+        for s in range(len(batch_raw_imgs)):
+            ori_shapes.append([batch_raw_imgs[s].shape[1], batch_raw_imgs[s].shape[0]])
+        batch_imgs = self.preprocessors[1](imgs=batch_raw_imgs)  # ResizeByLong
+        batch_imgs = self.preprocessors[2](imgs=batch_imgs)  # Normalize
+        pad_results = self.preprocessors[3](imgs=batch_imgs)  # Pad
+        pad_imgs = []
+        padding_sizes = []
+        for pad_img, padding_size in pad_results:
+            pad_imgs.append(pad_img)
+            padding_sizes.append(padding_size)
+        batch_imgs = self.preprocessors[4](imgs=pad_imgs)  # ToCHWImage
+        x = self.preprocessors[5](imgs=batch_imgs)  # ToBatch
+
+        batch_preds = self.infer(x=x)
+
+        table_result = self.postprocessors(
+            pred=batch_preds,
+            img_size=padding_sizes,
+            ori_img_size=ori_shapes,
+        )
+
+        table_result_bbox = []
+        table_result_structure = []
+        table_result_structure_score = []
+        for i in range(len(table_result)):
+            table_result_bbox.append(table_result[i]["bbox"])
+            table_result_structure.append(table_result[i]["structure"])
+            table_result_structure_score.append(table_result[i]["structure_score"])
+
+        final_result = {
+            "input_path": batch_data,
+            "input_img": batch_raw_imgs,
+            "bbox": table_result_bbox,
+            "structure": table_result_structure,
+            "structure_score": table_result_structure_score,
+        }
+
+        return final_result
+
+    @register("DecodeImage")
+    def build_readimg(self, channel_first=False, img_mode="BGR"):
+        assert channel_first is False
+        assert img_mode == "BGR"
+        return ReadImage(format=img_mode)
+
+    @register("TableLabelEncode")
+    def foo(self, *args, **kwargs):
+        return None
+
+    @register("TableBoxEncode")
+    def foo(self, *args, **kwargs):
+        return None
+
+    @register("ResizeTableImage")
+    def build_resize_table(self, max_len=488, resize_bboxes=True):
+        return ResizeByLong(target_long_edge=max_len)
+
+    @register("NormalizeImage")
+    def build_normalize(
+        self,
+        mean=[0.485, 0.456, 0.406],
+        std=[0.229, 0.224, 0.225],
+        scale=1 / 255,
+        order="hwc",
+    ):
+        return Normalize(mean=mean, std=std)
+
+    @register("PaddingTableImage")
+    def build_padding(self, size=[488, 448], pad_value=0):
+        return Pad(target_size=size[0], val=pad_value)
+
+    @register("ToCHWImage")
+    def build_to_chw(self):
+        return ToCHWImage()
+
+    @register("KeepKeys")
+    def foo(self, *args, **kwargs):
+        return None
+
+    def _pack_res(self, single):
+        keys = ["input_path", "bbox", "structure"]
+        return TableRecResult({key: single[key] for key in keys})

+ 240 - 0
paddlex/inference/models_new/table_structure_recognition/processors.py

@@ -0,0 +1,240 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import cv2
+import numpy as np
+from numpy import ndarray
+from ..common.vision import funcs as F
+
+
+class Pad:
+    """Pad the image."""
+
+    INPUT_KEYS = "img"
+    OUTPUT_KEYS = ["img", "img_size"]
+    DEAULT_INPUTS = {"img": "img"}
+    DEAULT_OUTPUTS = {"img": "img", "img_size": "img_size"}
+
+    def __init__(self, target_size, val=127.5):
+        """
+        Initialize the instance.
+
+        Args:
+            target_size (list|tuple|int): Target width and height of the image after
+                padding.
+            val (float, optional): Value to fill the padded area. Default: 127.5.
+        """
+        super().__init__()
+
+        if isinstance(target_size, int):
+            target_size = [target_size, target_size]
+        self.target_size = target_size
+
+        self.val = val
+
+    def apply(self, img):
+        """apply"""
+        h, w = img.shape[:2]
+        tw, th = self.target_size
+        ph = th - h
+        pw = tw - w
+
+        if ph < 0 or pw < 0:
+            raise ValueError(
+                f"Input image ({w}, {h}) smaller than the target size ({tw}, {th})."
+            )
+        else:
+            img = F.pad(img, pad=(0, ph, 0, pw), val=self.val)
+
+        return [img, [img.shape[1], img.shape[0]]]
+
+    def __call__(self, imgs):
+        """apply"""
+        return [self.apply(img) for img in imgs]
+
+
+class TableLabelDecode:
+    """decode the table model outputs(probs) to character str"""
+
+    ENABLE_BATCH = True
+
+    INPUT_KEYS = ["pred", "img_size", "ori_img_size"]
+    OUTPUT_KEYS = ["bbox", "structure", "structure_score"]
+    DEAULT_INPUTS = {
+        "pred": "pred",
+        "img_size": "img_size",
+        "ori_img_size": "ori_img_size",
+    }
+    DEAULT_OUTPUTS = {
+        "bbox": "bbox",
+        "structure": "structure",
+        "structure_score": "structure_score",
+    }
+
+    def __init__(self, model_name, merge_no_span_structure=True, dict_character=[]):
+        super().__init__()
+
+        if merge_no_span_structure:
+            if "<td></td>" not in dict_character:
+                dict_character.append("<td></td>")
+            if "<td>" in dict_character:
+                dict_character.remove("<td>")
+        self.model_name = model_name
+
+        dict_character = self.add_special_char(dict_character)
+        self.dict = {}
+        for i, char in enumerate(dict_character):
+            self.dict[char] = i
+        self.character = dict_character
+        self.td_token = ["<td>", "<td", "<td></td>"]
+
+    def add_special_char(self, dict_character):
+        """add_special_char"""
+        self.beg_str = "sos"
+        self.end_str = "eos"
+        dict_character = dict_character
+        dict_character = [self.beg_str] + dict_character + [self.end_str]
+        return dict_character
+
+    def get_ignored_tokens(self):
+        """get_ignored_tokens"""
+        beg_idx = self.get_beg_end_flag_idx("beg")
+        end_idx = self.get_beg_end_flag_idx("end")
+        return [beg_idx, end_idx]
+
+    def get_beg_end_flag_idx(self, beg_or_end):
+        """get_beg_end_flag_idx"""
+        if beg_or_end == "beg":
+            idx = np.array(self.dict[self.beg_str])
+        elif beg_or_end == "end":
+            idx = np.array(self.dict[self.end_str])
+        else:
+            assert False, "unsupported type %s in get_beg_end_flag_idx" % beg_or_end
+        return idx
+
+    def __call__(self, pred, img_size, ori_img_size):
+        """apply"""
+        bbox_preds, structure_probs = [], []
+
+        for i in range(len(pred[0][0])):
+            bbox_preds.append(pred[0][0][i])
+            structure_probs.append(pred[1][0][i])
+        bbox_preds = [bbox_preds]
+        structure_probs = [structure_probs]
+
+        bbox_preds = np.array(bbox_preds)
+        structure_probs = np.array(structure_probs)
+
+        bbox_list, structure_str_list, structure_score = self.decode(
+            structure_probs, bbox_preds, img_size, ori_img_size
+        )
+        structure_str_list = [
+            (
+                ["<html>", "<body>", "<table>"]
+                + structure
+                + ["</table>", "</body>", "</html>"]
+            )
+            for structure in structure_str_list
+        ]
+        return [
+            {"bbox": bbox, "structure": structure, "structure_score": structure_score}
+            for bbox, structure in zip(bbox_list, structure_str_list)
+        ]
+
+    def decode(self, structure_probs, bbox_preds, padding_size, ori_img_size):
+        """convert text-label into text-index."""
+        ignored_tokens = self.get_ignored_tokens()
+        end_idx = self.dict[self.end_str]
+
+        structure_idx = structure_probs.argmax(axis=2)
+        structure_probs = structure_probs.max(axis=2)
+
+        structure_batch_list = []
+        bbox_batch_list = []
+        batch_size = len(structure_idx)
+        for batch_idx in range(batch_size):
+            structure_list = []
+            bbox_list = []
+            score_list = []
+            for idx in range(len(structure_idx[batch_idx])):
+                char_idx = int(structure_idx[batch_idx][idx])
+                if idx > 0 and char_idx == end_idx:
+                    break
+                if char_idx in ignored_tokens:
+                    continue
+                text = self.character[char_idx]
+                if text in self.td_token:
+                    bbox = bbox_preds[batch_idx, idx]
+                    bbox = self._bbox_decode(
+                        bbox, padding_size[batch_idx], ori_img_size[batch_idx]
+                    )
+                    bbox_list.append(bbox.astype(int))
+                structure_list.append(text)
+                score_list.append(structure_probs[batch_idx, idx])
+            structure_batch_list.append(structure_list)
+            structure_score = np.mean(score_list)
+            bbox_batch_list.append(bbox_list)
+
+        return bbox_batch_list, structure_batch_list, structure_score
+
+    def decode_label(self, batch):
+        """convert text-label into text-index."""
+        structure_idx = batch[1]
+        gt_bbox_list = batch[2]
+        shape_list = batch[-1]
+        ignored_tokens = self.get_ignored_tokens()
+        end_idx = self.dict[self.end_str]
+
+        structure_batch_list = []
+        bbox_batch_list = []
+        batch_size = len(structure_idx)
+        for batch_idx in range(batch_size):
+            structure_list = []
+            bbox_list = []
+            for idx in range(len(structure_idx[batch_idx])):
+                char_idx = int(structure_idx[batch_idx][idx])
+                if idx > 0 and char_idx == end_idx:
+                    break
+                if char_idx in ignored_tokens:
+                    continue
+                structure_list.append(self.character[char_idx])
+
+                bbox = gt_bbox_list[batch_idx][idx]
+                if bbox.sum() != 0:
+                    bbox = self._bbox_decode(bbox, shape_list[batch_idx])
+                    bbox_list.append(bbox.astype(int))
+            structure_batch_list.append(structure_list)
+            bbox_batch_list.append(bbox_list)
+        return bbox_batch_list, structure_batch_list
+
+    def _bbox_decode(self, bbox, padding_shape, ori_shape):
+
+        if self.model_name == "SLANet":
+            w, h = ori_shape
+            bbox[0::2] *= w
+            bbox[1::2] *= h
+        else:
+            w, h = padding_shape
+            ori_w, ori_h = ori_shape
+            ratio_w = w / ori_w
+            ratio_h = h / ori_h
+            ratio = min(ratio_w, ratio_h)
+
+            bbox[0::2] *= w
+            bbox[1::2] *= h
+            bbox[0::2] /= ratio
+            bbox[1::2] /= ratio
+
+        return bbox

+ 108 - 0
paddlex/inference/models_new/table_structure_recognition/result.py

@@ -0,0 +1,108 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import cv2
+import numpy as np
+from pathlib import Path
+import PIL
+from PIL import Image, ImageDraw, ImageFont
+
+from ...common.result import BaseResult, BaseCVResult, HtmlMixin, XlsxMixin
+
+
+class TableRecResult(BaseCVResult):
+    """SaveTableResults"""
+
+    def __init__(self, data):
+        super().__init__(data)
+
+    def _to_img(self):
+        image = self["input_img"]
+        bbox_res = self["bbox"]
+        if len(bbox_res) > 0 and len(bbox_res[0]) == 4:
+            vis_img = self.draw_rectangle(image, bbox_res)
+        else:
+            vis_img = self.draw_bbox(image, bbox_res)
+        return vis_img
+
+    def draw_rectangle(self, image, boxes):
+        """draw_rectangle"""
+        boxes = np.array(boxes)
+        img_show = image.copy()
+        for box in boxes.astype(int):
+            x1, y1, x2, y2 = box
+            cv2.rectangle(img_show, (x1, y1), (x2, y2), (255, 0, 0), 2)
+        return img_show
+
+    def draw_bbox(self, image, boxes):
+        """draw_bbox"""
+        for box in boxes:
+            box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
+            image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
+        return image
+
+
+class StructureTableResult(TableRecResult, HtmlMixin, XlsxMixin):
+    """StructureTableResult"""
+
+    def __init__(self, data):
+        super().__init__(data)
+        HtmlMixin.__init__(self)
+        XlsxMixin.__init__(self)
+
+    def _to_html(self):
+        return self["html"]
+
+
+class TableResult(BaseCVResult, HtmlMixin, XlsxMixin):
+    """TableResult"""
+
+    def __init__(self, data):
+        super().__init__(data)
+        HtmlMixin.__init__(self)
+        XlsxMixin.__init__(self)
+
+    def save_to_html(self, save_path):
+        if not save_path.lower().endswith(("html")):
+            input_path = self["input_path"]
+            save_path = Path(save_path) / f"{Path(input_path).stem}"
+        else:
+            save_path = Path(save_path).stem
+        for table_result in self["table_result"]:
+            table_result.save_to_html(save_path)
+
+    def save_to_xlsx(self, save_path):
+        if not save_path.lower().endswith(("xlsx")):
+            input_path = self["input_path"]
+            save_path = Path(save_path) / f"{Path(input_path).stem}"
+        else:
+            save_path = Path(save_path).stem
+        for table_result in self["table_result"]:
+            table_result.save_to_xlsx(save_path)
+
+    def save_to_img(self, save_path):
+        if not save_path.lower().endswith((".jpg", ".png")):
+            input_path = self["input_path"]
+            save_path = Path(save_path) / f"{Path(input_path).stem}"
+        else:
+            save_path = Path(save_path).stem
+        layout_save_path = f"{save_path}_layout.jpg"
+        ocr_save_path = f"{save_path}_ocr.jpg"
+        table_save_path = f"{save_path}_table"
+        layout_result = self["layout_result"]
+        layout_result.save_to_img(layout_save_path)
+        ocr_result = self["ocr_result"]
+        ocr_result.save_to_img(ocr_save_path)
+        for idx, table_result in enumerate(self["table_result"]):
+            table_result.save_to_img(f"{table_save_path}_{idx}.jpg")

+ 1 - 0
paddlex/modules/image_classification/model_list.py

@@ -95,4 +95,5 @@ MODELS = [
     "FasterNet-T0",
     "FasterNet-T1",
     "FasterNet-T2",
+    "PP-LCNet_x1_0_table_cls",
 ]

+ 2 - 0
paddlex/modules/object_detection/model_list.py

@@ -72,4 +72,6 @@ MODELS = [
     "BlazeFace-FPN-SSH",
     "PP-YOLOE_plus-S_face",
     "PP-YOLOE-R_L",
+    "RT-DETR-L_wired_table_cell_det",
+    "RT-DETR-L_wireless_table_cell_det",
 ]

+ 2 - 0
paddlex/modules/table_recognition/model_list.py

@@ -16,4 +16,6 @@
 MODELS = [
     "SLANet",
     "SLANet_plus",
+    "SLANeXt_wired",
+    "SLANeXt_wireless",
 ]

+ 10 - 0
paddlex/repo_apis/PaddleClas_api/cls/register.py

@@ -896,3 +896,13 @@ register_model_info(
         "infer_config": "deploy/configs/inference_cls.yaml",
     }
 )
+
+register_model_info(
+    {
+        "model_name": "PP-LCNet_x1_0_table_cls",
+        "suite": "Cls",
+        "config_path": osp.join(PDX_CONFIG_DIR, "PP-LCNet_x1_0_table_cls.yaml"),
+        "supported_apis": ["train", "evaluate", "predict", "export", "infer"],
+        "infer_config": "deploy/configs/inference_cls.yaml",
+    }
+)

+ 142 - 0
paddlex/repo_apis/PaddleClas_api/configs/PP-LCNet_x1_0_table_cls.yaml

@@ -0,0 +1,142 @@
+# global configs
+Global:
+  checkpoints: null
+  pretrained_model: null
+  output_dir: ./output/
+  device: gpu
+  save_interval: 1
+  eval_during_train: True
+  eval_interval: 1
+  epochs: 100
+  print_batch_step: 10
+  use_visualdl: True
+  # used for static mode and model export
+  image_shape: [3, 224, 224]
+  save_inference_dir: ./inference
+
+# mixed precision
+AMP:
+  use_amp: False
+  use_fp16_test: False
+  scale_loss: 128.0
+  use_dynamic_loss_scaling: True
+  use_promote: False
+  # O1: mixed fp16, O2: pure fp16
+  level: O1
+
+
+# model architecture
+Arch:
+  name: PPLCNet_x1_0
+  class_num: 2
+  pretrained: True
+ 
+# loss function config for traing/eval process
+Loss:
+  Train:
+    - CELoss:
+        weight: 1.0
+        epsilon: 0.1
+  Eval:
+    - CELoss:
+        weight: 1.0
+
+
+Optimizer:
+  name: Momentum
+  momentum: 0.9
+  lr:
+    name: Cosine
+    learning_rate: 0.1
+    warmup_epoch: 5
+  regularizer:
+    name: 'L2'
+    coeff: 0.00003
+
+
+# data loader for train and eval
+DataLoader:
+  Train:
+    dataset:
+      name: ClsDataset
+      image_root: ./dataset/table_classification/
+      cls_label_path: ./dataset/table_classification/train_list.txt
+      transform_ops:
+        - DecodeImage:
+            to_rgb: True
+            channel_first: False
+        - RandCropImage:
+            size: 224
+        - RandFlipImage:
+            flip_code: 1
+        - NormalizeImage:
+            scale: 1.0/255.0
+            mean: [0.485, 0.456, 0.406]
+            std: [0.229, 0.224, 0.225]
+            order: ''
+
+    sampler:
+      name: DistributedBatchSampler
+      batch_size: 64
+      drop_last: False
+      shuffle: True
+    loader:
+      num_workers: 4
+      use_shared_memory: True
+
+  Eval:
+    dataset: 
+      name: ClsDataset
+      image_root: ./dataset/table_classification/
+      cls_label_path: ./dataset/table_classification/val_list.txt
+      transform_ops:
+        - DecodeImage:
+            to_rgb: True
+            channel_first: False
+        - ResizeImage:
+            resize_short: 256
+        - CropImage:
+            size: 224
+        - NormalizeImage:
+            scale: 1.0/255.0
+            mean: [0.485, 0.456, 0.406]
+            std: [0.229, 0.224, 0.225]
+            order: ''
+    sampler:
+      name: DistributedBatchSampler
+      batch_size: 64
+      drop_last: False
+      shuffle: False
+    loader:
+      num_workers: 4
+      use_shared_memory: True
+
+Infer:
+  infer_imgs: docs/images/inference_deployment/whl_demo.jpg
+  batch_size: 10
+  transforms:
+    - DecodeImage:
+        to_rgb: True
+        channel_first: False
+    - ResizeImage:
+        resize_short: 256
+    - CropImage:
+        size: 224
+    - NormalizeImage:
+        scale: 1.0/255.0
+        mean: [0.485, 0.456, 0.406]
+        std: [0.229, 0.224, 0.225]
+        order: ''
+    - ToCHWImage:
+  PostProcess:
+    name: Topk
+    topk: 5
+    class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
+
+Metric:
+  Train:
+    - TopkAcc:
+        topk: [1, 5]
+  Eval:
+    - TopkAcc:
+        topk: [1, 5]

+ 173 - 0
paddlex/repo_apis/PaddleDetection_api/configs/RT-DETR-L_wired_table_cell_det.yaml

@@ -0,0 +1,173 @@
+# Runtime
+epoch: 40
+log_iter: 10
+find_unused_parameters: true
+use_gpu: true
+use_xpu: false
+use_mlu: false
+use_npu: false
+use_ema: true
+ema_decay: 0.9999
+ema_decay_type: "exponential"
+ema_filter_no_grad: true
+save_dir: output
+snapshot_epoch: 1
+print_flops: false
+print_params: false
+eval_size: [640, 640]
+
+# Dataset
+metric: COCO
+num_classes: 1
+
+worker_num: 4
+
+TrainDataset:
+  name: COCODetDataset
+  image_dir: images
+  anno_path: annotations/instance_train.json
+  dataset_dir: datasets/COCO
+  data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd']
+
+EvalDataset:
+  name: COCODetDataset
+  image_dir: images
+  anno_path: annotations/instance_val.json
+  dataset_dir: datasets/COCO
+  allow_empty: true
+
+TestDataset:
+  name: ImageFolder
+  anno_path: annotations/instance_val.json
+  dataset_dir: datasets/COCO
+
+TrainReader:
+  sample_transforms:
+    - Decode: {}
+    - RandomDistort: {prob: 0.8}
+    - RandomExpand: {fill_value: [123.675, 116.28, 103.53]}
+    - RandomCrop: {prob: 0.8}
+    - RandomFlip: {}
+  batch_transforms:
+    - BatchRandomResize: {target_size: [480, 512, 544, 576, 608, 640, 640, 640, 672, 704, 736, 768, 800], random_size: True, random_interp: True, keep_ratio: False}
+    - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
+    - NormalizeBox: {}
+    - BboxXYXY2XYWH: {}
+    - Permute: {}
+  batch_size: 8
+  shuffle: true
+  drop_last: true
+  collate_batch: false
+  use_shared_memory: true
+
+EvalReader:
+  sample_transforms:
+    - Decode: {}
+    - Resize: {target_size: [640, 640], keep_ratio: False, interp: 2}
+    - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
+    - Permute: {}
+  batch_size: 4
+  shuffle: false
+  drop_last: false
+
+TestReader:
+  inputs_def:
+    image_shape: [3, 640, 640]
+  sample_transforms:
+    - Decode: {}
+    - Resize: {target_size: [640, 640], keep_ratio: False, interp: 2}
+    - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
+    - Permute: {}
+  batch_size: 1
+  shuffle: false
+  drop_last: false
+
+# Model
+architecture: DETR
+pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/PPHGNetV2_L_ssld_pretrained.pdparams
+
+norm_type: sync_bn
+hidden_dim: 256
+use_focal_loss: True
+
+DETR:
+  backbone: PPHGNetV2
+  neck: HybridEncoder
+  transformer: RTDETRTransformer
+  detr_head: DINOHead
+  post_process: DETRPostProcess
+
+PPHGNetV2:
+  arch: 'L'
+  return_idx: [1, 2, 3]
+  freeze_stem_only: true
+  freeze_at: 0
+  freeze_norm: true
+  lr_mult_list: [0., 0.05, 0.05, 0.05, 0.05]
+
+HybridEncoder:
+  hidden_dim: 256
+  use_encoder_idx: [2]
+  num_encoder_layers: 1
+  encoder_layer:
+    name: TransformerLayer
+    d_model: 256
+    nhead: 8
+    dim_feedforward: 1024
+    dropout: 0.
+    activation: 'gelu'
+  expansion: 1.0
+
+RTDETRTransformer:
+  num_queries: 300
+  position_embed_type: sine
+  feat_strides: [8, 16, 32]
+  num_levels: 3
+  nhead: 8
+  num_decoder_layers: 6
+  dim_feedforward: 1024
+  dropout: 0.0
+  activation: relu
+  num_denoising: 100
+  label_noise_ratio: 0.5
+  box_noise_scale: 1.0
+  learnt_init_query: false
+
+DINOHead:
+  loss:
+    name: DINOLoss
+    loss_coeff: {class: 1, bbox: 5, giou: 2}
+    aux_loss: true
+    use_vfl: true
+    matcher:
+      name: HungarianMatcher
+      matcher_coeff: {class: 2, bbox: 5, giou: 2}
+
+DETRPostProcess:
+  num_top_queries: 300
+
+# Optimizer
+LearningRate:
+  base_lr: 0.0001
+  schedulers:
+  - !PiecewiseDecay
+    gamma: 1.0
+    milestones: [100]
+    use_warmup: true
+  - !LinearWarmup
+    start_factor: 0.001
+    steps: 100
+
+OptimizerBuilder:
+  clip_grad_by_norm: 0.1
+  regularizer: false
+  optimizer:
+    type: AdamW
+    weight_decay: 0.0001
+
+# Export
+export:
+  post_process: true
+  nms: true
+  benchmark: false
+  fuse_conv_bn: false

+ 173 - 0
paddlex/repo_apis/PaddleDetection_api/configs/RT-DETR-L_wireless_table_cell_det.yaml

@@ -0,0 +1,173 @@
+# Runtime
+epoch: 40
+log_iter: 10
+find_unused_parameters: true
+use_gpu: true
+use_xpu: false
+use_mlu: false
+use_npu: false
+use_ema: true
+ema_decay: 0.9999
+ema_decay_type: "exponential"
+ema_filter_no_grad: true
+save_dir: output
+snapshot_epoch: 1
+print_flops: false
+print_params: false
+eval_size: [640, 640]
+
+# Dataset
+metric: COCO
+num_classes: 1
+
+worker_num: 4
+
+TrainDataset:
+  name: COCODetDataset
+  image_dir: images
+  anno_path: annotations/instance_train.json
+  dataset_dir: datasets/COCO
+  data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd']
+
+EvalDataset:
+  name: COCODetDataset
+  image_dir: images
+  anno_path: annotations/instance_val.json
+  dataset_dir: datasets/COCO
+  allow_empty: true
+
+TestDataset:
+  name: ImageFolder
+  anno_path: annotations/instance_val.json
+  dataset_dir: datasets/COCO
+
+TrainReader:
+  sample_transforms:
+    - Decode: {}
+    - RandomDistort: {prob: 0.8}
+    - RandomExpand: {fill_value: [123.675, 116.28, 103.53]}
+    - RandomCrop: {prob: 0.8}
+    - RandomFlip: {}
+  batch_transforms:
+    - BatchRandomResize: {target_size: [480, 512, 544, 576, 608, 640, 640, 640, 672, 704, 736, 768, 800], random_size: True, random_interp: True, keep_ratio: False}
+    - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
+    - NormalizeBox: {}
+    - BboxXYXY2XYWH: {}
+    - Permute: {}
+  batch_size: 8
+  shuffle: true
+  drop_last: true
+  collate_batch: false
+  use_shared_memory: true
+
+EvalReader:
+  sample_transforms:
+    - Decode: {}
+    - Resize: {target_size: [640, 640], keep_ratio: False, interp: 2}
+    - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
+    - Permute: {}
+  batch_size: 4
+  shuffle: false
+  drop_last: false
+
+TestReader:
+  inputs_def:
+    image_shape: [3, 640, 640]
+  sample_transforms:
+    - Decode: {}
+    - Resize: {target_size: [640, 640], keep_ratio: False, interp: 2}
+    - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
+    - Permute: {}
+  batch_size: 1
+  shuffle: false
+  drop_last: false
+
+# Model
+architecture: DETR
+pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/PPHGNetV2_L_ssld_pretrained.pdparams
+
+norm_type: sync_bn
+hidden_dim: 256
+use_focal_loss: True
+
+DETR:
+  backbone: PPHGNetV2
+  neck: HybridEncoder
+  transformer: RTDETRTransformer
+  detr_head: DINOHead
+  post_process: DETRPostProcess
+
+PPHGNetV2:
+  arch: 'L'
+  return_idx: [1, 2, 3]
+  freeze_stem_only: true
+  freeze_at: 0
+  freeze_norm: true
+  lr_mult_list: [0., 0.05, 0.05, 0.05, 0.05]
+
+HybridEncoder:
+  hidden_dim: 256
+  use_encoder_idx: [2]
+  num_encoder_layers: 1
+  encoder_layer:
+    name: TransformerLayer
+    d_model: 256
+    nhead: 8
+    dim_feedforward: 1024
+    dropout: 0.
+    activation: 'gelu'
+  expansion: 1.0
+
+RTDETRTransformer:
+  num_queries: 300
+  position_embed_type: sine
+  feat_strides: [8, 16, 32]
+  num_levels: 3
+  nhead: 8
+  num_decoder_layers: 6
+  dim_feedforward: 1024
+  dropout: 0.0
+  activation: relu
+  num_denoising: 100
+  label_noise_ratio: 0.5
+  box_noise_scale: 1.0
+  learnt_init_query: false
+
+DINOHead:
+  loss:
+    name: DINOLoss
+    loss_coeff: {class: 1, bbox: 5, giou: 2}
+    aux_loss: true
+    use_vfl: true
+    matcher:
+      name: HungarianMatcher
+      matcher_coeff: {class: 2, bbox: 5, giou: 2}
+
+DETRPostProcess:
+  num_top_queries: 300
+
+# Optimizer
+LearningRate:
+  base_lr: 0.0001
+  schedulers:
+  - !PiecewiseDecay
+    gamma: 1.0
+    milestones: [100]
+    use_warmup: true
+  - !LinearWarmup
+    start_factor: 0.001
+    steps: 100
+
+OptimizerBuilder:
+  clip_grad_by_norm: 0.1
+  regularizer: false
+  optimizer:
+    type: AdamW
+    weight_decay: 0.0001
+
+# Export
+export:
+  post_process: true
+  nms: true
+  benchmark: false
+  fuse_conv_bn: false

+ 2 - 0
paddlex/repo_apis/PaddleDetection_api/object_det/official_categories.py

@@ -140,4 +140,6 @@ official_categories = {
     "BlazeFace-FPN-SSH": [{"name": "face", "id": 0}],
     "PicoDet_LCNet_x2_5_face": [{"name": "face", "id": 0}],
     "PP-YOLOE_plus-S_face": [{"name": "face", "id": 0}],
+    "RT-DETR-L_wired_table_cell_det": [{"name": "cell", "id": 0}],
+    "RT-DETR-L_wireless_table_cell_det": [{"name": "cell", "id": 0}],
 }

+ 33 - 1
paddlex/repo_apis/PaddleDetection_api/object_det/register.py

@@ -939,4 +939,36 @@ register_model_info(
             "amp": ["OFF"],
         },
     }
-)
+)
+
+register_model_info(
+    {
+        "model_name": "RT-DETR-L_wired_table_cell_det",
+        "suite": "Det",
+        "config_path": osp.join(PDX_CONFIG_DIR, "RT-DETR-L_wired_table_cell_det.yaml"),
+        "supported_apis": ["train", "evaluate", "predict", "export", "infer"],
+        "supported_dataset_types": ["COCODetDataset"],
+        "supported_train_opts": {
+            "device": ["cpu", "gpu_nxcx", "xpu", "npu", "mlu"],
+            "dy2st": False,
+            "amp": ["OFF"],
+        },
+    }
+)
+
+register_model_info(
+    {
+        "model_name": "RT-DETR-L_wireless_table_cell_det",
+        "suite": "Det",
+        "config_path": osp.join(
+            PDX_CONFIG_DIR, "RT-DETR-L_wireless_table_cell_det.yaml"
+        ),
+        "supported_apis": ["train", "evaluate", "predict", "export", "infer"],
+        "supported_dataset_types": ["COCODetDataset"],
+        "supported_train_opts": {
+            "device": ["cpu", "gpu_nxcx", "xpu", "npu", "mlu"],
+            "dy2st": False,
+            "amp": ["OFF"],
+        },
+    }
+)

+ 179 - 0
paddlex/repo_apis/PaddleOCR_api/configs/SLANeXt_wired.yaml

@@ -0,0 +1,179 @@
+Global:
+  use_gpu: true
+  epoch_num: 400
+  log_smooth_window: 20
+  print_batch_step: 20
+  save_model_dir: ./output/SLANeXt_wired
+  save_epoch_step: 400
+  eval_batch_step:
+  - 0
+  - 331
+  cal_metric_during_train: true
+  pretrained_model: null
+  checkpoints: null
+  save_inference_dir: ./output/SLANeXt_wired/infer
+  use_visualdl: false
+  infer_img: ppstructure/docs/table/table.jpg
+  character_dict_path: ppocr/utils/dict/table_structure_dict_ch.txt
+  character_type: en
+  max_text_length: 500
+  box_format: xyxyxyxy
+  infer_mode: false
+  use_sync_bn: true
+  save_res_path: output/infer
+
+Optimizer:
+  name: AdamW
+  beta1: 0.9
+  beta2: 0.999
+  clip_norm: 5.0
+  lr:
+    name: Cosine
+    learning_rate: 0.0001
+    warmup_epoch: 1
+  regularizer:
+    name: L2
+    factor: 0.0
+
+Architecture:
+  model_type: table
+  algorithm: SLANeXt
+  Backbone:
+    name: Vary_VIT_B
+    image_size: 512 
+    encoder_embed_dim: 768
+    encoder_depth: 12
+    encoder_num_heads: 12
+    encoder_global_attn_indexes: [2, 5, 8, 11]
+  Head:
+    name: SLAHead
+    hidden_size: 512
+    max_text_length: 500
+    loc_reg_num: 8
+
+Loss:
+  name: SLALoss
+  structure_weight: 1.0
+  # SLANeXt does not train the cell location task by default, set the loc_weight if needed.
+  loc_weight: 0.0
+  loc_loss: smooth_l1
+
+PostProcess:
+  name: TableLabelDecode
+  merge_no_span_structure: true
+
+Metric:
+  name: TableMetric
+  main_indicator: acc
+  compute_bbox_metric: false
+  loc_reg_num: 8
+  box_format: xyxyxyxy
+  del_thead_tbody: true
+
+Train:
+  dataset:
+    name: PubTabDataSet
+    data_dir: train_data/table/train/
+    label_file_list:
+    - train_data/table/train.txt
+    ratio_list:
+    - 1
+    transforms:
+    - DecodeImage:
+        img_mode: BGR
+        channel_first: false
+    - TableLabelEncode:
+        learn_empty_box: false
+        merge_no_span_structure: true
+        replace_empty_cell_token: false
+        loc_reg_num: 8
+        max_text_length: 500
+    - TableBoxEncode:
+        in_box_format: xyxyxyxy
+        out_box_format: xyxyxyxy
+    - ResizeTableImage:
+        max_len: 512
+        resize_bboxes: true
+    - NormalizeImage:
+        scale: 1./255.
+        mean:
+        - 0.485
+        - 0.456
+        - 0.406
+        std:
+        - 0.229
+        - 0.224
+        - 0.225
+        order: hwc
+    - PaddingTableImage:
+        size:
+        - 512
+        - 512
+    - ToCHWImage: null
+    - KeepKeys:
+        keep_keys:
+        - image
+        - structure
+        - bboxes
+        - bbox_masks
+        - length
+        - shape
+  loader:
+    shuffle: true
+    batch_size_per_card: 48
+    drop_last: true
+    num_workers: 1
+
+Eval:
+  dataset:
+    name: PubTabDataSet
+    data_dir: train_data/table/val/
+    label_file_list:
+    - train_data/table/val.txt
+    transforms:
+    - DecodeImage:
+        img_mode: BGR
+        channel_first: false
+    - TableLabelEncode:
+        learn_empty_box: false
+        merge_no_span_structure: true
+        replace_empty_cell_token: false
+        loc_reg_num: 8
+        max_text_length: 500
+    - TableBoxEncode:
+        in_box_format: xyxyxyxy
+        out_box_format: xyxyxyxy
+    - ResizeTableImage:
+        max_len: 512
+        resize_bboxes: true
+    - NormalizeImage:
+        scale: 1./255.
+        mean:
+        - 0.485
+        - 0.456
+        - 0.406
+        std:
+        - 0.229
+        - 0.224
+        - 0.225
+        order: hwc
+    - PaddingTableImage:
+        size:
+        - 512
+        - 512
+    - ToCHWImage: null
+    - KeepKeys:
+        keep_keys:
+        - image
+        - structure
+        - bboxes
+        - bbox_masks
+        - length
+        - shape
+  loader:
+    shuffle: false
+    drop_last: false
+    batch_size_per_card: 48
+    num_workers: 1
+
+profiler_options: null

+ 179 - 0
paddlex/repo_apis/PaddleOCR_api/configs/SLANeXt_wireless.yaml

@@ -0,0 +1,179 @@
+Global:
+  use_gpu: true
+  epoch_num: 400
+  log_smooth_window: 20
+  print_batch_step: 20
+  save_model_dir: ./output/SLANeXt_wireless
+  save_epoch_step: 400
+  eval_batch_step:
+  - 0
+  - 331
+  cal_metric_during_train: true
+  pretrained_model: null
+  checkpoints: null
+  save_inference_dir: ./output/SLANeXt_wireless/infer
+  use_visualdl: false
+  infer_img: ppstructure/docs/table/table.jpg
+  character_dict_path: ppocr/utils/dict/table_structure_dict_ch.txt
+  character_type: en
+  max_text_length: 500
+  box_format: xyxyxyxy
+  infer_mode: false
+  use_sync_bn: true
+  save_res_path: output/infer
+
+Optimizer:
+  name: AdamW
+  beta1: 0.9
+  beta2: 0.999
+  clip_norm: 5.0
+  lr:
+    name: Cosine
+    learning_rate: 0.0001
+    warmup_epoch: 1
+  regularizer:
+    name: L2
+    factor: 0.0
+
+Architecture:
+  model_type: table
+  algorithm: SLANeXt
+  Backbone:
+    name: Vary_VIT_B
+    image_size: 512 
+    encoder_embed_dim: 768
+    encoder_depth: 12
+    encoder_num_heads: 12
+    encoder_global_attn_indexes: [2, 5, 8, 11]
+  Head:
+    name: SLAHead
+    hidden_size: 512
+    max_text_length: 500
+    loc_reg_num: 8
+
+Loss:
+  name: SLALoss
+  structure_weight: 1.0
+  # SLANeXt does not train the cell location task by default, set the loc_weight if needed.
+  loc_weight: 0.0
+  loc_loss: smooth_l1
+
+PostProcess:
+  name: TableLabelDecode
+  merge_no_span_structure: true
+
+Metric:
+  name: TableMetric
+  main_indicator: acc
+  compute_bbox_metric: false
+  loc_reg_num: 8
+  box_format: xyxyxyxy
+  del_thead_tbody: true
+
+Train:
+  dataset:
+    name: PubTabDataSet
+    data_dir: train_data/table/train/
+    label_file_list:
+    - train_data/table/train.txt
+    ratio_list:
+    - 1
+    transforms:
+    - DecodeImage:
+        img_mode: BGR
+        channel_first: false
+    - TableLabelEncode:
+        learn_empty_box: false
+        merge_no_span_structure: true
+        replace_empty_cell_token: false
+        loc_reg_num: 8
+        max_text_length: 500
+    - TableBoxEncode:
+        in_box_format: xyxyxyxy
+        out_box_format: xyxyxyxy
+    - ResizeTableImage:
+        max_len: 512
+        resize_bboxes: true
+    - NormalizeImage:
+        scale: 1./255.
+        mean:
+        - 0.485
+        - 0.456
+        - 0.406
+        std:
+        - 0.229
+        - 0.224
+        - 0.225
+        order: hwc
+    - PaddingTableImage:
+        size:
+        - 512
+        - 512
+    - ToCHWImage: null
+    - KeepKeys:
+        keep_keys:
+        - image
+        - structure
+        - bboxes
+        - bbox_masks
+        - length
+        - shape
+  loader:
+    shuffle: true
+    batch_size_per_card: 48
+    drop_last: true
+    num_workers: 1
+
+Eval:
+  dataset:
+    name: PubTabDataSet
+    data_dir: train_data/table/val/
+    label_file_list:
+    - train_data/table/val.txt
+    transforms:
+    - DecodeImage:
+        img_mode: BGR
+        channel_first: false
+    - TableLabelEncode:
+        learn_empty_box: false
+        merge_no_span_structure: true
+        replace_empty_cell_token: false
+        loc_reg_num: 8
+        max_text_length: 500
+    - TableBoxEncode:
+        in_box_format: xyxyxyxy
+        out_box_format: xyxyxyxy
+    - ResizeTableImage:
+        max_len: 512
+        resize_bboxes: true
+    - NormalizeImage:
+        scale: 1./255.
+        mean:
+        - 0.485
+        - 0.456
+        - 0.406
+        std:
+        - 0.229
+        - 0.224
+        - 0.225
+        order: hwc
+    - PaddingTableImage:
+        size:
+        - 512
+        - 512
+    - ToCHWImage: null
+    - KeepKeys:
+        keep_keys:
+        - image
+        - structure
+        - bboxes
+        - bbox_masks
+        - length
+        - shape
+  loader:
+    shuffle: false
+    drop_last: false
+    batch_size_per_card: 48
+    num_workers: 1
+
+profiler_options: null

+ 18 - 0
paddlex/repo_apis/PaddleOCR_api/table_rec/register.py

@@ -51,3 +51,21 @@ register_model_info(
         "supported_apis": ["train", "evaluate", "predict", "export"],
     }
 )
+
+register_model_info(
+    {
+        "model_name": "SLANeXt_wired",
+        "suite": "TableRec",
+        "config_path": osp.join(PDX_CONFIG_DIR, "SLANeXt_wired.yaml"),
+        "supported_apis": ["train", "evaluate", "predict", "export"],
+    }
+)
+
+register_model_info(
+    {
+        "model_name": "SLANeXt_wireless",
+        "suite": "TableRec",
+        "config_path": osp.join(PDX_CONFIG_DIR, "SLANeXt_wireless.yaml"),
+        "supported_apis": ["train", "evaluate", "predict", "export"],
+    }
+)