predictor.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Dict, List, Tuple, Union
  15. import numpy as np
  16. from ....modules.table_recognition.model_list import MODELS
  17. from ....utils.func_register import FuncRegister
  18. from ...common.batch_sampler import ImageBatchSampler
  19. from ...common.reader import ReadImage
  20. from ..base import BasePredictor
  21. from ..common import Normalize, ResizeByLong, ToBatch, ToCHWImage
  22. from .processors import Pad, TableLabelDecode
  23. from .result import TableRecResult
  24. class TablePredictor(BasePredictor):
  25. entities = MODELS
  26. _FUNC_MAP = {}
  27. register = FuncRegister(_FUNC_MAP)
  28. def __init__(self, *args: List, **kwargs: Dict) -> None:
  29. super().__init__(*args, **kwargs)
  30. self.preprocessors, self.infer, self.postprocessors = self._build()
  31. def _build_batch_sampler(self) -> ImageBatchSampler:
  32. return ImageBatchSampler()
  33. def _get_result_class(self) -> type:
  34. return TableRecResult
  35. def _build(self) -> Tuple:
  36. preprocessors = []
  37. for cfg in self.config["PreProcess"]["transform_ops"]:
  38. tf_key = list(cfg.keys())[0]
  39. func = self._FUNC_MAP[tf_key]
  40. args = cfg.get(tf_key, {})
  41. op = func(self, **args) if args else func(self)
  42. if op:
  43. preprocessors.append(op)
  44. preprocessors.append(ToBatch())
  45. infer = self.create_static_infer()
  46. postprocessors = TableLabelDecode(
  47. model_name=self.config["Global"]["model_name"],
  48. merge_no_span_structure=self.config["PreProcess"]["transform_ops"][1][
  49. "TableLabelEncode"
  50. ]["merge_no_span_structure"],
  51. dict_character=self.config["PostProcess"]["character_dict"],
  52. )
  53. return preprocessors, infer, postprocessors
  54. def process(self, batch_data: List[Union[str, np.ndarray]]) -> Dict[str, Any]:
  55. """
  56. Process a batch of data through the preprocessing, inference, and postprocessing.
  57. Args:
  58. batch_data (List[Union[str, np.ndarray], ...]): A batch of input data (e.g., image file paths).
  59. Returns:
  60. dict: A dictionary containing the input path, raw image, class IDs, scores, and label names for every instance of the batch. Keys include 'input_path', 'input_img', 'class_ids', 'scores', and 'label_names'.
  61. """
  62. batch_raw_imgs = self.preprocessors[0](imgs=batch_data.instances) # ReadImage
  63. ori_shapes = []
  64. for s in range(len(batch_raw_imgs)):
  65. ori_shapes.append([batch_raw_imgs[s].shape[1], batch_raw_imgs[s].shape[0]])
  66. batch_imgs = self.preprocessors[1](imgs=batch_raw_imgs) # ResizeByLong
  67. batch_imgs = self.preprocessors[2](imgs=batch_imgs) # Normalize
  68. pad_results = self.preprocessors[3](imgs=batch_imgs) # Pad
  69. pad_imgs = []
  70. padding_sizes = []
  71. for pad_img, padding_size in pad_results:
  72. pad_imgs.append(pad_img)
  73. padding_sizes.append(padding_size)
  74. batch_imgs = self.preprocessors[4](imgs=pad_imgs) # ToCHWImage
  75. x = self.preprocessors[5](imgs=batch_imgs) # ToBatch
  76. batch_preds = self.infer(x=x)
  77. table_result = self.postprocessors(
  78. pred=batch_preds,
  79. img_size=padding_sizes,
  80. ori_img_size=ori_shapes,
  81. )
  82. table_result_bbox = []
  83. table_result_structure = []
  84. table_result_structure_score = []
  85. for i in range(len(table_result)):
  86. table_result_bbox.append(table_result[i]["bbox"])
  87. table_result_structure.append(table_result[i]["structure"])
  88. table_result_structure_score.append(table_result[i]["structure_score"])
  89. final_result = {
  90. "input_path": batch_data.input_paths,
  91. "page_index": batch_data.page_indexes,
  92. "input_img": batch_raw_imgs,
  93. "bbox": table_result_bbox,
  94. "structure": table_result_structure,
  95. "structure_score": table_result_structure_score,
  96. }
  97. return final_result
  98. @register("DecodeImage")
  99. def build_readimg(self, channel_first=False, img_mode="BGR"):
  100. assert channel_first is False
  101. assert img_mode == "BGR"
  102. return ReadImage(format=img_mode)
  103. @register("TableLabelEncode")
  104. def foo(self, *args, **kwargs):
  105. return None
  106. @register("TableBoxEncode")
  107. def foo(self, *args, **kwargs):
  108. return None
  109. @register("ResizeTableImage")
  110. def build_resize_table(self, max_len=488, resize_bboxes=True):
  111. return ResizeByLong(target_long_edge=max_len)
  112. @register("NormalizeImage")
  113. def build_normalize(
  114. self,
  115. mean=[0.485, 0.456, 0.406],
  116. std=[0.229, 0.224, 0.225],
  117. scale=1 / 255,
  118. order="hwc",
  119. ):
  120. return Normalize(mean=mean, std=std)
  121. @register("PaddingTableImage")
  122. def build_padding(self, size=[488, 448], pad_value=0):
  123. return Pad(target_size=size[0], val=pad_value)
  124. @register("ToCHWImage")
  125. def build_to_chw(self):
  126. return ToCHWImage()
  127. @register("KeepKeys")
  128. def foo(self, *args, **kwargs):
  129. return None
  130. def _pack_res(self, single):
  131. keys = ["input_path", "bbox", "structure"]
  132. return TableRecResult({key: single[key] for key in keys})