layout_parsing.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from ...results import *
  16. from ...components import *
  17. from ..ocr import OCRPipeline
  18. from ....utils import logging
  19. from ..ppchatocrv3.utils import *
  20. from ..table_recognition import _TableRecPipeline
  21. from ..table_recognition.utils import convert_4point2rect, get_ori_coordinate_for_table
  22. class LayoutParsingPipeline(_TableRecPipeline):
  23. """Layout Analysis Pileline"""
  24. entities = "layout_parsing"
  25. def __init__(
  26. self,
  27. layout_model,
  28. text_det_model,
  29. text_rec_model,
  30. table_model,
  31. formula_rec_model,
  32. doc_image_ori_cls_model=None,
  33. doc_image_unwarp_model=None,
  34. seal_text_det_model=None,
  35. layout_batch_size=1,
  36. text_det_batch_size=1,
  37. text_rec_batch_size=1,
  38. table_batch_size=1,
  39. doc_image_ori_cls_batch_size=1,
  40. doc_image_unwarp_batch_size=1,
  41. seal_text_det_batch_size=1,
  42. formula_rec_batch_size=1,
  43. recovery=True,
  44. device=None,
  45. predictor_kwargs=None,
  46. ):
  47. super().__init__(
  48. device,
  49. predictor_kwargs,
  50. )
  51. self._build_predictor(
  52. layout_model=layout_model,
  53. text_det_model=text_det_model,
  54. text_rec_model=text_rec_model,
  55. table_model=table_model,
  56. doc_image_ori_cls_model=doc_image_ori_cls_model,
  57. doc_image_unwarp_model=doc_image_unwarp_model,
  58. seal_text_det_model=seal_text_det_model,
  59. formula_rec_model=formula_rec_model,
  60. )
  61. self.set_predictor(
  62. layout_batch_size=layout_batch_size,
  63. text_det_batch_size=text_det_batch_size,
  64. text_rec_batch_size=text_rec_batch_size,
  65. table_batch_size=table_batch_size,
  66. doc_image_ori_cls_batch_size=doc_image_ori_cls_batch_size,
  67. doc_image_unwarp_batch_size=doc_image_unwarp_batch_size,
  68. seal_text_det_batch_size=seal_text_det_batch_size,
  69. formula_rec_batch_size=formula_rec_batch_size,
  70. )
  71. self.recovery = recovery
  72. def _build_predictor(
  73. self,
  74. layout_model,
  75. text_det_model,
  76. text_rec_model,
  77. table_model,
  78. formula_rec_model,
  79. seal_text_det_model=None,
  80. doc_image_ori_cls_model=None,
  81. doc_image_unwarp_model=None,
  82. ):
  83. super()._build_predictor(
  84. layout_model, text_det_model, text_rec_model, table_model
  85. )
  86. self.formula_predictor = self._create(formula_rec_model)
  87. if seal_text_det_model:
  88. self.curve_pipeline = self._create(
  89. pipeline=OCRPipeline,
  90. text_det_model=seal_text_det_model,
  91. text_rec_model=text_rec_model,
  92. )
  93. else:
  94. self.curve_pipeline = None
  95. if doc_image_ori_cls_model:
  96. self.oricls_predictor = self._create(doc_image_ori_cls_model)
  97. else:
  98. self.oricls_predictor = None
  99. if doc_image_unwarp_model:
  100. self.uvdoc_predictor = self._create(doc_image_unwarp_model)
  101. else:
  102. self.uvdoc_predictor = None
  103. self.img_reader = ReadImage(format="BGR")
  104. self.cropper = CropByBoxes()
  105. def set_predictor(
  106. self,
  107. layout_batch_size=None,
  108. text_det_batch_size=None,
  109. text_rec_batch_size=None,
  110. table_batch_size=None,
  111. doc_image_ori_cls_batch_size=None,
  112. doc_image_unwarp_batch_size=None,
  113. seal_text_det_batch_size=None,
  114. formula_rec_batch_size=None,
  115. device=None,
  116. ):
  117. if text_det_batch_size and text_det_batch_size > 1:
  118. logging.warning(
  119. f"text det model only support batch_size=1 now,the setting of text_det_batch_size={text_det_batch_size} will not using! "
  120. )
  121. if layout_batch_size:
  122. self.layout_predictor.set_predictor(batch_size=layout_batch_size)
  123. if text_rec_batch_size:
  124. self.ocr_pipeline.text_rec_model.set_predictor(
  125. batch_size=text_rec_batch_size
  126. )
  127. if table_batch_size:
  128. self.table_predictor.set_predictor(batch_size=table_batch_size)
  129. if formula_rec_batch_size:
  130. self.formula_predictor.set_predictor(batch_size=formula_rec_batch_size)
  131. if self.curve_pipeline and seal_text_det_batch_size:
  132. self.curve_pipeline.text_det_model.set_predictor(
  133. batch_size=seal_text_det_batch_size
  134. )
  135. if self.oricls_predictor and doc_image_ori_cls_batch_size:
  136. self.oricls_predictor.set_predictor(batch_size=doc_image_ori_cls_batch_size)
  137. if self.uvdoc_predictor and doc_image_unwarp_batch_size:
  138. self.uvdoc_predictor.set_predictor(batch_size=doc_image_unwarp_batch_size)
  139. if device:
  140. if self.curve_pipeline:
  141. self.curve_pipeline.set_predictor(device=device)
  142. if self.oricls_predictor:
  143. self.oricls_predictor.set_predictor(device=device)
  144. if self.uvdoc_predictor:
  145. self.uvdoc_predictor.set_predictor(device=device)
  146. self.layout_predictor.set_predictor(device=device)
  147. self.ocr_pipeline.set_predictor(device=device)
  148. def predict(
  149. self,
  150. inputs,
  151. use_doc_image_ori_cls_model=True,
  152. use_doc_image_unwarp_model=True,
  153. use_seal_text_det_model=True,
  154. recovery=True,
  155. **kwargs,
  156. ):
  157. self.set_predictor(**kwargs)
  158. # get oricls and uvdoc results
  159. img_info_list = list(self.img_reader(inputs))[0]
  160. oricls_results = []
  161. if self.oricls_predictor and use_doc_image_ori_cls_model:
  162. oricls_results = get_oriclas_results(img_info_list, self.oricls_predictor)
  163. unwarp_result = []
  164. if self.uvdoc_predictor and use_doc_image_unwarp_model:
  165. unwarp_result = get_unwarp_results(img_info_list, self.uvdoc_predictor)
  166. img_list = [img_info["img"] for img_info in img_info_list]
  167. for idx, (img_info, layout_pred) in enumerate(
  168. zip(img_info_list, self.layout_predictor(img_list))
  169. ):
  170. page_id = idx
  171. single_img_res = {
  172. "input_path": "",
  173. "layout_result": DetResult({}),
  174. "ocr_result": OCRResult({}),
  175. "table_ocr_result": [],
  176. "table_result": StructureTableResult([]),
  177. "layout_parsing_result": {},
  178. "oricls_result": TopkResult({}),
  179. "formula_result": TextRecResult({}),
  180. "unwarp_result": DocTrResult({}),
  181. "curve_result": [],
  182. }
  183. # update oricls and uvdoc result
  184. if oricls_results:
  185. single_img_res["oricls_result"] = oricls_results[idx]
  186. if unwarp_result:
  187. single_img_res["unwarp_result"] = unwarp_result[idx]
  188. # update layout result
  189. single_img_res["input_path"] = layout_pred["input_path"]
  190. single_img_res["layout_result"] = layout_pred
  191. single_img = img_info["img"]
  192. table_subs = []
  193. curve_subs = []
  194. formula_subs = []
  195. structure_res = []
  196. ocr_res_with_layout = []
  197. if len(layout_pred["boxes"]) > 0:
  198. subs_of_img = list(self._crop_by_boxes(layout_pred))
  199. # get cropped images
  200. for sub in subs_of_img:
  201. box = sub["box"]
  202. xmin, ymin, xmax, ymax = [int(i) for i in box]
  203. mask_flag = True
  204. if sub["label"].lower() == "table":
  205. table_subs.append(sub)
  206. elif sub["label"].lower() == "seal":
  207. curve_subs.append(sub)
  208. elif sub["label"].lower() == "formula":
  209. formula_subs.append(sub)
  210. else:
  211. if self.recovery and recovery:
  212. # TODO: Why use the entire image?
  213. wht_im = (
  214. np.ones(single_img.shape, dtype=single_img.dtype) * 255
  215. )
  216. wht_im[ymin:ymax, xmin:xmax, :] = sub["img"]
  217. sub_ocr_res = get_ocr_res(self.ocr_pipeline, wht_im)
  218. else:
  219. sub_ocr_res = get_ocr_res(self.ocr_pipeline, sub)
  220. sub_ocr_res["dt_polys"] = get_ori_coordinate_for_table(
  221. xmin, ymin, sub_ocr_res["dt_polys"]
  222. )
  223. layout_label = sub["label"].lower()
  224. # Adapt the user label definition to specify behavior.
  225. if sub_ocr_res and sub["label"].lower() in [
  226. "image",
  227. "figure",
  228. "img",
  229. "fig",
  230. ]:
  231. get_text_in_image = kwargs.get("get_text_in_image", False)
  232. mask_flag = not get_text_in_image
  233. text_in_image = ""
  234. if get_text_in_image:
  235. text_in_image = "".join(sub_ocr_res["rec_text"])
  236. ocr_res_with_layout.append(sub_ocr_res)
  237. structure_res.append(
  238. {
  239. "input_path": sub_ocr_res["input_path"],
  240. "layout_bbox": box,
  241. f"{layout_label}": {
  242. "img": sub["img"],
  243. f"{layout_label}_text": text_in_image,
  244. },
  245. }
  246. )
  247. else:
  248. ocr_res_with_layout.append(sub_ocr_res)
  249. structure_res.append(
  250. {
  251. "input_path": sub_ocr_res["input_path"],
  252. "layout_bbox": box,
  253. f"{layout_label}": "\n".join(
  254. sub_ocr_res["rec_text"]
  255. ),
  256. }
  257. )
  258. if mask_flag:
  259. single_img[ymin:ymax, xmin:xmax, :] = 255
  260. curve_pipeline = self.ocr_pipeline
  261. if self.curve_pipeline and use_seal_text_det_model:
  262. curve_pipeline = self.curve_pipeline
  263. all_curve_res = get_ocr_res(curve_pipeline, curve_subs)
  264. single_img_res["curve_result"] = all_curve_res
  265. if isinstance(all_curve_res, dict):
  266. all_curve_res = [all_curve_res]
  267. for sub, curve_res in zip(curve_subs, all_curve_res):
  268. structure_res.append(
  269. {
  270. "input_path": curve_res["input_path"],
  271. "layout_bbox": sub["box"],
  272. "seal": "".join(curve_res["rec_text"]),
  273. }
  274. )
  275. all_formula_res = get_formula_res(self.formula_predictor, formula_subs)
  276. single_img_res["formula_result"] = all_formula_res
  277. for sub, formula_res in zip(formula_subs, all_formula_res):
  278. structure_res.append(
  279. {
  280. "input_path": formula_res["input_path"],
  281. "layout_bbox": sub["box"],
  282. "formula": "".join(formula_res["rec_text"]),
  283. }
  284. )
  285. use_ocr_without_layout = kwargs.get("use_ocr_without_layout", True)
  286. ocr_res = {
  287. "dt_polys": [],
  288. "rec_text": [],
  289. "input_path": layout_pred["input_path"],
  290. }
  291. if use_ocr_without_layout:
  292. ocr_res = get_ocr_res(self.ocr_pipeline, single_img)
  293. ocr_res["input_path"] = layout_pred["input_path"]
  294. for idx, single_dt_poly in enumerate(ocr_res["dt_polys"]):
  295. structure_res.append(
  296. {
  297. "input_path": ocr_res["input_path"],
  298. "layout_bbox": convert_4point2rect(single_dt_poly),
  299. "text_without_layout": ocr_res["rec_text"][idx],
  300. }
  301. )
  302. # update ocr result
  303. for layout_ocr_res in ocr_res_with_layout:
  304. ocr_res["dt_polys"].extend(layout_ocr_res["dt_polys"])
  305. ocr_res["rec_text"].extend(layout_ocr_res["rec_text"])
  306. ocr_res["rec_score"].extend(layout_ocr_res["rec_score"])
  307. ocr_res["input_path"] = single_img_res["input_path"]
  308. all_table_ocr_res = []
  309. all_table_res, _ = self.get_table_result(table_subs)
  310. # get table text from html
  311. structure_res_table, all_table_ocr_res = get_table_text_from_html(
  312. all_table_res
  313. )
  314. structure_res.extend(structure_res_table)
  315. # sort the layout result by the left top point of the box
  316. structure_res = sorted_layout_boxes(structure_res, w=single_img.shape[1])
  317. structure_res = LayoutParsingResult(
  318. {
  319. "input_path": layout_pred["input_path"],
  320. "parsing_result": structure_res,
  321. }
  322. )
  323. single_img_res["table_result"] = all_table_res
  324. single_img_res["ocr_result"] = ocr_res
  325. single_img_res["table_ocr_result"] = all_table_ocr_res
  326. single_img_res["layout_parsing_result"] = structure_res
  327. single_img_res["layout_parsing_result"]["page_id"] = page_id + 1
  328. yield VisualResult(single_img_res, page_id, inputs)
  329. def get_formula_res(predictor, input):
  330. """get formula res"""
  331. res_list = []
  332. if isinstance(input, list):
  333. img = [im["img"] for im in input]
  334. elif isinstance(input, dict):
  335. img = input["img"]
  336. else:
  337. img = input
  338. for res in predictor(img):
  339. res_list.append(res)
  340. return res_list