formula_rec.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import sys
  16. import cv2
  17. import math
  18. import random
  19. import tempfile
  20. import subprocess
  21. import numpy as np
  22. from PIL import Image, ImageDraw
  23. from .base import CVResult
  24. from ...utils import logging
  25. from .ocr import draw_box_txt_fine
  26. from ...utils.fonts import PINGFANG_FONT_FILE_PATH
  27. class FormulaRecResult(CVResult):
  28. def _to_str(self, *args, **kwargs):
  29. return super()._to_str(*args, **kwargs).replace("\\\\", "\\")
  30. def _to_img(
  31. self,
  32. ):
  33. """Draw formula on image"""
  34. image = self._img_reader.read(self["input_path"])
  35. rec_formula = str(self["rec_text"])
  36. image = np.array(image.convert("RGB"))
  37. xywh = crop_white_area(image)
  38. if xywh is not None:
  39. x, y, w, h = xywh
  40. image = image[y : y + h, x : x + w]
  41. image = Image.fromarray(image)
  42. image_width, image_height = image.size
  43. box = [[0, 0], [image_width, 0], [image_width, image_height], [0, image_height]]
  44. try:
  45. img_formula = draw_formula_module(
  46. image.size, box, rec_formula, is_debug=False
  47. )
  48. img_formula = Image.fromarray(img_formula)
  49. render_width, render_height = img_formula.size
  50. resize_height = render_height
  51. resize_width = int(resize_height * image_width / image_height)
  52. image = image.resize((resize_width, resize_height), Image.LANCZOS)
  53. new_image_width = image.width + int(render_width) + 10
  54. new_image = Image.new(
  55. "RGB", (new_image_width, render_height), (255, 255, 255)
  56. )
  57. new_image.paste(image, (0, 0))
  58. new_image.paste(img_formula, (image.width + 10, 0))
  59. return new_image
  60. except subprocess.CalledProcessError as e:
  61. logging.warning(
  62. "Please refer to 2.3 Formula Recognition Pipeline Visualization in Formula Recognition Pipeline Tutorial to install the LaTeX rendering engine at first."
  63. )
  64. return None
  65. class FormulaResult(CVResult):
  66. def _to_str(self, *args, **kwargs):
  67. return super()._to_str(*args, **kwargs).replace("\\\\", "\\")
  68. def _to_img(
  69. self,
  70. ):
  71. """draw formula result"""
  72. boxes = self["dt_polys"]
  73. formulas = self["rec_formula"]
  74. image = self._img_reader.read(self["input_path"])
  75. h, w = image.height, image.width
  76. img_left = image.copy()
  77. img_right = np.ones((h, w, 3), dtype=np.uint8) * 255
  78. random.seed(0)
  79. draw_left = ImageDraw.Draw(img_left)
  80. if formulas is None or len(formulas) != len(boxes):
  81. formulas = [None] * len(boxes)
  82. for idx, (box, formula) in enumerate(zip(boxes, formulas)):
  83. try:
  84. color = (
  85. random.randint(0, 255),
  86. random.randint(0, 255),
  87. random.randint(0, 255),
  88. )
  89. box = np.array(box)
  90. pts = [(x, y) for x, y in box.tolist()]
  91. draw_left.polygon(pts, outline=color, width=8)
  92. draw_left.polygon(box, fill=color)
  93. img_right_text = draw_box_formula_fine(
  94. (w, h),
  95. box,
  96. formula,
  97. is_debug=False,
  98. )
  99. pts = np.array(box, np.int32).reshape((-1, 1, 2))
  100. cv2.polylines(img_right_text, [pts], True, color, 1)
  101. img_right = cv2.bitwise_and(img_right, img_right_text)
  102. except subprocess.CalledProcessError as e:
  103. logging.warning(
  104. "Please refer to 2.3 Formula Recognition Pipeline Visualization in Formula Recognition Pipeline Tutorial to install the LaTeX rendering engine at first."
  105. )
  106. return None
  107. img_left = Image.blend(image, img_left, 0.5)
  108. img_show = Image.new("RGB", (int(w * 2), h), (255, 255, 255))
  109. img_show.paste(img_left, (0, 0, w, h))
  110. img_show.paste(Image.fromarray(img_right), (w, 0, w * 2, h))
  111. return img_show
  112. def get_align_equation(equation):
  113. is_align = False
  114. equation = str(equation) + "\n"
  115. begin_dict = [
  116. r"begin{align}",
  117. r"begin{align*}",
  118. ]
  119. for begin_sym in begin_dict:
  120. if begin_sym in equation:
  121. is_align = True
  122. break
  123. if not is_align:
  124. equation = (
  125. r"\begin{equation}"
  126. + "\n"
  127. + equation.strip()
  128. + r"\nonumber"
  129. + "\n"
  130. + r"\end{equation}"
  131. + "\n"
  132. )
  133. return equation
  134. def generate_tex_file(tex_file_path, equation):
  135. with open(tex_file_path, "w") as fp:
  136. start_template = (
  137. r"\documentclass{article}" + "\n"
  138. r"\usepackage{cite}" + "\n"
  139. r"\usepackage{amsmath,amssymb,amsfonts}" + "\n"
  140. r"\usepackage{graphicx}" + "\n"
  141. r"\usepackage{textcomp}" + "\n"
  142. r"\DeclareMathSizes{14}{14}{9.8}{7}" + "\n"
  143. r"\pagestyle{empty}" + "\n"
  144. r"\begin{document}" + "\n"
  145. r"\begin{large}" + "\n"
  146. )
  147. fp.write(start_template)
  148. equation = get_align_equation(equation)
  149. fp.write(equation)
  150. end_template = r"\end{large}" + "\n" r"\end{document}" + "\n"
  151. fp.write(end_template)
  152. def generate_pdf_file(tex_path, pdf_dir, is_debug=False):
  153. if os.path.exists(tex_path):
  154. command = "pdflatex -halt-on-error -output-directory={} {}".format(
  155. pdf_dir, tex_path
  156. )
  157. if is_debug:
  158. subprocess.check_call(command, shell=True)
  159. else:
  160. devNull = open(os.devnull, "w")
  161. subprocess.check_call(
  162. command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True
  163. )
  164. def crop_white_area(image):
  165. image = np.array(image).astype("uint8")
  166. gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
  167. _, thresh = cv2.threshold(gray, 240, 255, cv2.THRESH_BINARY_INV)
  168. contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  169. if len(contours) > 0:
  170. x, y, w, h = cv2.boundingRect(np.concatenate(contours))
  171. return [x, y, w, h]
  172. else:
  173. return None
  174. def pdf2img(pdf_path, img_path, is_padding=False):
  175. import fitz
  176. pdfDoc = fitz.open(pdf_path)
  177. if pdfDoc.page_count != 1:
  178. return None
  179. for pg in range(pdfDoc.page_count):
  180. page = pdfDoc[pg]
  181. rotate = int(0)
  182. zoom_x = 2
  183. zoom_y = 2
  184. mat = fitz.Matrix(zoom_x, zoom_y).prerotate(rotate)
  185. pix = page.get_pixmap(matrix=mat, alpha=False)
  186. if not os.path.exists(img_path):
  187. os.makedirs(img_path)
  188. pix._writeIMG(img_path, 7, 100)
  189. img = cv2.imread(img_path)
  190. xywh = crop_white_area(img)
  191. if xywh is not None:
  192. x, y, w, h = xywh
  193. img = img[y : y + h, x : x + w]
  194. if is_padding:
  195. img = cv2.copyMakeBorder(
  196. img, 30, 30, 30, 30, cv2.BORDER_CONSTANT, value=(255, 255, 255)
  197. )
  198. return img
  199. return None
  200. def draw_formula_module(img_size, box, formula, is_debug=False):
  201. """draw box formula for module"""
  202. box_width, box_height = img_size
  203. with tempfile.TemporaryDirectory() as td:
  204. tex_file_path = os.path.join(td, "temp.tex")
  205. pdf_file_path = os.path.join(td, "temp.pdf")
  206. img_file_path = os.path.join(td, "temp.jpg")
  207. generate_tex_file(tex_file_path, formula)
  208. if os.path.exists(tex_file_path):
  209. generate_pdf_file(tex_file_path, td, is_debug)
  210. formula_img = None
  211. if os.path.exists(pdf_file_path):
  212. formula_img = pdf2img(pdf_file_path, img_file_path, is_padding=False)
  213. if formula_img is not None:
  214. return formula_img
  215. else:
  216. img_right_text = draw_box_txt_fine(
  217. img_size, box, "Rendering Failed", PINGFANG_FONT_FILE_PATH
  218. )
  219. return img_right_text
  220. def draw_box_formula_fine(img_size, box, formula, is_debug=False):
  221. """draw box formula for pipeline"""
  222. box_height = int(
  223. math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2)
  224. )
  225. box_width = int(
  226. math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2)
  227. )
  228. with tempfile.TemporaryDirectory() as td:
  229. tex_file_path = os.path.join(td, "temp.tex")
  230. pdf_file_path = os.path.join(td, "temp.pdf")
  231. img_file_path = os.path.join(td, "temp.jpg")
  232. generate_tex_file(tex_file_path, formula)
  233. if os.path.exists(tex_file_path):
  234. generate_pdf_file(tex_file_path, td, is_debug)
  235. formula_img = None
  236. if os.path.exists(pdf_file_path):
  237. formula_img = pdf2img(pdf_file_path, img_file_path, is_padding=False)
  238. if formula_img is not None:
  239. formula_h, formula_w = formula_img.shape[:-1]
  240. resize_height = box_height
  241. resize_width = formula_w * resize_height / formula_h
  242. formula_img = cv2.resize(
  243. formula_img, (int(resize_width), int(resize_height))
  244. )
  245. formula_h, formula_w = formula_img.shape[:-1]
  246. pts1 = np.float32(
  247. [[0, 0], [box_width, 0], [box_width, box_height], [0, box_height]]
  248. )
  249. pts2 = np.array(box, dtype=np.float32)
  250. M = cv2.getPerspectiveTransform(pts1, pts2)
  251. formula_img = np.array(formula_img, dtype=np.uint8)
  252. img_right_text = cv2.warpPerspective(
  253. formula_img,
  254. M,
  255. img_size,
  256. flags=cv2.INTER_NEAREST,
  257. borderMode=cv2.BORDER_CONSTANT,
  258. borderValue=(255, 255, 255),
  259. )
  260. else:
  261. img_right_text = draw_box_txt_fine(
  262. img_size, box, "Rendering Failed", PINGFANG_FONT_FILE_PATH
  263. )
  264. return img_right_text