pipeline.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Dict, List, Optional, Union
  15. import numpy as np
  16. from scipy.ndimage import rotate
  17. from ...common.reader import ReadImage
  18. from ...common.batch_sampler import ImageBatchSampler
  19. from ...utils.pp_option import PaddlePredictorOption
  20. from ...utils.hpi import HPIConfig
  21. from ..base import BasePipeline
  22. from ..components import (
  23. CropByPolys,
  24. SortQuadBoxes,
  25. SortPolyBoxes,
  26. convert_points_to_boxes,
  27. )
  28. from .result import OCRResult
  29. from ..doc_preprocessor.result import DocPreprocessorResult
  30. from ....utils import logging
  31. class OCRPipeline(BasePipeline):
  32. """OCR Pipeline"""
  33. entities = "OCR"
  34. def __init__(
  35. self,
  36. config: Dict,
  37. device: Optional[str] = None,
  38. pp_option: Optional[PaddlePredictorOption] = None,
  39. use_hpip: bool = False,
  40. hpi_config: Optional[Union[Dict[str, Any], HPIConfig]] = None,
  41. ) -> None:
  42. """
  43. Initializes the class with given configurations and options.
  44. Args:
  45. config (Dict): Configuration dictionary containing various settings.
  46. device (str, optional): Device to run the predictions on. Defaults to None.
  47. pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
  48. use_hpip (bool, optional): Whether to use the high-performance
  49. inference plugin (HPIP) by default. Defaults to False.
  50. hpi_config (Optional[Union[Dict[str, Any], HPIConfig]], optional):
  51. The default high-performance inference configuration dictionary.
  52. Defaults to None.
  53. """
  54. super().__init__(
  55. device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_config=hpi_config
  56. )
  57. self.use_doc_preprocessor = config.get("use_doc_preprocessor", True)
  58. if self.use_doc_preprocessor:
  59. doc_preprocessor_config = config.get("SubPipelines", {}).get(
  60. "DocPreprocessor",
  61. {
  62. "pipeline_config_error": "config error for doc_preprocessor_pipeline!"
  63. },
  64. )
  65. self.doc_preprocessor_pipeline = self.create_pipeline(
  66. doc_preprocessor_config
  67. )
  68. self.use_textline_orientation = config.get("use_textline_orientation", True)
  69. if self.use_textline_orientation:
  70. textline_orientation_config = config.get("SubModules", {}).get(
  71. "TextLineOrientation",
  72. {"model_config_error": "config error for textline_orientation_model!"},
  73. )
  74. self.textline_orientation_model = self.create_model(
  75. textline_orientation_config
  76. )
  77. text_det_config = config.get("SubModules", {}).get(
  78. "TextDetection", {"model_config_error": "config error for text_det_model!"}
  79. )
  80. self.text_type = config["text_type"]
  81. if self.text_type == "general":
  82. self.text_det_limit_side_len = text_det_config.get("limit_side_len", 960)
  83. self.text_det_limit_type = text_det_config.get("limit_type", "max")
  84. self.text_det_thresh = text_det_config.get("thresh", 0.3)
  85. self.text_det_box_thresh = text_det_config.get("box_thresh", 0.6)
  86. self.input_shape = text_det_config.get("input_shape", None)
  87. self.text_det_unclip_ratio = text_det_config.get("unclip_ratio", 2.0)
  88. self._sort_boxes = SortQuadBoxes()
  89. self._crop_by_polys = CropByPolys(det_box_type="quad")
  90. elif self.text_type == "seal":
  91. self.text_det_limit_side_len = text_det_config.get("limit_side_len", 736)
  92. self.text_det_limit_type = text_det_config.get("limit_type", "min")
  93. self.text_det_thresh = text_det_config.get("thresh", 0.2)
  94. self.text_det_box_thresh = text_det_config.get("box_thresh", 0.6)
  95. self.text_det_unclip_ratio = text_det_config.get("unclip_ratio", 0.5)
  96. self.input_shape = text_det_config.get("input_shape", None)
  97. self._sort_boxes = SortPolyBoxes()
  98. self._crop_by_polys = CropByPolys(det_box_type="poly")
  99. else:
  100. raise ValueError("Unsupported text type {}".format(self.text_type))
  101. self.text_det_model = self.create_model(
  102. text_det_config,
  103. limit_side_len=self.text_det_limit_side_len,
  104. limit_type=self.text_det_limit_type,
  105. thresh=self.text_det_thresh,
  106. box_thresh=self.text_det_box_thresh,
  107. unclip_ratio=self.text_det_unclip_ratio,
  108. input_shape=self.input_shape,
  109. )
  110. text_rec_config = config.get("SubModules", {}).get(
  111. "TextRecognition",
  112. {"model_config_error": "config error for text_rec_model!"},
  113. )
  114. self.text_rec_score_thresh = text_rec_config.get("score_thresh", 0)
  115. self.input_shape = text_rec_config.get("input_shape", None)
  116. self.text_rec_model = self.create_model(
  117. text_rec_config, input_shape=self.input_shape
  118. )
  119. self.batch_sampler = ImageBatchSampler(batch_size=1)
  120. self.img_reader = ReadImage(format="BGR")
  121. def rotate_image(
  122. self, image_array_list: List[np.ndarray], rotate_angle_list: List[int]
  123. ) -> List[np.ndarray]:
  124. """
  125. Rotate the given image arrays by their corresponding angles.
  126. 0 corresponds to 0 degrees, 1 corresponds to 180 degrees.
  127. Args:
  128. image_array_list (List[np.ndarray]): A list of input image arrays to be rotated.
  129. rotate_angle_list (List[int]): A list of rotation indicators (0 or 1).
  130. 0 means rotate by 0 degrees
  131. 1 means rotate by 180 degrees
  132. Returns:
  133. List[np.ndarray]: A list of rotated image arrays.
  134. Raises:
  135. AssertionError: If any rotate_angle is not 0 or 1.
  136. AssertionError: If the lengths of input lists don't match.
  137. """
  138. assert len(image_array_list) == len(
  139. rotate_angle_list
  140. ), f"Length of image_array_list ({len(image_array_list)}) must match length of rotate_angle_list ({len(rotate_angle_list)})"
  141. for angle in rotate_angle_list:
  142. assert angle in [0, 1], f"rotate_angle must be 0 or 1, now it's {angle}"
  143. rotated_images = []
  144. for image_array, rotate_indicator in zip(image_array_list, rotate_angle_list):
  145. # Convert 0/1 indicator to actual rotation angle
  146. rotate_angle = rotate_indicator * 180
  147. rotated_image = rotate(image_array, rotate_angle, reshape=True)
  148. rotated_images.append(rotated_image)
  149. return rotated_images
  150. def check_model_settings_valid(self, model_settings: Dict) -> bool:
  151. """
  152. Check if the input parameters are valid based on the initialized models.
  153. Args:
  154. model_info_params(Dict): A dictionary containing input parameters.
  155. Returns:
  156. bool: True if all required models are initialized according to input parameters, False otherwise.
  157. """
  158. if model_settings["use_doc_preprocessor"] and not self.use_doc_preprocessor:
  159. logging.error(
  160. "Set use_doc_preprocessor, but the models for doc preprocessor are not initialized."
  161. )
  162. return False
  163. if (
  164. model_settings["use_textline_orientation"]
  165. and not self.use_textline_orientation
  166. ):
  167. logging.error(
  168. "Set use_textline_orientation, but the models for use_textline_orientation are not initialized."
  169. )
  170. return False
  171. return True
  172. def get_model_settings(
  173. self,
  174. use_doc_orientation_classify: Optional[bool],
  175. use_doc_unwarping: Optional[bool],
  176. use_textline_orientation: Optional[bool],
  177. ) -> dict:
  178. """
  179. Get the model settings based on the provided parameters or default values.
  180. Args:
  181. use_doc_orientation_classify (Optional[bool]): Whether to use document orientation classification.
  182. use_doc_unwarping (Optional[bool]): Whether to use document unwarping.
  183. use_textline_orientation (Optional[bool]): Whether to use textline orientation.
  184. Returns:
  185. dict: A dictionary containing the model settings.
  186. """
  187. if use_doc_orientation_classify is None and use_doc_unwarping is None:
  188. use_doc_preprocessor = self.use_doc_preprocessor
  189. else:
  190. if use_doc_orientation_classify is True or use_doc_unwarping is True:
  191. use_doc_preprocessor = True
  192. else:
  193. use_doc_preprocessor = False
  194. if use_textline_orientation is None:
  195. use_textline_orientation = self.use_textline_orientation
  196. return dict(
  197. use_doc_preprocessor=use_doc_preprocessor,
  198. use_textline_orientation=use_textline_orientation,
  199. )
  200. def get_text_det_params(
  201. self,
  202. text_det_limit_side_len: Optional[int] = None,
  203. text_det_limit_type: Optional[str] = None,
  204. text_det_thresh: Optional[float] = None,
  205. text_det_box_thresh: Optional[float] = None,
  206. text_det_unclip_ratio: Optional[float] = None,
  207. ) -> dict:
  208. """
  209. Get text detection parameters.
  210. If a parameter is None, its default value from the instance will be used.
  211. Args:
  212. text_det_limit_side_len (Optional[int]): The maximum side length of the text box.
  213. text_det_limit_type (Optional[str]): The type of limit to apply to the text box.
  214. text_det_thresh (Optional[float]): The threshold for text detection.
  215. text_det_box_thresh (Optional[float]): The threshold for the bounding box.
  216. text_det_unclip_ratio (Optional[float]): The ratio for unclipping the text box.
  217. Returns:
  218. dict: A dictionary containing the text detection parameters.
  219. """
  220. if text_det_limit_side_len is None:
  221. text_det_limit_side_len = self.text_det_limit_side_len
  222. if text_det_limit_type is None:
  223. text_det_limit_type = self.text_det_limit_type
  224. if text_det_thresh is None:
  225. text_det_thresh = self.text_det_thresh
  226. if text_det_box_thresh is None:
  227. text_det_box_thresh = self.text_det_box_thresh
  228. if text_det_unclip_ratio is None:
  229. text_det_unclip_ratio = self.text_det_unclip_ratio
  230. return dict(
  231. limit_side_len=text_det_limit_side_len,
  232. limit_type=text_det_limit_type,
  233. thresh=text_det_thresh,
  234. box_thresh=text_det_box_thresh,
  235. unclip_ratio=text_det_unclip_ratio,
  236. )
  237. def predict(
  238. self,
  239. input: Union[str, List[str], np.ndarray, List[np.ndarray]],
  240. use_doc_orientation_classify: Optional[bool] = None,
  241. use_doc_unwarping: Optional[bool] = None,
  242. use_textline_orientation: Optional[bool] = None,
  243. text_det_limit_side_len: Optional[int] = None,
  244. text_det_limit_type: Optional[str] = None,
  245. text_det_thresh: Optional[float] = None,
  246. text_det_box_thresh: Optional[float] = None,
  247. text_det_unclip_ratio: Optional[float] = None,
  248. text_rec_score_thresh: Optional[float] = None,
  249. ) -> OCRResult:
  250. """
  251. Predict OCR results based on input images or arrays with optional preprocessing steps.
  252. Args:
  253. input (Union[str, list[str], np.ndarray, list[np.ndarray]]): Input image of pdf path(s) or numpy array(s).
  254. use_doc_orientation_classify (Optional[bool]): Whether to use document orientation classification.
  255. use_doc_unwarping (Optional[bool]): Whether to use document unwarping.
  256. use_textline_orientation (Optional[bool]): Whether to use textline orientation prediction.
  257. text_det_limit_side_len (Optional[int]): Maximum side length for text detection.
  258. text_det_limit_type (Optional[str]): Type of limit to apply for text detection.
  259. text_det_thresh (Optional[float]): Threshold for text detection.
  260. text_det_box_thresh (Optional[float]): Threshold for text detection boxes.
  261. text_det_unclip_ratio (Optional[float]): Ratio for unclipping text detection boxes.
  262. text_rec_score_thresh (Optional[float]): Score threshold for text recognition.
  263. Returns:
  264. OCRResult: Generator yielding OCR results for each input image.
  265. """
  266. model_settings = self.get_model_settings(
  267. use_doc_orientation_classify, use_doc_unwarping, use_textline_orientation
  268. )
  269. if not self.check_model_settings_valid(model_settings):
  270. yield {"error": "the input params for model settings are invalid!"}
  271. text_det_params = self.get_text_det_params(
  272. text_det_limit_side_len,
  273. text_det_limit_type,
  274. text_det_thresh,
  275. text_det_box_thresh,
  276. text_det_unclip_ratio,
  277. )
  278. if text_rec_score_thresh is None:
  279. text_rec_score_thresh = self.text_rec_score_thresh
  280. for img_id, batch_data in enumerate(self.batch_sampler(input)):
  281. image_array = self.img_reader(batch_data.instances)[0]
  282. if model_settings["use_doc_preprocessor"]:
  283. doc_preprocessor_res = next(
  284. self.doc_preprocessor_pipeline(
  285. image_array,
  286. use_doc_orientation_classify=use_doc_orientation_classify,
  287. use_doc_unwarping=use_doc_unwarping,
  288. )
  289. )
  290. else:
  291. doc_preprocessor_res = {"output_img": image_array}
  292. doc_preprocessor_image = doc_preprocessor_res["output_img"]
  293. det_res = next(
  294. self.text_det_model(doc_preprocessor_image, **text_det_params)
  295. )
  296. dt_polys = det_res["dt_polys"]
  297. dt_scores = det_res["dt_scores"]
  298. dt_polys = self._sort_boxes(dt_polys)
  299. single_img_res = {
  300. "input_path": batch_data.input_paths[0],
  301. "page_index": batch_data.page_indexes[0],
  302. "doc_preprocessor_res": doc_preprocessor_res,
  303. "dt_polys": dt_polys,
  304. "model_settings": model_settings,
  305. "text_det_params": text_det_params,
  306. "text_type": self.text_type,
  307. "text_rec_score_thresh": text_rec_score_thresh,
  308. }
  309. single_img_res["rec_texts"] = []
  310. single_img_res["rec_scores"] = []
  311. single_img_res["rec_polys"] = []
  312. if len(dt_polys) > 0:
  313. all_subs_of_img = list(
  314. self._crop_by_polys(doc_preprocessor_image, dt_polys)
  315. )
  316. # use textline orientation model
  317. if model_settings["use_textline_orientation"]:
  318. angles = [
  319. int(textline_angle_info["class_ids"][0])
  320. for textline_angle_info in self.textline_orientation_model(
  321. all_subs_of_img
  322. )
  323. ]
  324. all_subs_of_img = self.rotate_image(all_subs_of_img, angles)
  325. else:
  326. angles = [-1] * len(all_subs_of_img)
  327. single_img_res["textline_orientation_angles"] = angles
  328. sub_img_info_list = [
  329. {
  330. "sub_img_id": img_id,
  331. "sub_img_ratio": sub_img.shape[1] / float(sub_img.shape[0]),
  332. }
  333. for img_id, sub_img in enumerate(all_subs_of_img)
  334. ]
  335. sorted_subs_info = sorted(
  336. sub_img_info_list, key=lambda x: x["sub_img_ratio"]
  337. )
  338. sorted_subs_of_img = [
  339. all_subs_of_img[x["sub_img_id"]] for x in sorted_subs_info
  340. ]
  341. for idx, rec_res in enumerate(self.text_rec_model(sorted_subs_of_img)):
  342. sub_img_id = sorted_subs_info[idx]["sub_img_id"]
  343. sub_img_info_list[sub_img_id]["rec_res"] = rec_res
  344. for sno in range(len(sub_img_info_list)):
  345. rec_res = sub_img_info_list[sno]["rec_res"]
  346. if rec_res["rec_score"] >= text_rec_score_thresh:
  347. single_img_res["rec_texts"].append(rec_res["rec_text"])
  348. single_img_res["rec_scores"].append(rec_res["rec_score"])
  349. single_img_res["rec_polys"].append(dt_polys[sno])
  350. if self.text_type == "general":
  351. rec_boxes = convert_points_to_boxes(single_img_res["rec_polys"])
  352. single_img_res["rec_boxes"] = rec_boxes
  353. else:
  354. single_img_res["rec_boxes"] = np.array([])
  355. yield OCRResult(single_img_res)