pipeline.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Dict, List, Optional, Union
  15. import numpy as np
  16. from ....utils import logging
  17. from ....utils.deps import pipeline_requires_extra
  18. from ...common.batch_sampler import ImageBatchSampler
  19. from ...common.reader import ReadImage
  20. from ...utils.benchmark import benchmark
  21. from ...utils.hpi import HPIConfig
  22. from ...utils.pp_option import PaddlePredictorOption
  23. from .._parallel import AutoParallelImageSimpleInferencePipeline
  24. from ..base import BasePipeline
  25. from ..components import rotate_image
  26. from .result import DocPreprocessorResult
  27. @benchmark.time_methods
  28. class _DocPreprocessorPipeline(BasePipeline):
  29. """Doc Preprocessor Pipeline"""
  30. def __init__(
  31. self,
  32. config: Dict,
  33. device: Optional[str] = None,
  34. pp_option: Optional[PaddlePredictorOption] = None,
  35. use_hpip: bool = False,
  36. hpi_config: Optional[Union[Dict[str, Any], HPIConfig]] = None,
  37. ) -> None:
  38. """Initializes the doc preprocessor pipeline.
  39. Args:
  40. config (Dict): Configuration dictionary containing various settings.
  41. device (str, optional): Device to run the predictions on. Defaults to None.
  42. pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
  43. use_hpip (bool, optional): Whether to use the high-performance
  44. inference plugin (HPIP) by default. Defaults to False.
  45. hpi_config (Optional[Union[Dict[str, Any], HPIConfig]], optional):
  46. The default high-performance inference configuration dictionary.
  47. Defaults to None.
  48. """
  49. super().__init__(
  50. device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_config=hpi_config
  51. )
  52. self.use_doc_orientation_classify = config.get(
  53. "use_doc_orientation_classify", True
  54. )
  55. if self.use_doc_orientation_classify:
  56. doc_ori_classify_config = config.get("SubModules", {}).get(
  57. "DocOrientationClassify",
  58. {"model_config_error": "config error for doc_ori_classify_model!"},
  59. )
  60. self.doc_ori_classify_model = self.create_model(doc_ori_classify_config)
  61. self.use_doc_unwarping = config.get("use_doc_unwarping", True)
  62. if self.use_doc_unwarping:
  63. doc_unwarping_config = config.get("SubModules", {}).get(
  64. "DocUnwarping",
  65. {"model_config_error": "config error for doc_unwarping_model!"},
  66. )
  67. self.doc_unwarping_model = self.create_model(doc_unwarping_config)
  68. self.batch_sampler = ImageBatchSampler(batch_size=config.get("batch_size", 1))
  69. self.img_reader = ReadImage(format="BGR")
  70. def check_model_settings_valid(self, model_settings: Dict) -> bool:
  71. """
  72. Check if the the input params for model settings are valid based on the initialized models.
  73. Args:
  74. model_settings (Dict): A dictionary containing model settings.
  75. Returns:
  76. bool: True if all required models are initialized according to the model settings, False otherwise.
  77. """
  78. if (
  79. model_settings["use_doc_orientation_classify"]
  80. and not self.use_doc_orientation_classify
  81. ):
  82. logging.error(
  83. "Set use_doc_orientation_classify, but the model for doc orientation classify is not initialized."
  84. )
  85. return False
  86. if model_settings["use_doc_unwarping"] and not self.use_doc_unwarping:
  87. logging.error(
  88. "Set use_doc_unwarping, but the model for doc unwarping is not initialized."
  89. )
  90. return False
  91. return True
  92. def get_model_settings(
  93. self, use_doc_orientation_classify, use_doc_unwarping
  94. ) -> dict:
  95. """
  96. Retrieve the model settings dictionary based on input parameters.
  97. Args:
  98. use_doc_orientation_classify (bool, optional): Whether to use document orientation classification.
  99. use_doc_unwarping (bool, optional): Whether to use document unwarping.
  100. Returns:
  101. dict: A dictionary containing the model settings.
  102. """
  103. if use_doc_orientation_classify is None:
  104. use_doc_orientation_classify = self.use_doc_orientation_classify
  105. if use_doc_unwarping is None:
  106. use_doc_unwarping = self.use_doc_unwarping
  107. model_settings = {
  108. "use_doc_orientation_classify": use_doc_orientation_classify,
  109. "use_doc_unwarping": use_doc_unwarping,
  110. }
  111. return model_settings
  112. def predict(
  113. self,
  114. input: Union[str, List[str], np.ndarray, List[np.ndarray]],
  115. use_doc_orientation_classify: Optional[bool] = None,
  116. use_doc_unwarping: Optional[bool] = None,
  117. ) -> DocPreprocessorResult:
  118. """
  119. Predict the preprocessing result for the input image or images.
  120. Args:
  121. input (Union[str, list[str], np.ndarray, list[np.ndarray]]): The input image(s) or path(s) to the images or pdfs.
  122. use_doc_orientation_classify (bool): Whether to use document orientation classification.
  123. use_doc_unwarping (bool): Whether to use document unwarping.
  124. **kwargs: Additional keyword arguments.
  125. Returns:
  126. DocPreprocessorResult: A generator yielding preprocessing results.
  127. """
  128. model_settings = self.get_model_settings(
  129. use_doc_orientation_classify, use_doc_unwarping
  130. )
  131. if not self.check_model_settings_valid(model_settings):
  132. yield {"error": "the input params for model settings are invalid!"}
  133. for _, batch_data in enumerate(self.batch_sampler(input)):
  134. image_arrays = self.img_reader(batch_data.instances)
  135. if model_settings["use_doc_orientation_classify"]:
  136. preds = list(self.doc_ori_classify_model(image_arrays))
  137. angles = []
  138. rot_imgs = []
  139. for img, pred in zip(image_arrays, preds):
  140. angle = int(pred["label_names"][0])
  141. angles.append(angle)
  142. rot_img = rotate_image(img, angle)
  143. rot_imgs.append(rot_img)
  144. else:
  145. angles = [-1 for _ in range(len(image_arrays))]
  146. rot_imgs = image_arrays
  147. if model_settings["use_doc_unwarping"]:
  148. output_imgs = [
  149. item["doctr_img"][:, :, ::-1]
  150. for item in self.doc_unwarping_model(rot_imgs)
  151. ]
  152. else:
  153. output_imgs = rot_imgs
  154. for input_path, page_index, image_array, angle, rot_img, output_img in zip(
  155. batch_data.input_paths,
  156. batch_data.page_indexes,
  157. image_arrays,
  158. angles,
  159. rot_imgs,
  160. output_imgs,
  161. ):
  162. single_img_res = {
  163. "input_path": input_path,
  164. "page_index": page_index,
  165. "input_img": image_array,
  166. "model_settings": model_settings,
  167. "angle": angle,
  168. "rot_img": rot_img,
  169. "output_img": output_img,
  170. }
  171. yield DocPreprocessorResult(single_img_res)
  172. @pipeline_requires_extra("ocr", alt="ocr-core")
  173. class DocPreprocessorPipeline(AutoParallelImageSimpleInferencePipeline):
  174. entities = "doc_preprocessor"
  175. @property
  176. def _pipeline_cls(self):
  177. return _DocPreprocessorPipeline
  178. def _get_batch_size(self, config):
  179. return config.get("batch_size", 1)