predictor.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Union, Dict, List, Tuple
  15. import numpy as np
  16. from ....utils.func_register import FuncRegister
  17. from ....modules.table_recognition.model_list import MODELS
  18. from ...common.batch_sampler import ImageBatchSampler
  19. from ...common.reader import ReadImage
  20. from ..common import (
  21. Resize,
  22. ResizeByLong,
  23. Normalize,
  24. ToCHWImage,
  25. ToBatch,
  26. StaticInfer,
  27. )
  28. from ..base import BasicPredictor
  29. from .processors import Pad, TableLabelDecode
  30. from .result import TableRecResult
  31. class TablePredictor(BasicPredictor):
  32. entities = MODELS
  33. _FUNC_MAP = {}
  34. register = FuncRegister(_FUNC_MAP)
  35. def __init__(self, *args: List, **kwargs: Dict) -> None:
  36. super().__init__(*args, **kwargs)
  37. self.preprocessors, self.infer, self.postprocessors = self._build()
  38. def _build_batch_sampler(self) -> ImageBatchSampler:
  39. return ImageBatchSampler()
  40. def _get_result_class(self) -> type:
  41. return TableRecResult
  42. def _build(self) -> Tuple:
  43. preprocessors = []
  44. for cfg in self.config["PreProcess"]["transform_ops"]:
  45. tf_key = list(cfg.keys())[0]
  46. func = self._FUNC_MAP[tf_key]
  47. args = cfg.get(tf_key, {})
  48. op = func(self, **args) if args else func(self)
  49. if op:
  50. preprocessors.append(op)
  51. preprocessors.append(ToBatch())
  52. infer = StaticInfer(
  53. model_dir=self.model_dir,
  54. model_prefix=self.MODEL_FILE_PREFIX,
  55. option=self.pp_option,
  56. )
  57. postprocessors = TableLabelDecode(
  58. model_name=self.config["Global"]["model_name"],
  59. merge_no_span_structure=self.config["PreProcess"]["transform_ops"][1][
  60. "TableLabelEncode"
  61. ]["merge_no_span_structure"],
  62. dict_character=self.config["PostProcess"]["character_dict"],
  63. )
  64. return preprocessors, infer, postprocessors
  65. def process(self, batch_data: List[Union[str, np.ndarray]]) -> Dict[str, Any]:
  66. """
  67. Process a batch of data through the preprocessing, inference, and postprocessing.
  68. Args:
  69. batch_data (List[Union[str, np.ndarray], ...]): A batch of input data (e.g., image file paths).
  70. Returns:
  71. dict: A dictionary containing the input path, raw image, class IDs, scores, and label names for every instance of the batch. Keys include 'input_path', 'input_img', 'class_ids', 'scores', and 'label_names'.
  72. """
  73. batch_raw_imgs = self.preprocessors[0](imgs=batch_data) # ReadImage
  74. ori_shapes = []
  75. for s in range(len(batch_raw_imgs)):
  76. ori_shapes.append([batch_raw_imgs[s].shape[1], batch_raw_imgs[s].shape[0]])
  77. batch_imgs = self.preprocessors[1](imgs=batch_raw_imgs) # ResizeByLong
  78. batch_imgs = self.preprocessors[2](imgs=batch_imgs) # Normalize
  79. pad_results = self.preprocessors[3](imgs=batch_imgs) # Pad
  80. pad_imgs = []
  81. padding_sizes = []
  82. for pad_img, padding_size in pad_results:
  83. pad_imgs.append(pad_img)
  84. padding_sizes.append(padding_size)
  85. batch_imgs = self.preprocessors[4](imgs=pad_imgs) # ToCHWImage
  86. x = self.preprocessors[5](imgs=batch_imgs) # ToBatch
  87. batch_preds = self.infer(x=x)
  88. table_result = self.postprocessors(
  89. pred=batch_preds,
  90. img_size=padding_sizes,
  91. ori_img_size=ori_shapes,
  92. )
  93. table_result_bbox = []
  94. table_result_structure = []
  95. table_result_structure_score = []
  96. for i in range(len(table_result)):
  97. table_result_bbox.append(table_result[i]["bbox"])
  98. table_result_structure.append(table_result[i]["structure"])
  99. table_result_structure_score.append(table_result[i]["structure_score"])
  100. final_result = {
  101. "input_path": batch_data,
  102. "input_img": batch_raw_imgs,
  103. "bbox": table_result_bbox,
  104. "structure": table_result_structure,
  105. "structure_score": table_result_structure_score,
  106. }
  107. return final_result
  108. @register("DecodeImage")
  109. def build_readimg(self, channel_first=False, img_mode="BGR"):
  110. assert channel_first is False
  111. assert img_mode == "BGR"
  112. return ReadImage(format=img_mode)
  113. @register("TableLabelEncode")
  114. def foo(self, *args, **kwargs):
  115. return None
  116. @register("TableBoxEncode")
  117. def foo(self, *args, **kwargs):
  118. return None
  119. @register("ResizeTableImage")
  120. def build_resize_table(self, max_len=488, resize_bboxes=True):
  121. return ResizeByLong(target_long_edge=max_len)
  122. @register("NormalizeImage")
  123. def build_normalize(
  124. self,
  125. mean=[0.485, 0.456, 0.406],
  126. std=[0.229, 0.224, 0.225],
  127. scale=1 / 255,
  128. order="hwc",
  129. ):
  130. return Normalize(mean=mean, std=std)
  131. @register("PaddingTableImage")
  132. def build_padding(self, size=[488, 448], pad_value=0):
  133. return Pad(target_size=size[0], val=pad_value)
  134. @register("ToCHWImage")
  135. def build_to_chw(self):
  136. return ToCHWImage()
  137. @register("KeepKeys")
  138. def foo(self, *args, **kwargs):
  139. return None
  140. def _pack_res(self, single):
  141. keys = ["input_path", "bbox", "structure"]
  142. return TableRecResult({key: single[key] for key in keys})