pipeline.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Dict, List, Optional, Union
  15. import numpy as np
  16. from ....utils import logging
  17. from ....utils.deps import pipeline_requires_extra
  18. from ...common.batch_sampler import ImageBatchSampler
  19. from ...common.reader import ReadImage
  20. from ...utils.hpi import HPIConfig
  21. from ...utils.pp_option import PaddlePredictorOption
  22. from .._parallel import AutoParallelImageSimpleInferencePipeline
  23. from ..base import BasePipeline
  24. from ..components import rotate_image
  25. from .result import DocPreprocessorResult
  26. class _DocPreprocessorPipeline(BasePipeline):
  27. """Doc Preprocessor Pipeline"""
  28. def __init__(
  29. self,
  30. config: Dict,
  31. device: Optional[str] = None,
  32. pp_option: Optional[PaddlePredictorOption] = None,
  33. use_hpip: bool = False,
  34. hpi_config: Optional[Union[Dict[str, Any], HPIConfig]] = None,
  35. ) -> None:
  36. """Initializes the doc preprocessor pipeline.
  37. Args:
  38. config (Dict): Configuration dictionary containing various settings.
  39. device (str, optional): Device to run the predictions on. Defaults to None.
  40. pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
  41. use_hpip (bool, optional): Whether to use the high-performance
  42. inference plugin (HPIP) by default. Defaults to False.
  43. hpi_config (Optional[Union[Dict[str, Any], HPIConfig]], optional):
  44. The default high-performance inference configuration dictionary.
  45. Defaults to None.
  46. """
  47. super().__init__(
  48. device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_config=hpi_config
  49. )
  50. self.use_doc_orientation_classify = config.get(
  51. "use_doc_orientation_classify", True
  52. )
  53. if self.use_doc_orientation_classify:
  54. doc_ori_classify_config = config.get("SubModules", {}).get(
  55. "DocOrientationClassify",
  56. {"model_config_error": "config error for doc_ori_classify_model!"},
  57. )
  58. self.doc_ori_classify_model = self.create_model(doc_ori_classify_config)
  59. self.use_doc_unwarping = config.get("use_doc_unwarping", True)
  60. if self.use_doc_unwarping:
  61. doc_unwarping_config = config.get("SubModules", {}).get(
  62. "DocUnwarping",
  63. {"model_config_error": "config error for doc_unwarping_model!"},
  64. )
  65. self.doc_unwarping_model = self.create_model(doc_unwarping_config)
  66. self.batch_sampler = ImageBatchSampler(batch_size=config.get("batch_size", 1))
  67. self.img_reader = ReadImage(format="BGR")
  68. def check_model_settings_valid(self, model_settings: Dict) -> bool:
  69. """
  70. Check if the the input params for model settings are valid based on the initialized models.
  71. Args:
  72. model_settings (Dict): A dictionary containing model settings.
  73. Returns:
  74. bool: True if all required models are initialized according to the model settings, False otherwise.
  75. """
  76. if (
  77. model_settings["use_doc_orientation_classify"]
  78. and not self.use_doc_orientation_classify
  79. ):
  80. logging.error(
  81. "Set use_doc_orientation_classify, but the model for doc orientation classify is not initialized."
  82. )
  83. return False
  84. if model_settings["use_doc_unwarping"] and not self.use_doc_unwarping:
  85. logging.error(
  86. "Set use_doc_unwarping, but the model for doc unwarping is not initialized."
  87. )
  88. return False
  89. return True
  90. def get_model_settings(
  91. self, use_doc_orientation_classify, use_doc_unwarping
  92. ) -> dict:
  93. """
  94. Retrieve the model settings dictionary based on input parameters.
  95. Args:
  96. use_doc_orientation_classify (bool, optional): Whether to use document orientation classification.
  97. use_doc_unwarping (bool, optional): Whether to use document unwarping.
  98. Returns:
  99. dict: A dictionary containing the model settings.
  100. """
  101. if use_doc_orientation_classify is None:
  102. use_doc_orientation_classify = self.use_doc_orientation_classify
  103. if use_doc_unwarping is None:
  104. use_doc_unwarping = self.use_doc_unwarping
  105. model_settings = {
  106. "use_doc_orientation_classify": use_doc_orientation_classify,
  107. "use_doc_unwarping": use_doc_unwarping,
  108. }
  109. return model_settings
  110. def predict(
  111. self,
  112. input: Union[str, List[str], np.ndarray, List[np.ndarray]],
  113. use_doc_orientation_classify: Optional[bool] = None,
  114. use_doc_unwarping: Optional[bool] = None,
  115. ) -> DocPreprocessorResult:
  116. """
  117. Predict the preprocessing result for the input image or images.
  118. Args:
  119. input (Union[str, list[str], np.ndarray, list[np.ndarray]]): The input image(s) or path(s) to the images or pdfs.
  120. use_doc_orientation_classify (bool): Whether to use document orientation classification.
  121. use_doc_unwarping (bool): Whether to use document unwarping.
  122. **kwargs: Additional keyword arguments.
  123. Returns:
  124. DocPreprocessorResult: A generator yielding preprocessing results.
  125. """
  126. model_settings = self.get_model_settings(
  127. use_doc_orientation_classify, use_doc_unwarping
  128. )
  129. if not self.check_model_settings_valid(model_settings):
  130. yield {"error": "the input params for model settings are invalid!"}
  131. for _, batch_data in enumerate(self.batch_sampler(input)):
  132. image_arrays = self.img_reader(batch_data.instances)
  133. if model_settings["use_doc_orientation_classify"]:
  134. preds = list(self.doc_ori_classify_model(image_arrays))
  135. angles = []
  136. rot_imgs = []
  137. for img, pred in zip(image_arrays, preds):
  138. angle = int(pred["label_names"][0])
  139. angles.append(angle)
  140. rot_img = rotate_image(img, angle)
  141. rot_imgs.append(rot_img)
  142. else:
  143. angles = [-1 for _ in range(len(image_arrays))]
  144. rot_imgs = image_arrays
  145. if model_settings["use_doc_unwarping"]:
  146. output_imgs = [
  147. item["doctr_img"][:, :, ::-1]
  148. for item in self.doc_unwarping_model(rot_imgs)
  149. ]
  150. else:
  151. output_imgs = rot_imgs
  152. for input_path, page_index, image_array, angle, rot_img, output_img in zip(
  153. batch_data.input_paths,
  154. batch_data.page_indexes,
  155. image_arrays,
  156. angles,
  157. rot_imgs,
  158. output_imgs,
  159. ):
  160. single_img_res = {
  161. "input_path": input_path,
  162. "page_index": page_index,
  163. "input_img": image_array,
  164. "model_settings": model_settings,
  165. "angle": angle,
  166. "rot_img": rot_img,
  167. "output_img": output_img,
  168. }
  169. yield DocPreprocessorResult(single_img_res)
  170. @pipeline_requires_extra("ocr")
  171. class DocPreprocessorPipeline(AutoParallelImageSimpleInferencePipeline):
  172. entities = "doc_preprocessor"
  173. @property
  174. def _pipeline_cls(self):
  175. return _DocPreprocessorPipeline
  176. def _get_batch_size(self, config):
  177. return config.get("batch_size", 1)