pipeline.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Dict, List, Optional, Union
  15. import numpy as np
  16. from scipy.ndimage import rotate
  17. from ...common.reader import ReadImage
  18. from ...common.batch_sampler import ImageBatchSampler
  19. from ...utils.pp_option import PaddlePredictorOption
  20. from ..base import BasePipeline
  21. from ..components import (
  22. CropByPolys,
  23. SortQuadBoxes,
  24. SortPolyBoxes,
  25. convert_points_to_boxes,
  26. )
  27. from .result import OCRResult
  28. from ..doc_preprocessor.result import DocPreprocessorResult
  29. from ....utils import logging
  30. class OCRPipeline(BasePipeline):
  31. """OCR Pipeline"""
  32. entities = "OCR"
  33. def __init__(
  34. self,
  35. config: Dict,
  36. device: Optional[str] = None,
  37. pp_option: Optional[PaddlePredictorOption] = None,
  38. use_hpip: bool = False,
  39. ) -> None:
  40. """
  41. Initializes the class with given configurations and options.
  42. Args:
  43. config (Dict): Configuration dictionary containing various settings.
  44. device (str, optional): Device to run the predictions on. Defaults to None.
  45. pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
  46. use_hpip (bool, optional): Whether to use high-performance inference (hpip) for prediction. Defaults to False.
  47. """
  48. super().__init__(device=device, pp_option=pp_option, use_hpip=use_hpip)
  49. self.use_doc_preprocessor = config.get("use_doc_preprocessor", True)
  50. if self.use_doc_preprocessor:
  51. doc_preprocessor_config = config.get("SubPipelines", {}).get(
  52. "DocPreprocessor",
  53. {
  54. "pipeline_config_error": "config error for doc_preprocessor_pipeline!"
  55. },
  56. )
  57. self.doc_preprocessor_pipeline = self.create_pipeline(
  58. doc_preprocessor_config
  59. )
  60. self.use_textline_orientation = config.get("use_textline_orientation", True)
  61. if self.use_textline_orientation:
  62. textline_orientation_config = config.get("SubModules", {}).get(
  63. "TextLineOrientation",
  64. {"model_config_error": "config error for textline_orientation_model!"},
  65. )
  66. self.textline_orientation_model = self.create_model(
  67. textline_orientation_config
  68. )
  69. text_det_config = config.get("SubModules", {}).get(
  70. "TextDetection", {"model_config_error": "config error for text_det_model!"}
  71. )
  72. self.text_type = config["text_type"]
  73. if self.text_type == "general":
  74. self.text_det_limit_side_len = text_det_config.get("limit_side_len", 960)
  75. self.text_det_limit_type = text_det_config.get("limit_type", "max")
  76. self.text_det_thresh = text_det_config.get("thresh", 0.3)
  77. self.text_det_box_thresh = text_det_config.get("box_thresh", 0.6)
  78. self.text_det_unclip_ratio = text_det_config.get("unclip_ratio", 2.0)
  79. self._sort_boxes = SortQuadBoxes()
  80. self._crop_by_polys = CropByPolys(det_box_type="quad")
  81. elif self.text_type == "seal":
  82. self.text_det_limit_side_len = text_det_config.get("limit_side_len", 736)
  83. self.text_det_limit_type = text_det_config.get("limit_type", "min")
  84. self.text_det_thresh = text_det_config.get("thresh", 0.2)
  85. self.text_det_box_thresh = text_det_config.get("box_thresh", 0.6)
  86. self.text_det_unclip_ratio = text_det_config.get("unclip_ratio", 0.5)
  87. self._sort_boxes = SortPolyBoxes()
  88. self._crop_by_polys = CropByPolys(det_box_type="poly")
  89. else:
  90. raise ValueError("Unsupported text type {}".format(self.text_type))
  91. self.text_det_model = self.create_model(
  92. text_det_config,
  93. limit_side_len=self.text_det_limit_side_len,
  94. limit_type=self.text_det_limit_type,
  95. thresh=self.text_det_thresh,
  96. box_thresh=self.text_det_box_thresh,
  97. unclip_ratio=self.text_det_unclip_ratio,
  98. )
  99. text_rec_config = config.get("SubModules", {}).get(
  100. "TextRecognition",
  101. {"model_config_error": "config error for text_rec_model!"},
  102. )
  103. self.text_rec_score_thresh = text_rec_config.get("score_thresh", 0)
  104. self.text_rec_model = self.create_model(text_rec_config)
  105. self.batch_sampler = ImageBatchSampler(batch_size=1)
  106. self.img_reader = ReadImage(format="BGR")
  107. def rotate_image(
  108. self, image_array_list: List[np.ndarray], rotate_angle_list: List[int]
  109. ) -> List[np.ndarray]:
  110. """
  111. Rotate the given image arrays by their corresponding angles.
  112. 0 corresponds to 0 degrees, 1 corresponds to 180 degrees.
  113. Args:
  114. image_array_list (List[np.ndarray]): A list of input image arrays to be rotated.
  115. rotate_angle_list (List[int]): A list of rotation indicators (0 or 1).
  116. 0 means rotate by 0 degrees
  117. 1 means rotate by 180 degrees
  118. Returns:
  119. List[np.ndarray]: A list of rotated image arrays.
  120. Raises:
  121. AssertionError: If any rotate_angle is not 0 or 1.
  122. AssertionError: If the lengths of input lists don't match.
  123. """
  124. assert len(image_array_list) == len(
  125. rotate_angle_list
  126. ), f"Length of image_array_list ({len(image_array_list)}) must match length of rotate_angle_list ({len(rotate_angle_list)})"
  127. for angle in rotate_angle_list:
  128. assert angle in [0, 1], f"rotate_angle must be 0 or 1, now it's {angle}"
  129. rotated_images = []
  130. for image_array, rotate_indicator in zip(image_array_list, rotate_angle_list):
  131. # Convert 0/1 indicator to actual rotation angle
  132. rotate_angle = rotate_indicator * 180
  133. rotated_image = rotate(image_array, rotate_angle, reshape=True)
  134. rotated_images.append(rotated_image)
  135. return rotated_images
  136. def check_model_settings_valid(self, model_settings: Dict) -> bool:
  137. """
  138. Check if the input parameters are valid based on the initialized models.
  139. Args:
  140. model_info_params(Dict): A dictionary containing input parameters.
  141. Returns:
  142. bool: True if all required models are initialized according to input parameters, False otherwise.
  143. """
  144. if model_settings["use_doc_preprocessor"] and not self.use_doc_preprocessor:
  145. logging.error(
  146. "Set use_doc_preprocessor, but the models for doc preprocessor are not initialized."
  147. )
  148. return False
  149. if (
  150. model_settings["use_textline_orientation"]
  151. and not self.use_textline_orientation
  152. ):
  153. logging.error(
  154. "Set use_textline_orientation, but the models for use_textline_orientation are not initialized."
  155. )
  156. return False
  157. return True
  158. def get_model_settings(
  159. self,
  160. use_doc_orientation_classify: Optional[bool],
  161. use_doc_unwarping: Optional[bool],
  162. use_textline_orientation: Optional[bool],
  163. ) -> dict:
  164. """
  165. Get the model settings based on the provided parameters or default values.
  166. Args:
  167. use_doc_orientation_classify (Optional[bool]): Whether to use document orientation classification.
  168. use_doc_unwarping (Optional[bool]): Whether to use document unwarping.
  169. use_textline_orientation (Optional[bool]): Whether to use textline orientation.
  170. Returns:
  171. dict: A dictionary containing the model settings.
  172. """
  173. if use_doc_orientation_classify is None and use_doc_unwarping is None:
  174. use_doc_preprocessor = self.use_doc_preprocessor
  175. else:
  176. if use_doc_orientation_classify is True or use_doc_unwarping is True:
  177. use_doc_preprocessor = True
  178. else:
  179. use_doc_preprocessor = False
  180. if use_textline_orientation is None:
  181. use_textline_orientation = self.use_textline_orientation
  182. return dict(
  183. use_doc_preprocessor=use_doc_preprocessor,
  184. use_textline_orientation=use_textline_orientation,
  185. )
  186. def get_text_det_params(
  187. self,
  188. text_det_limit_side_len: Optional[int] = None,
  189. text_det_limit_type: Optional[str] = None,
  190. text_det_thresh: Optional[float] = None,
  191. text_det_box_thresh: Optional[float] = None,
  192. text_det_unclip_ratio: Optional[float] = None,
  193. ) -> dict:
  194. """
  195. Get text detection parameters.
  196. If a parameter is None, its default value from the instance will be used.
  197. Args:
  198. text_det_limit_side_len (Optional[int]): The maximum side length of the text box.
  199. text_det_limit_type (Optional[str]): The type of limit to apply to the text box.
  200. text_det_thresh (Optional[float]): The threshold for text detection.
  201. text_det_box_thresh (Optional[float]): The threshold for the bounding box.
  202. text_det_unclip_ratio (Optional[float]): The ratio for unclipping the text box.
  203. Returns:
  204. dict: A dictionary containing the text detection parameters.
  205. """
  206. if text_det_limit_side_len is None:
  207. text_det_limit_side_len = self.text_det_limit_side_len
  208. if text_det_limit_type is None:
  209. text_det_limit_type = self.text_det_limit_type
  210. if text_det_thresh is None:
  211. text_det_thresh = self.text_det_thresh
  212. if text_det_box_thresh is None:
  213. text_det_box_thresh = self.text_det_box_thresh
  214. if text_det_unclip_ratio is None:
  215. text_det_unclip_ratio = self.text_det_unclip_ratio
  216. return dict(
  217. limit_side_len=text_det_limit_side_len,
  218. limit_type=text_det_limit_type,
  219. thresh=text_det_thresh,
  220. box_thresh=text_det_box_thresh,
  221. unclip_ratio=text_det_unclip_ratio,
  222. )
  223. def predict(
  224. self,
  225. input: Union[str, List[str], np.ndarray, List[np.ndarray]],
  226. use_doc_orientation_classify: Optional[bool] = None,
  227. use_doc_unwarping: Optional[bool] = None,
  228. use_textline_orientation: Optional[bool] = None,
  229. text_det_limit_side_len: Optional[int] = None,
  230. text_det_limit_type: Optional[str] = None,
  231. text_det_thresh: Optional[float] = None,
  232. text_det_box_thresh: Optional[float] = None,
  233. text_det_unclip_ratio: Optional[float] = None,
  234. text_rec_score_thresh: Optional[float] = None,
  235. ) -> OCRResult:
  236. """
  237. Predict OCR results based on input images or arrays with optional preprocessing steps.
  238. Args:
  239. input (Union[str, list[str], np.ndarray, list[np.ndarray]]): Input image of pdf path(s) or numpy array(s).
  240. use_doc_orientation_classify (Optional[bool]): Whether to use document orientation classification.
  241. use_doc_unwarping (Optional[bool]): Whether to use document unwarping.
  242. use_textline_orientation (Optional[bool]): Whether to use textline orientation prediction.
  243. text_det_limit_side_len (Optional[int]): Maximum side length for text detection.
  244. text_det_limit_type (Optional[str]): Type of limit to apply for text detection.
  245. text_det_thresh (Optional[float]): Threshold for text detection.
  246. text_det_box_thresh (Optional[float]): Threshold for text detection boxes.
  247. text_det_unclip_ratio (Optional[float]): Ratio for unclipping text detection boxes.
  248. text_rec_score_thresh (Optional[float]): Score threshold for text recognition.
  249. Returns:
  250. OCRResult: Generator yielding OCR results for each input image.
  251. """
  252. model_settings = self.get_model_settings(
  253. use_doc_orientation_classify, use_doc_unwarping, use_textline_orientation
  254. )
  255. if not self.check_model_settings_valid(model_settings):
  256. yield {"error": "the input params for model settings are invalid!"}
  257. text_det_params = self.get_text_det_params(
  258. text_det_limit_side_len,
  259. text_det_limit_type,
  260. text_det_thresh,
  261. text_det_box_thresh,
  262. text_det_unclip_ratio,
  263. )
  264. if text_rec_score_thresh is None:
  265. text_rec_score_thresh = self.text_rec_score_thresh
  266. for img_id, batch_data in enumerate(self.batch_sampler(input)):
  267. image_array = self.img_reader(batch_data.instances)[0]
  268. if model_settings["use_doc_preprocessor"]:
  269. doc_preprocessor_res = next(
  270. self.doc_preprocessor_pipeline(
  271. image_array,
  272. use_doc_orientation_classify=use_doc_orientation_classify,
  273. use_doc_unwarping=use_doc_unwarping,
  274. )
  275. )
  276. else:
  277. doc_preprocessor_res = {"output_img": image_array}
  278. doc_preprocessor_image = doc_preprocessor_res["output_img"]
  279. det_res = next(
  280. self.text_det_model(doc_preprocessor_image, **text_det_params)
  281. )
  282. dt_polys = det_res["dt_polys"]
  283. dt_scores = det_res["dt_scores"]
  284. dt_polys = self._sort_boxes(dt_polys)
  285. single_img_res = {
  286. "input_path": batch_data.input_paths[0],
  287. "page_index": batch_data.page_indexes[0],
  288. "doc_preprocessor_res": doc_preprocessor_res,
  289. "dt_polys": dt_polys,
  290. "model_settings": model_settings,
  291. "text_det_params": text_det_params,
  292. "text_type": self.text_type,
  293. "text_rec_score_thresh": text_rec_score_thresh,
  294. }
  295. single_img_res["rec_texts"] = []
  296. single_img_res["rec_scores"] = []
  297. single_img_res["rec_polys"] = []
  298. if len(dt_polys) > 0:
  299. all_subs_of_img = list(
  300. self._crop_by_polys(doc_preprocessor_image, dt_polys)
  301. )
  302. # use textline orientation model
  303. if model_settings["use_textline_orientation"]:
  304. angles = [
  305. int(textline_angle_info["class_ids"][0])
  306. for textline_angle_info in self.textline_orientation_model(
  307. all_subs_of_img
  308. )
  309. ]
  310. all_subs_of_img = self.rotate_image(all_subs_of_img, angles)
  311. else:
  312. angles = [-1] * len(all_subs_of_img)
  313. single_img_res["textline_orientation_angles"] = angles
  314. sub_img_info_list = [
  315. {
  316. "sub_img_id": img_id,
  317. "sub_img_ratio": sub_img.shape[1] / float(sub_img.shape[0]),
  318. }
  319. for img_id, sub_img in enumerate(all_subs_of_img)
  320. ]
  321. sorted_subs_info = sorted(
  322. sub_img_info_list, key=lambda x: x["sub_img_ratio"]
  323. )
  324. sorted_subs_of_img = [
  325. all_subs_of_img[x["sub_img_id"]] for x in sorted_subs_info
  326. ]
  327. for idx, rec_res in enumerate(self.text_rec_model(sorted_subs_of_img)):
  328. sub_img_id = sorted_subs_info[idx]["sub_img_id"]
  329. sub_img_info_list[sub_img_id]["rec_res"] = rec_res
  330. for sno in range(len(sub_img_info_list)):
  331. rec_res = sub_img_info_list[sno]["rec_res"]
  332. if rec_res["rec_score"] >= text_rec_score_thresh:
  333. single_img_res["rec_texts"].append(rec_res["rec_text"])
  334. single_img_res["rec_scores"].append(rec_res["rec_score"])
  335. single_img_res["rec_polys"].append(dt_polys[sno])
  336. if self.text_type == "general":
  337. rec_boxes = convert_points_to_boxes(single_img_res["rec_polys"])
  338. single_img_res["rec_boxes"] = rec_boxes
  339. else:
  340. single_img_res["rec_boxes"] = np.array([])
  341. yield OCRResult(single_img_res)