predictor.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Union, Dict, List, Tuple
  15. import numpy as np
  16. from ....utils.func_register import FuncRegister
  17. from ....modules.table_recognition.model_list import MODELS
  18. from ...common.batch_sampler import ImageBatchSampler
  19. from ...common.reader import ReadImage
  20. from ..common import (
  21. Resize,
  22. ResizeByLong,
  23. Normalize,
  24. ToCHWImage,
  25. ToBatch,
  26. )
  27. from ..base import BasePredictor
  28. from .processors import Pad, TableLabelDecode
  29. from .result import TableRecResult
  30. class TablePredictor(BasePredictor):
  31. entities = MODELS
  32. _FUNC_MAP = {}
  33. register = FuncRegister(_FUNC_MAP)
  34. def __init__(self, *args: List, **kwargs: Dict) -> None:
  35. super().__init__(*args, **kwargs)
  36. self.preprocessors, self.infer, self.postprocessors = self._build()
  37. def _build_batch_sampler(self) -> ImageBatchSampler:
  38. return ImageBatchSampler()
  39. def _get_result_class(self) -> type:
  40. return TableRecResult
  41. def _build(self) -> Tuple:
  42. preprocessors = []
  43. for cfg in self.config["PreProcess"]["transform_ops"]:
  44. tf_key = list(cfg.keys())[0]
  45. func = self._FUNC_MAP[tf_key]
  46. args = cfg.get(tf_key, {})
  47. op = func(self, **args) if args else func(self)
  48. if op:
  49. preprocessors.append(op)
  50. preprocessors.append(ToBatch())
  51. infer = self.create_static_infer()
  52. postprocessors = TableLabelDecode(
  53. model_name=self.config["Global"]["model_name"],
  54. merge_no_span_structure=self.config["PreProcess"]["transform_ops"][1][
  55. "TableLabelEncode"
  56. ]["merge_no_span_structure"],
  57. dict_character=self.config["PostProcess"]["character_dict"],
  58. )
  59. return preprocessors, infer, postprocessors
  60. def process(self, batch_data: List[Union[str, np.ndarray]]) -> Dict[str, Any]:
  61. """
  62. Process a batch of data through the preprocessing, inference, and postprocessing.
  63. Args:
  64. batch_data (List[Union[str, np.ndarray], ...]): A batch of input data (e.g., image file paths).
  65. Returns:
  66. dict: A dictionary containing the input path, raw image, class IDs, scores, and label names for every instance of the batch. Keys include 'input_path', 'input_img', 'class_ids', 'scores', and 'label_names'.
  67. """
  68. batch_raw_imgs = self.preprocessors[0](imgs=batch_data.instances) # ReadImage
  69. ori_shapes = []
  70. for s in range(len(batch_raw_imgs)):
  71. ori_shapes.append([batch_raw_imgs[s].shape[1], batch_raw_imgs[s].shape[0]])
  72. batch_imgs = self.preprocessors[1](imgs=batch_raw_imgs) # ResizeByLong
  73. batch_imgs = self.preprocessors[2](imgs=batch_imgs) # Normalize
  74. pad_results = self.preprocessors[3](imgs=batch_imgs) # Pad
  75. pad_imgs = []
  76. padding_sizes = []
  77. for pad_img, padding_size in pad_results:
  78. pad_imgs.append(pad_img)
  79. padding_sizes.append(padding_size)
  80. batch_imgs = self.preprocessors[4](imgs=pad_imgs) # ToCHWImage
  81. x = self.preprocessors[5](imgs=batch_imgs) # ToBatch
  82. batch_preds = self.infer(x=x)
  83. table_result = self.postprocessors(
  84. pred=batch_preds,
  85. img_size=padding_sizes,
  86. ori_img_size=ori_shapes,
  87. )
  88. table_result_bbox = []
  89. table_result_structure = []
  90. table_result_structure_score = []
  91. for i in range(len(table_result)):
  92. table_result_bbox.append(table_result[i]["bbox"])
  93. table_result_structure.append(table_result[i]["structure"])
  94. table_result_structure_score.append(table_result[i]["structure_score"])
  95. final_result = {
  96. "input_path": batch_data.input_paths,
  97. "page_index": batch_data.page_indexes,
  98. "input_img": batch_raw_imgs,
  99. "bbox": table_result_bbox,
  100. "structure": table_result_structure,
  101. "structure_score": table_result_structure_score,
  102. }
  103. return final_result
  104. @register("DecodeImage")
  105. def build_readimg(self, channel_first=False, img_mode="BGR"):
  106. assert channel_first is False
  107. assert img_mode == "BGR"
  108. return ReadImage(format=img_mode)
  109. @register("TableLabelEncode")
  110. def foo(self, *args, **kwargs):
  111. return None
  112. @register("TableBoxEncode")
  113. def foo(self, *args, **kwargs):
  114. return None
  115. @register("ResizeTableImage")
  116. def build_resize_table(self, max_len=488, resize_bboxes=True):
  117. return ResizeByLong(target_long_edge=max_len)
  118. @register("NormalizeImage")
  119. def build_normalize(
  120. self,
  121. mean=[0.485, 0.456, 0.406],
  122. std=[0.229, 0.224, 0.225],
  123. scale=1 / 255,
  124. order="hwc",
  125. ):
  126. return Normalize(mean=mean, std=std)
  127. @register("PaddingTableImage")
  128. def build_padding(self, size=[488, 448], pad_value=0):
  129. return Pad(target_size=size[0], val=pad_value)
  130. @register("ToCHWImage")
  131. def build_to_chw(self):
  132. return ToCHWImage()
  133. @register("KeepKeys")
  134. def foo(self, *args, **kwargs):
  135. return None
  136. def _pack_res(self, single):
  137. keys = ["input_path", "bbox", "structure"]
  138. return TableRecResult({key: single[key] for key in keys})