result.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import annotations
  15. from functools import partial
  16. import numpy as np
  17. from PIL import Image, ImageDraw, ImageFont
  18. from ....utils.fonts import PINGFANG_FONT
  19. from ...common.result import (
  20. BaseCVResult,
  21. HtmlMixin,
  22. JsonMixin,
  23. MarkdownMixin,
  24. XlsxMixin,
  25. )
  26. from ..layout_parsing.result_v2 import (
  27. format_centered_by_html,
  28. format_first_line_func,
  29. format_image_plain_func,
  30. format_image_scaled_by_html_func,
  31. format_text_plain_func,
  32. format_title_func,
  33. simplify_table_func,
  34. )
  35. VISUALIZE_INDEX_LABELS = [
  36. "text",
  37. "formula",
  38. "inline_formula",
  39. "display_formula",
  40. "algorithm",
  41. "reference",
  42. "reference_content",
  43. "content",
  44. "abstract",
  45. "paragraph_title",
  46. "doc_title",
  47. "vertical_text",
  48. "ocr",
  49. ]
  50. class PaddleOCRVLBlock(object):
  51. """PaddleOCRVL Block Class"""
  52. def __init__(self, label, bbox, content="") -> None:
  53. """
  54. Initialize a PaddleOCRVLBlock object.
  55. Args:
  56. label (str): Label assigned to the block.
  57. bbox (list): Bounding box coordinates of the block.
  58. content (str, optional): Content of the block. Defaults to an empty string.
  59. """
  60. self.label = label
  61. self.bbox = list(map(int, bbox))
  62. self.content = content
  63. self.image = None
  64. def __str__(self) -> str:
  65. """
  66. Return a string representation of the block.
  67. """
  68. _str = f"\n\n#################\nlabel:\t{self.label}\nbbox:\t{self.bbox}\ncontent:\t{self.content}\n#################"
  69. return _str
  70. def __repr__(self) -> str:
  71. """
  72. Return a string representation of the block.
  73. """
  74. _str = f"\n\n#################\nlabel:\t{self.label}\nbbox:\t{self.bbox}\ncontent:\t{self.content}\n#################"
  75. return _str
  76. def merge_formula_and_number(formula, formula_number):
  77. """
  78. Merge a formula and its formula number for display.
  79. Args:
  80. formula (str): The formula string.
  81. formula_number (str): The formula number string.
  82. Returns:
  83. str: The merged formula with tag.
  84. """
  85. formula = formula.replace("$$", "")
  86. merge_formula = r"{} \tag*{{{}}}".format(formula, formula_number)
  87. return f"$${merge_formula}$$"
  88. def format_chart2table_func(block):
  89. lines_list = block.content.split("\n")
  90. # 提取表头和内容
  91. header = lines_list[0].split("|")
  92. rows = [line.split("|") for line in lines_list[1:]]
  93. # 构造HTML表格
  94. html = "<table border=1 style='margin: auto; width: max-content;'>\n"
  95. html += (
  96. " <thead><tr>"
  97. + "".join(
  98. f"<th style='text-align: center;'>{cell.strip()}</th>" for cell in header
  99. )
  100. + "</tr></thead>\n"
  101. )
  102. html += " <tbody>\n"
  103. for row in rows:
  104. html += (
  105. " <tr>"
  106. + "".join(
  107. f"<td style='text-align: center;'>{cell.strip()}</td>" for cell in row
  108. )
  109. + "</tr>\n"
  110. )
  111. html += " </tbody>\n"
  112. html += "</table>"
  113. return html
  114. def format_table_center_func(block):
  115. tabel_content = block.content
  116. tabel_content = tabel_content.replace(
  117. "<table>", "<table border=1 style='margin: auto; width: max-content;'>"
  118. )
  119. tabel_content = tabel_content.replace("<th>", "<th style='text-align: center;'>")
  120. tabel_content = tabel_content.replace("<td>", "<td style='text-align: center;'>")
  121. return tabel_content
  122. def build_handle_funcs_dict(
  123. *,
  124. text_func,
  125. image_func,
  126. chart_func,
  127. table_func,
  128. formula_func,
  129. seal_func,
  130. ):
  131. """
  132. Build a dictionary mapping block labels to their formatting functions.
  133. Args:
  134. text_func: Function to format text blocks.
  135. image_func: Function to format image blocks.
  136. chart_func: Function to format chart blocks.
  137. table_func: Function to format table blocks.
  138. formula_func: Function to format formula blocks.
  139. seal_func: Function to format seal blocks.
  140. Returns:
  141. dict: A mapping from block label to handler function.
  142. """
  143. return {
  144. "paragraph_title": format_title_func,
  145. "abstract_title": format_title_func,
  146. "reference_title": format_title_func,
  147. "content_title": format_title_func,
  148. "doc_title": lambda block: f"# {block.content}".replace("-\n", "").replace(
  149. "\n", " "
  150. ),
  151. "table_title": text_func,
  152. "figure_title": text_func,
  153. "chart_title": text_func,
  154. "vision_footnote": lambda block: block.content.replace("\n\n", "\n").replace(
  155. "\n", "\n\n"
  156. ),
  157. "text": lambda block: block.content.replace("\n\n", "\n").replace("\n", "\n\n"),
  158. "ocr": lambda block: block.content.replace("\n\n", "\n").replace("\n", "\n\n"),
  159. "vertical_text": lambda block: block.content.replace("\n\n", "\n").replace(
  160. "\n", "\n\n"
  161. ),
  162. "reference_content": lambda block: block.content.replace("\n\n", "\n").replace(
  163. "\n", "\n\n"
  164. ),
  165. "abstract": partial(
  166. format_first_line_func,
  167. templates=["摘要", "abstract"],
  168. format_func=lambda l: f"## {l}\n",
  169. spliter=" ",
  170. ),
  171. "content": lambda block: block.content.replace("-\n", " \n").replace(
  172. "\n", " \n"
  173. ),
  174. "image": image_func,
  175. "chart": chart_func,
  176. "formula": formula_func,
  177. "display_formula": formula_func,
  178. "inline_formula": formula_func,
  179. "table": table_func,
  180. "reference": partial(
  181. format_first_line_func,
  182. templates=["参考文献", "references"],
  183. format_func=lambda l: f"## {l}",
  184. spliter="\n",
  185. ),
  186. "algorithm": lambda block: block.content.strip("\n"),
  187. "seal": seal_func,
  188. }
  189. class PaddleOCRVLResult(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
  190. """
  191. PaddleOCRVLResult class for holding and formatting OCR/VL parsing results.
  192. """
  193. def __init__(self, data) -> None:
  194. """
  195. Initializes a new instance of the class with the specified data.
  196. Args:
  197. data: The input data for the parsing result.
  198. """
  199. super().__init__(data)
  200. HtmlMixin.__init__(self)
  201. XlsxMixin.__init__(self)
  202. MarkdownMixin.__init__(self)
  203. JsonMixin.__init__(self)
  204. def _to_img(self) -> dict[str, np.ndarray]:
  205. """
  206. Convert the parsing result to a dictionary of images.
  207. Returns:
  208. dict: Keys are names, values are numpy arrays (images).
  209. """
  210. from ..layout_parsing.utils import get_show_color
  211. res_img_dict = {}
  212. model_settings = self["model_settings"]
  213. if model_settings["use_doc_preprocessor"]:
  214. for key, value in self["doc_preprocessor_res"].img.items():
  215. res_img_dict[key] = value
  216. if self["model_settings"]["use_layout_detection"]:
  217. res_img_dict["layout_det_res"] = self["layout_det_res"].img["res"]
  218. # for layout ordering image
  219. image = Image.fromarray(self["doc_preprocessor_res"]["output_img"][:, :, ::-1])
  220. draw = ImageDraw.Draw(image, "RGBA")
  221. font_size = int(0.018 * int(image.width)) + 2
  222. font = ImageFont.truetype(PINGFANG_FONT.path, font_size, encoding="utf-8")
  223. parsing_result = self["parsing_res_list"]
  224. order_index = 0
  225. for block in parsing_result:
  226. bbox = block.bbox
  227. label = block.label
  228. fill_color = get_show_color(label, False)
  229. draw.rectangle(bbox, fill=fill_color)
  230. if label in VISUALIZE_INDEX_LABELS:
  231. text_position = (bbox[2] + 2, bbox[1] - font_size // 2)
  232. if int(image.width) - bbox[2] < font_size:
  233. text_position = (
  234. int(bbox[2] - font_size * 1.1),
  235. bbox[1] - font_size // 2,
  236. )
  237. draw.text(text_position, str(order_index + 1), font=font, fill="red")
  238. order_index += 1
  239. res_img_dict["layout_order_res"] = image
  240. return res_img_dict
  241. def _to_html(self) -> dict[str, str]:
  242. """
  243. Converts the prediction to its corresponding HTML representation.
  244. Returns:
  245. dict: The str type HTML representation result.
  246. """
  247. res_html_dict = {}
  248. if len(self["table_res_list"]) > 0:
  249. for sno in range(len(self["table_res_list"])):
  250. table_res = self["table_res_list"][sno]
  251. table_region_id = table_res["table_region_id"]
  252. key = f"table_{table_region_id}"
  253. res_html_dict[key] = table_res.html["pred"]
  254. return res_html_dict
  255. def _to_xlsx(self) -> dict[str, str]:
  256. """
  257. Converts the prediction HTML to an XLSX file path.
  258. Returns:
  259. dict: The str type XLSX representation result.
  260. """
  261. res_xlsx_dict = {}
  262. if len(self["table_res_list"]) > 0:
  263. for sno in range(len(self["table_res_list"])):
  264. table_res = self["table_res_list"][sno]
  265. table_region_id = table_res["table_region_id"]
  266. key = f"table_{table_region_id}"
  267. res_xlsx_dict[key] = table_res.xlsx["pred"]
  268. return res_xlsx_dict
  269. def _to_str(self, *args, **kwargs) -> dict[str, str]:
  270. """
  271. Converts the instance's attributes to a dictionary and then to a string.
  272. Args:
  273. *args: Additional positional arguments passed to the base class method.
  274. **kwargs: Additional keyword arguments passed to the base class method.
  275. Returns:
  276. dict: A dictionary with the instance's attributes converted to strings.
  277. """
  278. data = {}
  279. data["input_path"] = self["input_path"]
  280. data["page_index"] = self["page_index"]
  281. model_settings = self["model_settings"]
  282. data["model_settings"] = model_settings
  283. if self["model_settings"]["use_doc_preprocessor"]:
  284. data["doc_preprocessor_res"] = self["doc_preprocessor_res"].str["res"]
  285. if self["model_settings"]["use_layout_detection"]:
  286. data["layout_det_res"] = self["layout_det_res"].str["res"]
  287. parsing_res_list = self["parsing_res_list"]
  288. parsing_res_list = [
  289. {
  290. "block_label": parsing_res.label,
  291. "block_content": parsing_res.content,
  292. "block_bbox": parsing_res.bbox,
  293. }
  294. for parsing_res in parsing_res_list
  295. ]
  296. data["parsing_res_list"] = parsing_res_list
  297. return JsonMixin._to_str(data, *args, **kwargs)
  298. def _to_json(self, *args, **kwargs) -> dict[str, str]:
  299. """
  300. Converts the object's data to a JSON dictionary.
  301. Args:
  302. *args: Positional arguments passed to the JsonMixin._to_json method.
  303. **kwargs: Keyword arguments passed to the JsonMixin._to_json method.
  304. Returns:
  305. dict: A dictionary containing the object's data in JSON format.
  306. """
  307. data = {}
  308. data["input_path"] = self["input_path"]
  309. data["page_index"] = self["page_index"]
  310. model_settings = self["model_settings"]
  311. data["model_settings"] = model_settings
  312. if self["model_settings"].get("format_block_content", False):
  313. original_image_width = self["doc_preprocessor_res"]["output_img"].shape[1]
  314. format_text_func = lambda block: format_centered_by_html(
  315. format_text_plain_func(block)
  316. )
  317. format_image_func = lambda block: format_centered_by_html(
  318. format_image_scaled_by_html_func(
  319. block,
  320. original_image_width=original_image_width,
  321. )
  322. )
  323. if self["model_settings"].get("use_chart_recognition", False):
  324. format_chart_func = format_chart2table_func
  325. else:
  326. format_chart_func = format_image_func
  327. format_seal_func = format_image_func
  328. format_table_func = lambda block: "\n" + format_table_center_func(block)
  329. format_formula_func = lambda block: block.content
  330. handle_funcs_dict = build_handle_funcs_dict(
  331. text_func=format_text_func,
  332. image_func=format_image_func,
  333. chart_func=format_chart_func,
  334. table_func=format_table_func,
  335. formula_func=format_formula_func,
  336. seal_func=format_seal_func,
  337. )
  338. parsing_res_list = self["parsing_res_list"]
  339. parsing_res_list_json = []
  340. order_index = 1
  341. for idx, parsing_res in enumerate(parsing_res_list):
  342. label = parsing_res.label
  343. if label in VISUALIZE_INDEX_LABELS:
  344. order = order_index
  345. order_index += 1
  346. else:
  347. order = None
  348. res_dict = {
  349. "block_label": parsing_res.label,
  350. "block_content": parsing_res.content,
  351. "block_bbox": parsing_res.bbox,
  352. "block_id": idx,
  353. "block_order": order,
  354. }
  355. if self["model_settings"].get("format_block_content", False):
  356. if handle_funcs_dict.get(parsing_res.label):
  357. res_dict["block_content"] = handle_funcs_dict[parsing_res.label](
  358. parsing_res
  359. )
  360. else:
  361. res_dict["block_content"] = parsing_res.content
  362. parsing_res_list_json.append(res_dict)
  363. data["parsing_res_list"] = parsing_res_list_json
  364. if self["model_settings"]["use_doc_preprocessor"]:
  365. data["doc_preprocessor_res"] = self["doc_preprocessor_res"].json["res"]
  366. if self["model_settings"]["use_layout_detection"]:
  367. data["layout_det_res"] = self["layout_det_res"].json["res"]
  368. return JsonMixin._to_json(data, *args, **kwargs)
  369. def _to_markdown(self, pretty=True, show_formula_number=False) -> dict:
  370. """
  371. Save the parsing result to a Markdown file.
  372. Args:
  373. pretty (Optional[bool]): whether to pretty markdown by HTML, default by True.
  374. show_formula_number (bool): whether to show formula numbers.
  375. Returns:
  376. dict: Markdown information with text and images.
  377. """
  378. original_image_width = self["doc_preprocessor_res"]["output_img"].shape[1]
  379. if pretty:
  380. format_text_func = lambda block: format_centered_by_html(
  381. format_text_plain_func(block)
  382. )
  383. format_image_func = lambda block: format_centered_by_html(
  384. format_image_scaled_by_html_func(
  385. block,
  386. original_image_width=original_image_width,
  387. )
  388. )
  389. else:
  390. format_text_func = lambda block: block.content
  391. format_image_func = format_image_plain_func
  392. format_chart_func = (
  393. format_chart2table_func
  394. if self["model_settings"]["use_chart_recognition"]
  395. else format_image_func
  396. )
  397. if pretty:
  398. format_table_func = lambda block: "\n" + format_table_center_func(block)
  399. else:
  400. format_table_func = lambda block: simplify_table_func("\n" + block.content)
  401. format_formula_func = lambda block: block.content
  402. format_seal_func = format_image_func
  403. handle_funcs_dict = build_handle_funcs_dict(
  404. text_func=format_text_func,
  405. image_func=format_image_func,
  406. chart_func=format_chart_func,
  407. table_func=format_table_func,
  408. formula_func=format_formula_func,
  409. seal_func=format_seal_func,
  410. )
  411. markdown_content = ""
  412. markdown_info = {}
  413. markdown_info["markdown_images"] = {}
  414. for idx, block in enumerate(self["parsing_res_list"]):
  415. label = block.label
  416. if block.image is not None:
  417. markdown_info["markdown_images"][block.image["path"]] = block.image[
  418. "img"
  419. ]
  420. handle_func = handle_funcs_dict.get(label, None)
  421. if (
  422. show_formula_number
  423. and (label == "display_formula" or label == "formula")
  424. and idx != len(self["parsing_res_list"]) - 1
  425. ):
  426. next_block = self["parsing_res_list"][idx + 1]
  427. next_block_label = next_block.label
  428. if next_block_label == "formula_number":
  429. block.content = merge_formula_and_number(
  430. block.content, next_block.content
  431. )
  432. if handle_func:
  433. markdown_content += (
  434. "\n\n" + handle_func(block)
  435. if markdown_content
  436. else handle_func(block)
  437. )
  438. markdown_info["page_index"] = self["page_index"]
  439. markdown_info["input_path"] = self["input_path"]
  440. markdown_info["markdown_texts"] = markdown_content
  441. for img in self["imgs_in_doc"]:
  442. markdown_info["markdown_images"][img["path"]] = img["img"]
  443. return markdown_info