pipeline.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Dict, List, Optional, Union
  15. import numpy as np
  16. from ....utils import logging
  17. from ....utils.deps import pipeline_requires_extra
  18. from ...common.batch_sampler import ImageBatchSampler
  19. from ...common.reader import ReadImage
  20. from ...utils.hpi import HPIConfig
  21. from ...utils.pp_option import PaddlePredictorOption
  22. from ..base import BasePipeline
  23. from ..components import rotate_image
  24. from .result import DocPreprocessorResult
  25. @pipeline_requires_extra("ocr")
  26. class DocPreprocessorPipeline(BasePipeline):
  27. """Doc Preprocessor Pipeline"""
  28. entities = "doc_preprocessor"
  29. def __init__(
  30. self,
  31. config: Dict,
  32. device: Optional[str] = None,
  33. pp_option: Optional[PaddlePredictorOption] = None,
  34. use_hpip: bool = False,
  35. hpi_config: Optional[Union[Dict[str, Any], HPIConfig]] = None,
  36. ) -> None:
  37. """Initializes the doc preprocessor pipeline.
  38. Args:
  39. config (Dict): Configuration dictionary containing various settings.
  40. device (str, optional): Device to run the predictions on. Defaults to None.
  41. pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
  42. use_hpip (bool, optional): Whether to use the high-performance
  43. inference plugin (HPIP) by default. Defaults to False.
  44. hpi_config (Optional[Union[Dict[str, Any], HPIConfig]], optional):
  45. The default high-performance inference configuration dictionary.
  46. Defaults to None.
  47. """
  48. super().__init__(
  49. device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_config=hpi_config
  50. )
  51. self.use_doc_orientation_classify = config.get(
  52. "use_doc_orientation_classify", True
  53. )
  54. if self.use_doc_orientation_classify:
  55. doc_ori_classify_config = config.get("SubModules", {}).get(
  56. "DocOrientationClassify",
  57. {"model_config_error": "config error for doc_ori_classify_model!"},
  58. )
  59. self.doc_ori_classify_model = self.create_model(doc_ori_classify_config)
  60. self.use_doc_unwarping = config.get("use_doc_unwarping", True)
  61. if self.use_doc_unwarping:
  62. doc_unwarping_config = config.get("SubModules", {}).get(
  63. "DocUnwarping",
  64. {"model_config_error": "config error for doc_unwarping_model!"},
  65. )
  66. self.doc_unwarping_model = self.create_model(doc_unwarping_config)
  67. self.batch_sampler = ImageBatchSampler(batch_size=1)
  68. self.img_reader = ReadImage(format="BGR")
  69. def check_model_settings_valid(self, model_settings: Dict) -> bool:
  70. """
  71. Check if the the input params for model settings are valid based on the initialized models.
  72. Args:
  73. model_settings (Dict): A dictionary containing model settings.
  74. Returns:
  75. bool: True if all required models are initialized according to the model settings, False otherwise.
  76. """
  77. if (
  78. model_settings["use_doc_orientation_classify"]
  79. and not self.use_doc_orientation_classify
  80. ):
  81. logging.error(
  82. "Set use_doc_orientation_classify, but the model for doc orientation classify is not initialized."
  83. )
  84. return False
  85. if model_settings["use_doc_unwarping"] and not self.use_doc_unwarping:
  86. logging.error(
  87. "Set use_doc_unwarping, but the model for doc unwarping is not initialized."
  88. )
  89. return False
  90. return True
  91. def get_model_settings(
  92. self, use_doc_orientation_classify, use_doc_unwarping
  93. ) -> dict:
  94. """
  95. Retrieve the model settings dictionary based on input parameters.
  96. Args:
  97. use_doc_orientation_classify (bool, optional): Whether to use document orientation classification.
  98. use_doc_unwarping (bool, optional): Whether to use document unwarping.
  99. Returns:
  100. dict: A dictionary containing the model settings.
  101. """
  102. if use_doc_orientation_classify is None:
  103. use_doc_orientation_classify = self.use_doc_orientation_classify
  104. if use_doc_unwarping is None:
  105. use_doc_unwarping = self.use_doc_unwarping
  106. model_settings = {
  107. "use_doc_orientation_classify": use_doc_orientation_classify,
  108. "use_doc_unwarping": use_doc_unwarping,
  109. }
  110. return model_settings
  111. def predict(
  112. self,
  113. input: Union[str, List[str], np.ndarray, List[np.ndarray]],
  114. use_doc_orientation_classify: Optional[bool] = None,
  115. use_doc_unwarping: Optional[bool] = None,
  116. ) -> DocPreprocessorResult:
  117. """
  118. Predict the preprocessing result for the input image or images.
  119. Args:
  120. input (Union[str, list[str], np.ndarray, list[np.ndarray]]): The input image(s) or path(s) to the images or pdfs.
  121. use_doc_orientation_classify (bool): Whether to use document orientation classification.
  122. use_doc_unwarping (bool): Whether to use document unwarping.
  123. **kwargs: Additional keyword arguments.
  124. Returns:
  125. DocPreprocessorResult: A generator yielding preprocessing results.
  126. """
  127. model_settings = self.get_model_settings(
  128. use_doc_orientation_classify, use_doc_unwarping
  129. )
  130. if not self.check_model_settings_valid(model_settings):
  131. yield {"error": "the input params for model settings are invalid!"}
  132. for img_id, batch_data in enumerate(self.batch_sampler(input)):
  133. image_array = self.img_reader(batch_data.instances)[0]
  134. if model_settings["use_doc_orientation_classify"]:
  135. pred = next(self.doc_ori_classify_model(image_array))
  136. angle = int(pred["label_names"][0])
  137. rot_img = rotate_image(image_array, angle)
  138. else:
  139. angle = -1
  140. rot_img = image_array
  141. if model_settings["use_doc_unwarping"]:
  142. output_img = next(self.doc_unwarping_model(rot_img))["doctr_img"]
  143. else:
  144. output_img = rot_img
  145. single_img_res = {
  146. "input_path": batch_data.input_paths[0],
  147. "page_index": batch_data.page_indexes[0],
  148. "input_img": image_array,
  149. "model_settings": model_settings,
  150. "angle": angle,
  151. "rot_img": rot_img,
  152. "output_img": output_img,
  153. }
  154. yield DocPreprocessorResult(single_img_res)