pipeline.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Dict, List, Optional
  15. import numpy as np
  16. from scipy.ndimage import rotate
  17. from ...common.reader import ReadImage
  18. from ...common.batch_sampler import ImageBatchSampler
  19. from ...utils.pp_option import PaddlePredictorOption
  20. from ..base import BasePipeline
  21. from ..components import CropByPolys, SortQuadBoxes, SortPolyBoxes
  22. from .result import OCRResult
  23. from ..doc_preprocessor.result import DocPreprocessorResult
  24. from ....utils import logging
  25. class OCRPipeline(BasePipeline):
  26. """OCR Pipeline"""
  27. entities = "OCR"
  28. def __init__(
  29. self,
  30. config: Dict,
  31. device: Optional[str] = None,
  32. pp_option: Optional[PaddlePredictorOption] = None,
  33. use_hpip: bool = False,
  34. hpi_params: Optional[Dict[str, Any]] = None,
  35. ) -> None:
  36. """
  37. Initializes the class with given configurations and options.
  38. Args:
  39. config (Dict): Configuration dictionary containing various settings.
  40. device (str, optional): Device to run the predictions on. Defaults to None.
  41. pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
  42. use_hpip (bool, optional): Whether to use high-performance inference (hpip) for prediction. Defaults to False.
  43. hpi_params (Optional[Dict[str, Any]], optional): HPIP parameters. Defaults to None.
  44. """
  45. super().__init__(
  46. device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params
  47. )
  48. self.use_doc_preprocessor = config.get("use_doc_preprocessor", True)
  49. if self.use_doc_preprocessor:
  50. doc_preprocessor_config = config["SubPipelines"]["DocPreprocessor"]
  51. self.doc_preprocessor_pipeline = self.create_pipeline(
  52. doc_preprocessor_config
  53. )
  54. self.use_textline_orientation = config.get("use_textline_orientation", True)
  55. if self.use_textline_orientation:
  56. textline_orientation_config = config["SubModules"]["TextLineOrientation"]
  57. # TODO: add batch_size
  58. # batch_size = textline_orientation_config.get("batch_size", 1)
  59. # self.textline_orientation_model = self.create_model(
  60. # textline_orientation_config, batch_size=batch_size
  61. # )
  62. self.textline_orientation_model = self.create_model(
  63. textline_orientation_config
  64. )
  65. text_det_config = config["SubModules"]["TextDetection"]
  66. self.text_det_limit_side_len = text_det_config.get("limit_side_len", 960)
  67. self.text_det_limit_type = text_det_config.get("limit_type", "max")
  68. self.text_det_thresh = text_det_config.get("thresh", 0.3)
  69. self.text_det_box_thresh = text_det_config.get("box_thresh", 0.6)
  70. self.text_det_max_candidates = text_det_config.get("max_candidates", 1000)
  71. self.text_det_unclip_ratio = text_det_config.get("unclip_ratio", 2.0)
  72. self.text_det_use_dilation = text_det_config.get("use_dilation", False)
  73. self.text_det_model = self.create_model(
  74. text_det_config,
  75. limit_side_len=self.text_det_limit_side_len,
  76. limit_type=self.text_det_limit_type,
  77. thresh=self.text_det_thresh,
  78. box_thresh=self.text_det_box_thresh,
  79. max_candidates=self.text_det_max_candidates,
  80. unclip_ratio=self.text_det_unclip_ratio,
  81. use_dilation=self.text_det_use_dilation,
  82. )
  83. text_rec_config = config["SubModules"]["TextRecognition"]
  84. # TODO: add batch_size
  85. # batch_size = text_rec_config.get("batch_size", 1)
  86. # self.text_rec_model = self.create_model(text_rec_config,
  87. # batch_size=batch_size)
  88. self.text_rec_score_thresh = text_rec_config.get("score_thresh", 0)
  89. self.text_rec_model = self.create_model(text_rec_config)
  90. self.text_type = config["text_type"]
  91. if self.text_type == "general":
  92. self._sort_boxes = SortQuadBoxes()
  93. self._crop_by_polys = CropByPolys(det_box_type="quad")
  94. elif self.text_type == "seal":
  95. self._sort_boxes = SortPolyBoxes()
  96. self._crop_by_polys = CropByPolys(det_box_type="poly")
  97. else:
  98. raise ValueError("Unsupported text type {}".format(self.text_type))
  99. self.batch_sampler = ImageBatchSampler(batch_size=1)
  100. self.img_reader = ReadImage(format="BGR")
  101. def rotate_image(
  102. self, image_array_list: List[np.ndarray], rotate_angle_list: List[int]
  103. ) -> List[np.ndarray]:
  104. """
  105. Rotate the given image arrays by their corresponding angles.
  106. 0 corresponds to 0 degrees, 1 corresponds to 180 degrees.
  107. Args:
  108. image_array_list (List[np.ndarray]): A list of input image arrays to be rotated.
  109. rotate_angle_list (List[int]): A list of rotation indicators (0 or 1).
  110. 0 means rotate by 0 degrees
  111. 1 means rotate by 180 degrees
  112. Returns:
  113. List[np.ndarray]: A list of rotated image arrays.
  114. Raises:
  115. AssertionError: If any rotate_angle is not 0 or 1.
  116. AssertionError: If the lengths of input lists don't match.
  117. """
  118. assert len(image_array_list) == len(
  119. rotate_angle_list
  120. ), f"Length of image_array_list ({len(image_array_list)}) must match length of rotate_angle_list ({len(rotate_angle_list)})"
  121. for angle in rotate_angle_list:
  122. assert angle in [0, 1], f"rotate_angle must be 0 or 1, now it's {angle}"
  123. rotated_images = []
  124. for image_array, rotate_indicator in zip(image_array_list, rotate_angle_list):
  125. # Convert 0/1 indicator to actual rotation angle
  126. rotate_angle = rotate_indicator * 180
  127. rotated_image = rotate(image_array, rotate_angle, reshape=True)
  128. rotated_images.append(rotated_image)
  129. return rotated_images
  130. def check_model_settings_valid(self, model_settings: Dict) -> bool:
  131. """
  132. Check if the input parameters are valid based on the initialized models.
  133. Args:
  134. model_info_params(Dict): A dictionary containing input parameters.
  135. Returns:
  136. bool: True if all required models are initialized according to input parameters, False otherwise.
  137. """
  138. if model_settings["use_doc_preprocessor"] and not self.use_doc_preprocessor:
  139. logging.error(
  140. "Set use_doc_preprocessor, but the models for doc preprocessor are not initialized."
  141. )
  142. return False
  143. if (
  144. model_settings["use_textline_orientation"]
  145. and not self.use_textline_orientation
  146. ):
  147. logging.error(
  148. "Set use_textline_orientation, but the models for use_textline_orientation are not initialized."
  149. )
  150. return False
  151. return True
  152. def predict_doc_preprocessor_res(
  153. self, image_array: np.ndarray, model_settings: dict
  154. ) -> tuple[DocPreprocessorResult, np.ndarray]:
  155. """
  156. Preprocess the document image based on input parameters.
  157. Args:
  158. image_array (np.ndarray): The input image array.
  159. model_settings (dict): Dictionary containing preprocessing parameters.
  160. Returns:
  161. tuple[DocPreprocessorResult, np.ndarray]: A tuple containing the preprocessing
  162. result dictionary and the processed image array.
  163. """
  164. if model_settings["use_doc_preprocessor"]:
  165. use_doc_orientation_classify = model_settings[
  166. "use_doc_orientation_classify"
  167. ]
  168. use_doc_unwarping = model_settings["use_doc_unwarping"]
  169. doc_preprocessor_res = next(
  170. self.doc_preprocessor_pipeline(
  171. image_array,
  172. use_doc_orientation_classify=use_doc_orientation_classify,
  173. use_doc_unwarping=use_doc_unwarping,
  174. )
  175. )
  176. else:
  177. doc_preprocessor_res = {"output_img": image_array}
  178. return doc_preprocessor_res
  179. def get_model_settings(
  180. self,
  181. use_doc_orientation_classify: Optional[bool],
  182. use_doc_unwarping: Optional[bool],
  183. use_textline_orientation: Optional[bool],
  184. ) -> dict:
  185. """
  186. Get the model settings based on the provided parameters or default values.
  187. Args:
  188. use_doc_orientation_classify (Optional[bool]): Whether to use document orientation classification.
  189. use_doc_unwarping (Optional[bool]): Whether to use document unwarping.
  190. use_textline_orientation (Optional[bool]): Whether to use textline orientation.
  191. Returns:
  192. dict: A dictionary containing the model settings.
  193. """
  194. if use_doc_orientation_classify is None:
  195. use_doc_orientation_classify = self.use_doc_orientation_classify
  196. if use_doc_unwarping is None:
  197. use_doc_unwarping = self.use_doc_unwarping
  198. if use_textline_orientation is None:
  199. use_textline_orientation = self.use_textline_orientation
  200. return dict(
  201. use_doc_orientation_classify=use_doc_orientation_classify,
  202. use_doc_unwarping=use_doc_unwarping,
  203. use_textline_orientation=use_textline_orientation,
  204. )
  205. def get_text_det_params(
  206. self,
  207. text_det_limit_side_len: Optional[int] = None,
  208. text_det_limit_type: Optional[str] = None,
  209. text_det_thresh: Optional[float] = None,
  210. text_det_box_thresh: Optional[float] = None,
  211. text_det_max_candidates: Optional[int] = None,
  212. text_det_unclip_ratio: Optional[float] = None,
  213. text_det_use_dilation: Optional[bool] = None,
  214. ) -> dict:
  215. """
  216. Get text detection parameters.
  217. If a parameter is None, its default value from the instance will be used.
  218. Args:
  219. text_det_limit_side_len (Optional[int]): The maximum side length of the text box.
  220. text_det_limit_type (Optional[str]): The type of limit to apply to the text box.
  221. text_det_thresh (Optional[float]): The threshold for text detection.
  222. text_det_box_thresh (Optional[float]): The threshold for the bounding box.
  223. text_det_max_candidates (Optional[int]): The maximum number of candidate text boxes.
  224. text_det_unclip_ratio (Optional[float]): The ratio for unclipping the text box.
  225. text_det_use_dilation (Optional[bool]): Whether to use dilation in text detection.
  226. Returns:
  227. dict: A dictionary containing the text detection parameters.
  228. """
  229. if text_det_limit_side_len is None:
  230. text_det_limit_side_len = self.text_det_limit_side_len
  231. if text_det_limit_type is None:
  232. text_det_limit_type = self.text_det_limit_type
  233. if text_det_thresh is None:
  234. text_det_thresh = self.text_det_thresh
  235. if text_det_box_thresh is None:
  236. text_det_box_thresh = self.text_det_box_thresh
  237. if text_det_max_candidates is None:
  238. text_det_max_candidates = self.text_det_max_candidates
  239. if text_det_unclip_ratio is None:
  240. text_det_unclip_ratio = self.text_det_unclip_ratio
  241. if text_det_use_dilation is None:
  242. text_det_use_dilation = self.text_det_use_dilation
  243. return dict(
  244. limit_side_len=text_det_limit_side_len,
  245. limit_type=text_det_limit_type,
  246. thresh=text_det_thresh,
  247. box_thresh=text_det_box_thresh,
  248. max_candidates=text_det_max_candidates,
  249. unclip_ratio=text_det_unclip_ratio,
  250. use_dilation=text_det_use_dilation,
  251. )
  252. def predict(
  253. self,
  254. input: str | list[str] | np.ndarray | list[np.ndarray],
  255. use_doc_orientation_classify: Optional[bool] = None,
  256. use_doc_unwarping: Optional[bool] = None,
  257. use_textline_orientation: Optional[bool] = None,
  258. text_det_limit_side_len: Optional[int] = None,
  259. text_det_limit_type: Optional[str] = None,
  260. text_det_thresh: Optional[float] = None,
  261. text_det_box_thresh: Optional[float] = None,
  262. text_det_max_candidates: Optional[int] = None,
  263. text_det_unclip_ratio: Optional[float] = None,
  264. text_det_use_dilation: Optional[bool] = None,
  265. text_rec_score_thresh: Optional[float] = None,
  266. ) -> OCRResult:
  267. """
  268. Predict OCR results based on input images or arrays with optional preprocessing steps.
  269. Args:
  270. input (str | list[str] | np.ndarray | list[np.ndarray]): Input image of pdf path(s) or numpy array(s).
  271. use_doc_orientation_classify (Optional[bool]): Whether to use document orientation classification.
  272. use_doc_unwarping (Optional[bool]): Whether to use document unwarping.
  273. use_textline_orientation (Optional[bool]): Whether to use textline orientation prediction.
  274. text_det_limit_side_len (Optional[int]): Maximum side length for text detection.
  275. text_det_limit_type (Optional[str]): Type of limit to apply for text detection.
  276. text_det_thresh (Optional[float]): Threshold for text detection.
  277. text_det_box_thresh (Optional[float]): Threshold for text detection boxes.
  278. text_det_max_candidates (Optional[int]): Maximum number of text detection candidates.
  279. text_det_unclip_ratio (Optional[float]): Ratio for unclipping text detection boxes.
  280. text_det_use_dilation (Optional[bool]): Whether to use dilation in text detection.
  281. text_rec_score_thresh (Optional[float]): Score threshold for text recognition.
  282. Returns:
  283. OCRResult: Generator yielding OCR results for each input image.
  284. """
  285. model_settings = self.get_model_settings(
  286. use_doc_orientation_classify, use_doc_unwarping, use_textline_orientation
  287. )
  288. if (
  289. model_settings["use_doc_orientation_classify"]
  290. or model_settings["use_doc_unwarping"]
  291. ):
  292. model_settings["use_doc_preprocessor"] = True
  293. else:
  294. model_settings["use_doc_preprocessor"] = False
  295. if not self.check_model_settings_valid(model_settings):
  296. yield {"error": "the input params for model settings are invalid!"}
  297. text_det_params = self.get_text_det_params(
  298. text_det_limit_side_len,
  299. text_det_limit_type,
  300. text_det_thresh,
  301. text_det_box_thresh,
  302. text_det_max_candidates,
  303. text_det_unclip_ratio,
  304. text_det_use_dilation,
  305. )
  306. if text_rec_score_thresh is None:
  307. text_rec_score_thresh = self.text_rec_score_thresh
  308. for img_id, batch_data in enumerate(self.batch_sampler(input)):
  309. if not isinstance(batch_data[0], str):
  310. # TODO: add support input_pth for ndarray and pdf
  311. input_path = f"{img_id}"
  312. else:
  313. input_path = batch_data[0]
  314. image_array = self.img_reader(batch_data)[0]
  315. doc_preprocessor_res = self.predict_doc_preprocessor_res(
  316. image_array, model_settings
  317. )
  318. doc_preprocessor_image = doc_preprocessor_res["output_img"]
  319. det_res = next(
  320. self.text_det_model(doc_preprocessor_image, **text_det_params)
  321. )
  322. dt_polys = det_res["dt_polys"]
  323. dt_scores = det_res["dt_scores"]
  324. dt_polys = self._sort_boxes(dt_polys)
  325. single_img_res = {
  326. "input_path": batch_data[0],
  327. "doc_preprocessor_res": doc_preprocessor_res,
  328. "dt_polys": dt_polys,
  329. "model_settings": model_settings,
  330. "text_det_params": text_det_params,
  331. "text_type": self.text_type,
  332. }
  333. single_img_res["rec_texts"] = []
  334. single_img_res["rec_scores"] = []
  335. single_img_res["rec_boxes"] = []
  336. if len(dt_polys) > 0:
  337. all_subs_of_img = list(
  338. self._crop_by_polys(doc_preprocessor_image, dt_polys)
  339. )
  340. # use textline orientation model
  341. if model_settings["use_textline_orientation"]:
  342. angles = [
  343. textline_angle_info["class_ids"][0]
  344. for textline_angle_info in self.textline_orientation_model(
  345. all_subs_of_img
  346. )
  347. ]
  348. single_img_res["textline_orientation_angle"] = angles
  349. all_subs_of_img = self.rotate_image(all_subs_of_img, angles)
  350. rno = -1
  351. for rec_res in self.text_rec_model(all_subs_of_img):
  352. rno += 1
  353. if rec_res["rec_score"] >= text_rec_score_thresh:
  354. single_img_res["rec_texts"].append(rec_res["rec_text"])
  355. single_img_res["rec_scores"].append(rec_res["rec_score"])
  356. single_img_res["rec_boxes"].append(dt_polys[rno])
  357. yield OCRResult(single_img_res)