result.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import sys
  16. import cv2
  17. import PIL
  18. import math
  19. import random
  20. import tempfile
  21. import subprocess
  22. import numpy as np
  23. from pathlib import Path
  24. from PIL import Image, ImageDraw, ImageFont
  25. from ...common.result import BaseCVResult
  26. from ....utils import logging
  27. from ....utils.fonts import PINGFANG_FONT_FILE_PATH
  28. class FormulaRecResult(BaseCVResult):
  29. def _to_str(self, *args, **kwargs):
  30. return super()._to_str(*args, **kwargs).replace("\\\\", "\\")
  31. def _to_img(
  32. self,
  33. ):
  34. """Draw formula on image"""
  35. image = Image.fromarray(self._input_img)
  36. try:
  37. env_valid()
  38. except subprocess.CalledProcessError as e:
  39. logging.warning(
  40. "Please refer to 2.3 Formula Recognition Pipeline Visualization in Formula Recognition Pipeline Tutorial to install the LaTeX rendering engine at first."
  41. )
  42. return image
  43. rec_formula = str(self["rec_formula"])
  44. image = np.array(image.convert("RGB"))
  45. xywh = crop_white_area(image)
  46. if xywh is not None:
  47. x, y, w, h = xywh
  48. image = image[y : y + h, x : x + w]
  49. image = Image.fromarray(image)
  50. image_width, image_height = image.size
  51. box = [[0, 0], [image_width, 0], [image_width, image_height], [0, image_height]]
  52. try:
  53. img_formula = draw_formula_module(
  54. image.size, box, rec_formula, is_debug=False
  55. )
  56. img_formula = Image.fromarray(img_formula)
  57. render_width, render_height = img_formula.size
  58. resize_height = render_height
  59. resize_width = int(resize_height * image_width / image_height)
  60. image = image.resize((resize_width, resize_height), Image.LANCZOS)
  61. new_image_width = image.width + int(render_width) + 10
  62. new_image = Image.new(
  63. "RGB", (new_image_width, render_height), (255, 255, 255)
  64. )
  65. new_image.paste(image, (0, 0))
  66. new_image.paste(img_formula, (image.width + 10, 0))
  67. return new_image
  68. except subprocess.CalledProcessError as e:
  69. logging.warning("Syntax error detected in formula, rendering failed.")
  70. return image
  71. def get_align_equation(equation):
  72. is_align = False
  73. equation = str(equation) + "\n"
  74. begin_dict = [
  75. r"begin{align}",
  76. r"begin{align*}",
  77. ]
  78. for begin_sym in begin_dict:
  79. if begin_sym in equation:
  80. is_align = True
  81. break
  82. if not is_align:
  83. equation = (
  84. r"\begin{equation}"
  85. + "\n"
  86. + equation.strip()
  87. + r"\nonumber"
  88. + "\n"
  89. + r"\end{equation}"
  90. + "\n"
  91. )
  92. return equation
  93. def generate_tex_file(tex_file_path, equation):
  94. with open(tex_file_path, "w") as fp:
  95. start_template = (
  96. r"\documentclass{article}" + "\n"
  97. r"\usepackage{cite}" + "\n"
  98. r"\usepackage{amsmath,amssymb,amsfonts}" + "\n"
  99. r"\usepackage{graphicx}" + "\n"
  100. r"\usepackage{textcomp}" + "\n"
  101. r"\DeclareMathSizes{14}{14}{9.8}{7}" + "\n"
  102. r"\pagestyle{empty}" + "\n"
  103. r"\begin{document}" + "\n"
  104. r"\begin{large}" + "\n"
  105. )
  106. fp.write(start_template)
  107. equation = get_align_equation(equation)
  108. fp.write(equation)
  109. end_template = r"\end{large}" + "\n" r"\end{document}" + "\n"
  110. fp.write(end_template)
  111. def generate_pdf_file(tex_path, pdf_dir, is_debug=False):
  112. if os.path.exists(tex_path):
  113. command = "pdflatex -halt-on-error -output-directory={} {}".format(
  114. pdf_dir, tex_path
  115. )
  116. if is_debug:
  117. subprocess.check_call(command, shell=True)
  118. else:
  119. devNull = open(os.devnull, "w")
  120. subprocess.check_call(
  121. command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True
  122. )
  123. def crop_white_area(image):
  124. image = np.array(image).astype("uint8")
  125. gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
  126. _, thresh = cv2.threshold(gray, 240, 255, cv2.THRESH_BINARY_INV)
  127. contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  128. if len(contours) > 0:
  129. x, y, w, h = cv2.boundingRect(np.concatenate(contours))
  130. return [x, y, w, h]
  131. else:
  132. return None
  133. def pdf2img(pdf_path, img_path, is_padding=False):
  134. import fitz
  135. pdfDoc = fitz.open(pdf_path)
  136. if pdfDoc.page_count != 1:
  137. return None
  138. for pg in range(pdfDoc.page_count):
  139. page = pdfDoc[pg]
  140. rotate = int(0)
  141. zoom_x = 2
  142. zoom_y = 2
  143. mat = fitz.Matrix(zoom_x, zoom_y).prerotate(rotate)
  144. pix = page.get_pixmap(matrix=mat, alpha=False)
  145. if not os.path.exists(img_path):
  146. os.makedirs(img_path)
  147. pix._writeIMG(img_path, 7, 100)
  148. img = cv2.imread(img_path)
  149. xywh = crop_white_area(img)
  150. if xywh is not None:
  151. x, y, w, h = xywh
  152. img = img[y : y + h, x : x + w]
  153. if is_padding:
  154. img = cv2.copyMakeBorder(
  155. img, 30, 30, 30, 30, cv2.BORDER_CONSTANT, value=(255, 255, 255)
  156. )
  157. return img
  158. return None
  159. def draw_formula_module(img_size, box, formula, is_debug=False):
  160. """draw box formula for module"""
  161. box_width, box_height = img_size
  162. with tempfile.TemporaryDirectory() as td:
  163. tex_file_path = os.path.join(td, "temp.tex")
  164. pdf_file_path = os.path.join(td, "temp.pdf")
  165. img_file_path = os.path.join(td, "temp.jpg")
  166. generate_tex_file(tex_file_path, formula)
  167. if os.path.exists(tex_file_path):
  168. generate_pdf_file(tex_file_path, td, is_debug)
  169. formula_img = None
  170. if os.path.exists(pdf_file_path):
  171. formula_img = pdf2img(pdf_file_path, img_file_path, is_padding=False)
  172. if formula_img is not None:
  173. return formula_img
  174. else:
  175. img_right_text = draw_box_txt_fine(
  176. img_size, box, "Rendering Failed", PINGFANG_FONT_FILE_PATH
  177. )
  178. return img_right_text
  179. def env_valid():
  180. with tempfile.TemporaryDirectory() as td:
  181. tex_file_path = os.path.join(td, "temp.tex")
  182. pdf_file_path = os.path.join(td, "temp.pdf")
  183. img_file_path = os.path.join(td, "temp.jpg")
  184. formula = "a+b=c"
  185. is_debug = False
  186. generate_tex_file(tex_file_path, formula)
  187. if os.path.exists(tex_file_path):
  188. generate_pdf_file(tex_file_path, td, is_debug)
  189. if os.path.exists(pdf_file_path):
  190. formula_img = pdf2img(pdf_file_path, img_file_path, is_padding=False)
  191. def draw_box_formula_fine(img_size, box, formula, is_debug=False):
  192. """draw box formula for pipeline"""
  193. box_height = int(
  194. math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2)
  195. )
  196. box_width = int(
  197. math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2)
  198. )
  199. with tempfile.TemporaryDirectory() as td:
  200. tex_file_path = os.path.join(td, "temp.tex")
  201. pdf_file_path = os.path.join(td, "temp.pdf")
  202. img_file_path = os.path.join(td, "temp.jpg")
  203. generate_tex_file(tex_file_path, formula)
  204. if os.path.exists(tex_file_path):
  205. generate_pdf_file(tex_file_path, td, is_debug)
  206. formula_img = None
  207. if os.path.exists(pdf_file_path):
  208. formula_img = pdf2img(pdf_file_path, img_file_path, is_padding=False)
  209. if formula_img is not None:
  210. formula_h, formula_w = formula_img.shape[:-1]
  211. resize_height = box_height
  212. resize_width = formula_w * resize_height / formula_h
  213. formula_img = cv2.resize(
  214. formula_img, (int(resize_width), int(resize_height))
  215. )
  216. formula_h, formula_w = formula_img.shape[:-1]
  217. pts1 = np.float32(
  218. [[0, 0], [box_width, 0], [box_width, box_height], [0, box_height]]
  219. )
  220. pts2 = np.array(box, dtype=np.float32)
  221. M = cv2.getPerspectiveTransform(pts1, pts2)
  222. formula_img = np.array(formula_img, dtype=np.uint8)
  223. img_right_text = cv2.warpPerspective(
  224. formula_img,
  225. M,
  226. img_size,
  227. flags=cv2.INTER_NEAREST,
  228. borderMode=cv2.BORDER_CONSTANT,
  229. borderValue=(255, 255, 255),
  230. )
  231. else:
  232. img_right_text = draw_box_txt_fine(
  233. img_size, box, "Rendering Failed", PINGFANG_FONT_FILE_PATH
  234. )
  235. return img_right_text
  236. def draw_box_txt_fine(img_size, box, txt, font_path):
  237. """draw box text"""
  238. box_height = int(
  239. math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2)
  240. )
  241. box_width = int(
  242. math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2)
  243. )
  244. if box_height > 2 * box_width and box_height > 30:
  245. img_text = Image.new("RGB", (box_height, box_width), (255, 255, 255))
  246. draw_text = ImageDraw.Draw(img_text)
  247. if txt:
  248. font = create_font(txt, (box_height, box_width), font_path)
  249. draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)
  250. img_text = img_text.transpose(Image.ROTATE_270)
  251. else:
  252. img_text = Image.new("RGB", (box_width, box_height), (255, 255, 255))
  253. draw_text = ImageDraw.Draw(img_text)
  254. if txt:
  255. font = create_font(txt, (box_width, box_height), font_path)
  256. draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)
  257. pts1 = np.float32(
  258. [[0, 0], [box_width, 0], [box_width, box_height], [0, box_height]]
  259. )
  260. pts2 = np.array(box, dtype=np.float32)
  261. M = cv2.getPerspectiveTransform(pts1, pts2)
  262. img_text = np.array(img_text, dtype=np.uint8)
  263. img_right_text = cv2.warpPerspective(
  264. img_text,
  265. M,
  266. img_size,
  267. flags=cv2.INTER_NEAREST,
  268. borderMode=cv2.BORDER_CONSTANT,
  269. borderValue=(255, 255, 255),
  270. )
  271. return img_right_text
  272. def create_font(txt, sz, font_path):
  273. """create font"""
  274. font_size = int(sz[1] * 0.8)
  275. font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
  276. if int(PIL.__version__.split(".")[0]) < 10:
  277. length = font.getsize(txt)[0]
  278. else:
  279. length = font.getlength(txt)
  280. if length > sz[0]:
  281. font_size = int(font_size * sz[0] / length)
  282. font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
  283. return font