pipeline.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from ..base import BasePipeline
  15. from typing import Any, Dict, Optional
  16. from scipy.ndimage import rotate
  17. from .result import DocPreprocessorResult
  18. from ....utils import logging
  19. import numpy as np
  20. ########## [TODO]后续需要更新路径
  21. from ...components.transforms import ReadImage
  22. from ...utils.pp_option import PaddlePredictorOption
  23. class DocPreprocessorPipeline(BasePipeline):
  24. """Doc Preprocessor Pipeline"""
  25. entities = "doc_preprocessor"
  26. def __init__(
  27. self,
  28. config: Dict,
  29. device: str = None,
  30. pp_option: PaddlePredictorOption = None,
  31. use_hpip: bool = False,
  32. hpi_params: Optional[Dict[str, Any]] = None,
  33. ) -> None:
  34. """Initializes the doc preprocessor pipeline.
  35. Args:
  36. config (Dict): Configuration dictionary containing various settings.
  37. device (str, optional): Device to run the predictions on. Defaults to None.
  38. pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
  39. use_hpip (bool, optional): Whether to use high-performance inference (hpip) for prediction. Defaults to False.
  40. hpi_params (Optional[Dict[str, Any]], optional): HPIP parameters. Defaults to None.
  41. """
  42. super().__init__(
  43. device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
  44. )
  45. self.use_doc_orientation_classify = True
  46. if "use_doc_orientation_classify" in config:
  47. self.use_doc_orientation_classify = config["use_doc_orientation_classify"]
  48. self.use_doc_unwarping = True
  49. if "use_doc_unwarping" in config:
  50. self.use_doc_unwarping = config["use_doc_unwarping"]
  51. if self.use_doc_orientation_classify:
  52. doc_ori_classify_config = config["SubModules"]["DocOrientationClassify"]
  53. self.doc_ori_classify_model = self.create_model(doc_ori_classify_config)
  54. if self.use_doc_unwarping:
  55. doc_unwarping_config = config["SubModules"]["DocUnwarping"]
  56. self.doc_unwarping_model = self.create_model(doc_unwarping_config)
  57. self.img_reader = ReadImage(format="BGR")
  58. def rotate_image(self, image_array: np.ndarray, rotate_angle: float) -> np.ndarray:
  59. """
  60. Rotate the given image array by the specified angle.
  61. Args:
  62. image_array (np.ndarray): The input image array to be rotated.
  63. rotate_angle (float): The angle in degrees by which to rotate the image.
  64. Returns:
  65. np.ndarray: The rotated image array.
  66. Raises:
  67. AssertionError: If rotate_angle is not in the range [0, 360).
  68. """
  69. assert (
  70. rotate_angle >= 0 and rotate_angle < 360
  71. ), "rotate_angle must in [0-360), but get {rotate_angle}."
  72. return rotate(image_array, rotate_angle, reshape=True)
  73. def check_input_params_valid(self, input_params: Dict) -> bool:
  74. """
  75. Check if the input parameters are valid based on the initialized models.
  76. Args:
  77. input_params (Dict): A dictionary containing input parameters.
  78. Returns:
  79. bool: True if all required models are initialized according to input parameters, False otherwise.
  80. """
  81. if (
  82. input_params["use_doc_orientation_classify"]
  83. and not self.use_doc_orientation_classify
  84. ):
  85. logging.error(
  86. "Set use_doc_orientation_classify, but the model for doc orientation classify is not initialized."
  87. )
  88. return False
  89. if input_params["use_doc_unwarping"] and not self.use_doc_unwarping:
  90. logging.error(
  91. "Set use_doc_unwarping, but the model for doc unwarping is not initialized."
  92. )
  93. return False
  94. return True
  95. def predict(
  96. self,
  97. input: str | list[str] | np.ndarray | list[np.ndarray],
  98. use_doc_orientation_classify: bool = True,
  99. use_doc_unwarping: bool = False,
  100. **kwargs
  101. ) -> DocPreprocessorResult:
  102. """
  103. Predict the preprocessing result for the input image or images.
  104. Args:
  105. input (str | list[str] | np.ndarray | list[np.ndarray]): The input image(s) or path(s) to the images.
  106. use_doc_orientation_classify (bool): Whether to use document orientation classification.
  107. use_doc_unwarping (bool): Whether to use document unwarping.
  108. **kwargs: Additional keyword arguments.
  109. Returns:
  110. DocPreprocessorResult: A generator yielding preprocessing results.
  111. """
  112. if not isinstance(input, list):
  113. input_list = [input]
  114. else:
  115. input_list = input
  116. input_params = {
  117. "use_doc_orientation_classify": use_doc_orientation_classify,
  118. "use_doc_unwarping": use_doc_unwarping,
  119. }
  120. if not self.check_input_params_valid(input_params):
  121. yield {"error": "input params invalid"}
  122. img_id = 1
  123. for input in input_list:
  124. if isinstance(input, str):
  125. image_array = next(self.img_reader(input))[0]["img"]
  126. else:
  127. image_array = input
  128. assert len(image_array.shape) == 3
  129. if input_params["use_doc_orientation_classify"]:
  130. pred = next(self.doc_ori_classify_model(image_array))
  131. angle = int(pred["label_names"][0])
  132. rot_img = self.rotate_image(image_array, angle)
  133. else:
  134. angle = -1
  135. rot_img = image_array
  136. if input_params["use_doc_unwarping"]:
  137. output_img = next(self.doc_unwarping_model(rot_img))["doctr_img"]
  138. else:
  139. output_img = rot_img
  140. single_img_res = {
  141. "input_image": image_array,
  142. "input_params": input_params,
  143. "angle": angle,
  144. "rot_img": rot_img,
  145. "output_img": output_img,
  146. "img_id": img_id,
  147. }
  148. img_id += 1
  149. yield DocPreprocessorResult(single_img_res)