pipeline.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Dict, List, Optional, Union
  15. import numpy as np
  16. from scipy.ndimage import rotate
  17. from ...common.reader import ReadImage
  18. from ...common.batch_sampler import ImageBatchSampler
  19. from ...utils.pp_option import PaddlePredictorOption
  20. from ..base import BasePipeline
  21. from ..components import (
  22. CropByPolys,
  23. SortQuadBoxes,
  24. SortPolyBoxes,
  25. convert_points_to_boxes,
  26. )
  27. from .result import OCRResult
  28. from ..doc_preprocessor.result import DocPreprocessorResult
  29. from ....utils import logging
  30. class OCRPipeline(BasePipeline):
  31. """OCR Pipeline"""
  32. entities = "OCR"
  33. def __init__(
  34. self,
  35. config: Dict,
  36. device: Optional[str] = None,
  37. pp_option: Optional[PaddlePredictorOption] = None,
  38. use_hpip: bool = False,
  39. ) -> None:
  40. """
  41. Initializes the class with given configurations and options.
  42. Args:
  43. config (Dict): Configuration dictionary containing various settings.
  44. device (str, optional): Device to run the predictions on. Defaults to None.
  45. pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
  46. use_hpip (bool, optional): Whether to use high-performance inference (hpip) for prediction. Defaults to False.
  47. """
  48. super().__init__(device=device, pp_option=pp_option, use_hpip=use_hpip)
  49. self.use_doc_preprocessor = config.get("use_doc_preprocessor", True)
  50. if self.use_doc_preprocessor:
  51. doc_preprocessor_config = config.get("SubPipelines", {}).get(
  52. "DocPreprocessor",
  53. {
  54. "pipeline_config_error": "config error for doc_preprocessor_pipeline!"
  55. },
  56. )
  57. self.doc_preprocessor_pipeline = self.create_pipeline(
  58. doc_preprocessor_config
  59. )
  60. self.use_textline_orientation = config.get("use_textline_orientation", True)
  61. if self.use_textline_orientation:
  62. textline_orientation_config = config.get("SubModules", {}).get(
  63. "TextLineOrientation",
  64. {"model_config_error": "config error for textline_orientation_model!"},
  65. )
  66. # TODO: add batch_size
  67. # batch_size = textline_orientation_config.get("batch_size", 1)
  68. # self.textline_orientation_model = self.create_model(
  69. # textline_orientation_config, batch_size=batch_size
  70. # )
  71. self.textline_orientation_model = self.create_model(
  72. textline_orientation_config
  73. )
  74. text_det_config = config.get("SubModules", {}).get(
  75. "TextDetection", {"model_config_error": "config error for text_det_model!"}
  76. )
  77. self.text_type = config["text_type"]
  78. if self.text_type == "general":
  79. self.text_det_limit_side_len = text_det_config.get("limit_side_len", 960)
  80. self.text_det_limit_type = text_det_config.get("limit_type", "max")
  81. self.text_det_thresh = text_det_config.get("thresh", 0.3)
  82. self.text_det_box_thresh = text_det_config.get("box_thresh", 0.6)
  83. self.text_det_unclip_ratio = text_det_config.get("unclip_ratio", 2.0)
  84. self._sort_boxes = SortQuadBoxes()
  85. self._crop_by_polys = CropByPolys(det_box_type="quad")
  86. elif self.text_type == "seal":
  87. self.text_det_limit_side_len = text_det_config.get("limit_side_len", 736)
  88. self.text_det_limit_type = text_det_config.get("limit_type", "min")
  89. self.text_det_thresh = text_det_config.get("thresh", 0.2)
  90. self.text_det_box_thresh = text_det_config.get("box_thresh", 0.6)
  91. self.text_det_unclip_ratio = text_det_config.get("unclip_ratio", 0.5)
  92. self._sort_boxes = SortPolyBoxes()
  93. self._crop_by_polys = CropByPolys(det_box_type="poly")
  94. else:
  95. raise ValueError("Unsupported text type {}".format(self.text_type))
  96. self.text_det_model = self.create_model(
  97. text_det_config,
  98. limit_side_len=self.text_det_limit_side_len,
  99. limit_type=self.text_det_limit_type,
  100. thresh=self.text_det_thresh,
  101. box_thresh=self.text_det_box_thresh,
  102. unclip_ratio=self.text_det_unclip_ratio,
  103. )
  104. text_rec_config = config.get("SubModules", {}).get(
  105. "TextRecognition",
  106. {"model_config_error": "config error for text_rec_model!"},
  107. )
  108. # TODO: add batch_size
  109. # batch_size = text_rec_config.get("batch_size", 1)
  110. # self.text_rec_model = self.create_model(text_rec_config,
  111. # batch_size=batch_size)
  112. self.text_rec_score_thresh = text_rec_config.get("score_thresh", 0)
  113. self.text_rec_model = self.create_model(text_rec_config)
  114. self.batch_sampler = ImageBatchSampler(batch_size=1)
  115. self.img_reader = ReadImage(format="BGR")
  116. def rotate_image(
  117. self, image_array_list: List[np.ndarray], rotate_angle_list: List[int]
  118. ) -> List[np.ndarray]:
  119. """
  120. Rotate the given image arrays by their corresponding angles.
  121. 0 corresponds to 0 degrees, 1 corresponds to 180 degrees.
  122. Args:
  123. image_array_list (List[np.ndarray]): A list of input image arrays to be rotated.
  124. rotate_angle_list (List[int]): A list of rotation indicators (0 or 1).
  125. 0 means rotate by 0 degrees
  126. 1 means rotate by 180 degrees
  127. Returns:
  128. List[np.ndarray]: A list of rotated image arrays.
  129. Raises:
  130. AssertionError: If any rotate_angle is not 0 or 1.
  131. AssertionError: If the lengths of input lists don't match.
  132. """
  133. assert len(image_array_list) == len(
  134. rotate_angle_list
  135. ), f"Length of image_array_list ({len(image_array_list)}) must match length of rotate_angle_list ({len(rotate_angle_list)})"
  136. for angle in rotate_angle_list:
  137. assert angle in [0, 1], f"rotate_angle must be 0 or 1, now it's {angle}"
  138. rotated_images = []
  139. for image_array, rotate_indicator in zip(image_array_list, rotate_angle_list):
  140. # Convert 0/1 indicator to actual rotation angle
  141. rotate_angle = rotate_indicator * 180
  142. rotated_image = rotate(image_array, rotate_angle, reshape=True)
  143. rotated_images.append(rotated_image)
  144. return rotated_images
  145. def check_model_settings_valid(self, model_settings: Dict) -> bool:
  146. """
  147. Check if the input parameters are valid based on the initialized models.
  148. Args:
  149. model_info_params(Dict): A dictionary containing input parameters.
  150. Returns:
  151. bool: True if all required models are initialized according to input parameters, False otherwise.
  152. """
  153. if model_settings["use_doc_preprocessor"] and not self.use_doc_preprocessor:
  154. logging.error(
  155. "Set use_doc_preprocessor, but the models for doc preprocessor are not initialized."
  156. )
  157. return False
  158. if (
  159. model_settings["use_textline_orientation"]
  160. and not self.use_textline_orientation
  161. ):
  162. logging.error(
  163. "Set use_textline_orientation, but the models for use_textline_orientation are not initialized."
  164. )
  165. return False
  166. return True
  167. def get_model_settings(
  168. self,
  169. use_doc_orientation_classify: Optional[bool],
  170. use_doc_unwarping: Optional[bool],
  171. use_textline_orientation: Optional[bool],
  172. ) -> dict:
  173. """
  174. Get the model settings based on the provided parameters or default values.
  175. Args:
  176. use_doc_orientation_classify (Optional[bool]): Whether to use document orientation classification.
  177. use_doc_unwarping (Optional[bool]): Whether to use document unwarping.
  178. use_textline_orientation (Optional[bool]): Whether to use textline orientation.
  179. Returns:
  180. dict: A dictionary containing the model settings.
  181. """
  182. if use_doc_orientation_classify is None and use_doc_unwarping is None:
  183. use_doc_preprocessor = self.use_doc_preprocessor
  184. else:
  185. if use_doc_orientation_classify is True or use_doc_unwarping is True:
  186. use_doc_preprocessor = True
  187. else:
  188. use_doc_preprocessor = False
  189. if use_textline_orientation is None:
  190. use_textline_orientation = self.use_textline_orientation
  191. return dict(
  192. use_doc_preprocessor=use_doc_preprocessor,
  193. use_textline_orientation=use_textline_orientation,
  194. )
  195. def get_text_det_params(
  196. self,
  197. text_det_limit_side_len: Optional[int] = None,
  198. text_det_limit_type: Optional[str] = None,
  199. text_det_thresh: Optional[float] = None,
  200. text_det_box_thresh: Optional[float] = None,
  201. text_det_unclip_ratio: Optional[float] = None,
  202. ) -> dict:
  203. """
  204. Get text detection parameters.
  205. If a parameter is None, its default value from the instance will be used.
  206. Args:
  207. text_det_limit_side_len (Optional[int]): The maximum side length of the text box.
  208. text_det_limit_type (Optional[str]): The type of limit to apply to the text box.
  209. text_det_thresh (Optional[float]): The threshold for text detection.
  210. text_det_box_thresh (Optional[float]): The threshold for the bounding box.
  211. text_det_unclip_ratio (Optional[float]): The ratio for unclipping the text box.
  212. Returns:
  213. dict: A dictionary containing the text detection parameters.
  214. """
  215. if text_det_limit_side_len is None:
  216. text_det_limit_side_len = self.text_det_limit_side_len
  217. if text_det_limit_type is None:
  218. text_det_limit_type = self.text_det_limit_type
  219. if text_det_thresh is None:
  220. text_det_thresh = self.text_det_thresh
  221. if text_det_box_thresh is None:
  222. text_det_box_thresh = self.text_det_box_thresh
  223. if text_det_unclip_ratio is None:
  224. text_det_unclip_ratio = self.text_det_unclip_ratio
  225. return dict(
  226. limit_side_len=text_det_limit_side_len,
  227. limit_type=text_det_limit_type,
  228. thresh=text_det_thresh,
  229. box_thresh=text_det_box_thresh,
  230. unclip_ratio=text_det_unclip_ratio,
  231. )
  232. def predict(
  233. self,
  234. input: Union[str, List[str], np.ndarray, List[np.ndarray]],
  235. use_doc_orientation_classify: Optional[bool] = None,
  236. use_doc_unwarping: Optional[bool] = None,
  237. use_textline_orientation: Optional[bool] = None,
  238. text_det_limit_side_len: Optional[int] = None,
  239. text_det_limit_type: Optional[str] = None,
  240. text_det_thresh: Optional[float] = None,
  241. text_det_box_thresh: Optional[float] = None,
  242. text_det_unclip_ratio: Optional[float] = None,
  243. text_rec_score_thresh: Optional[float] = None,
  244. ) -> OCRResult:
  245. """
  246. Predict OCR results based on input images or arrays with optional preprocessing steps.
  247. Args:
  248. input (Union[str, list[str], np.ndarray, list[np.ndarray]]): Input image of pdf path(s) or numpy array(s).
  249. use_doc_orientation_classify (Optional[bool]): Whether to use document orientation classification.
  250. use_doc_unwarping (Optional[bool]): Whether to use document unwarping.
  251. use_textline_orientation (Optional[bool]): Whether to use textline orientation prediction.
  252. text_det_limit_side_len (Optional[int]): Maximum side length for text detection.
  253. text_det_limit_type (Optional[str]): Type of limit to apply for text detection.
  254. text_det_thresh (Optional[float]): Threshold for text detection.
  255. text_det_box_thresh (Optional[float]): Threshold for text detection boxes.
  256. text_det_unclip_ratio (Optional[float]): Ratio for unclipping text detection boxes.
  257. text_rec_score_thresh (Optional[float]): Score threshold for text recognition.
  258. Returns:
  259. OCRResult: Generator yielding OCR results for each input image.
  260. """
  261. model_settings = self.get_model_settings(
  262. use_doc_orientation_classify, use_doc_unwarping, use_textline_orientation
  263. )
  264. if not self.check_model_settings_valid(model_settings):
  265. yield {"error": "the input params for model settings are invalid!"}
  266. text_det_params = self.get_text_det_params(
  267. text_det_limit_side_len,
  268. text_det_limit_type,
  269. text_det_thresh,
  270. text_det_box_thresh,
  271. text_det_unclip_ratio,
  272. )
  273. if text_rec_score_thresh is None:
  274. text_rec_score_thresh = self.text_rec_score_thresh
  275. for img_id, batch_data in enumerate(self.batch_sampler(input)):
  276. if not isinstance(batch_data[0], str):
  277. # TODO: add support input_pth for ndarray and pdf
  278. input_path = f"{img_id}.jpg"
  279. else:
  280. input_path = batch_data[0]
  281. image_array = self.img_reader(batch_data)[0]
  282. if model_settings["use_doc_preprocessor"]:
  283. doc_preprocessor_res = next(
  284. self.doc_preprocessor_pipeline(
  285. image_array,
  286. use_doc_orientation_classify=use_doc_orientation_classify,
  287. use_doc_unwarping=use_doc_unwarping,
  288. )
  289. )
  290. else:
  291. doc_preprocessor_res = {"output_img": image_array}
  292. doc_preprocessor_image = doc_preprocessor_res["output_img"]
  293. det_res = next(
  294. self.text_det_model(doc_preprocessor_image, **text_det_params)
  295. )
  296. dt_polys = det_res["dt_polys"]
  297. dt_scores = det_res["dt_scores"]
  298. dt_polys = self._sort_boxes(dt_polys)
  299. single_img_res = {
  300. "input_path": input_path,
  301. "doc_preprocessor_res": doc_preprocessor_res,
  302. "dt_polys": dt_polys,
  303. "model_settings": model_settings,
  304. "text_det_params": text_det_params,
  305. "text_type": self.text_type,
  306. "text_rec_score_thresh": text_rec_score_thresh,
  307. }
  308. single_img_res["rec_texts"] = []
  309. single_img_res["rec_scores"] = []
  310. single_img_res["rec_polys"] = []
  311. if len(dt_polys) > 0:
  312. all_subs_of_img = list(
  313. self._crop_by_polys(doc_preprocessor_image, dt_polys)
  314. )
  315. # use textline orientation model
  316. if model_settings["use_textline_orientation"]:
  317. angles = [
  318. int(textline_angle_info["class_ids"][0])
  319. for textline_angle_info in self.textline_orientation_model(
  320. all_subs_of_img
  321. )
  322. ]
  323. all_subs_of_img = self.rotate_image(all_subs_of_img, angles)
  324. else:
  325. angles = [-1] * len(all_subs_of_img)
  326. single_img_res["textline_orientation_angles"] = angles
  327. sub_img_info_list = [
  328. {
  329. "sub_img_id": img_id,
  330. "sub_img_ratio": sub_img.shape[1] / float(sub_img.shape[0]),
  331. }
  332. for img_id, sub_img in enumerate(all_subs_of_img)
  333. ]
  334. sorted_subs_info = sorted(
  335. sub_img_info_list, key=lambda x: x["sub_img_ratio"]
  336. )
  337. sorted_subs_of_img = [
  338. all_subs_of_img[x["sub_img_id"]] for x in sorted_subs_info
  339. ]
  340. for idx, rec_res in enumerate(self.text_rec_model(sorted_subs_of_img)):
  341. sub_img_id = sorted_subs_info[idx]["sub_img_id"]
  342. sub_img_info_list[sub_img_id]["rec_res"] = rec_res
  343. for sno in range(len(sub_img_info_list)):
  344. rec_res = sub_img_info_list[sno]["rec_res"]
  345. if rec_res["rec_score"] >= text_rec_score_thresh:
  346. single_img_res["rec_texts"].append(rec_res["rec_text"])
  347. single_img_res["rec_scores"].append(rec_res["rec_score"])
  348. single_img_res["rec_polys"].append(dt_polys[sno])
  349. if self.text_type == "general":
  350. rec_boxes = convert_points_to_boxes(single_img_res["rec_polys"])
  351. single_img_res["rec_boxes"] = rec_boxes
  352. else:
  353. single_img_res["rec_boxes"] = np.array([])
  354. yield OCRResult(single_img_res)