batch_analyze.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. import time
  2. import cv2
  3. import numpy as np
  4. import torch
  5. from loguru import logger
  6. from PIL import Image
  7. from magic_pdf.config.constants import MODEL_NAME
  8. from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
  9. from magic_pdf.data.dataset import Dataset
  10. from magic_pdf.libs.clean_memory import clean_memory
  11. from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
  12. from magic_pdf.model.operators import InferenceResult
  13. from magic_pdf.model.pdf_extract_kit import CustomPEKModel
  14. from magic_pdf.model.sub_modules.model_utils import (
  15. clean_vram,
  16. crop_img,
  17. get_res_list_from_layout_res,
  18. )
  19. from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
  20. get_adjusted_mfdetrec_res,
  21. get_ocr_result_list,
  22. )
  23. YOLO_LAYOUT_BASE_BATCH_SIZE = 4
  24. MFD_BASE_BATCH_SIZE = 1
  25. MFR_BASE_BATCH_SIZE = 16
  26. class BatchAnalyze:
  27. def __init__(self, model: CustomPEKModel, batch_ratio: int):
  28. self.model = model
  29. self.batch_ratio = batch_ratio
  30. def __call__(self, images: list) -> list:
  31. if self.model.layout_model_name == MODEL_NAME.LAYOUTLMv3:
  32. # layoutlmv3
  33. images_layout_res = []
  34. for image in images:
  35. layout_res = self.model.layout_model(image, ignore_catids=[])
  36. images_layout_res.append(layout_res)
  37. elif self.model.layout_model_name == MODEL_NAME.DocLayout_YOLO:
  38. # doclayout_yolo
  39. images_layout_res = self.model.layout_model.batch_predict(
  40. images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
  41. )
  42. if self.model.apply_formula:
  43. # 公式检测
  44. images_mfd_res = self.model.mfd_model.batch_predict(
  45. images, self.batch_ratio * MFD_BASE_BATCH_SIZE
  46. )
  47. # 公式识别
  48. images_formula_list = self.model.mfr_model.batch_predict(
  49. images_mfd_res,
  50. images,
  51. batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
  52. )
  53. for image_index in range(len(images)):
  54. images_layout_res[image_index] += images_formula_list[image_index]
  55. # 清理显存
  56. clean_vram(self.model.device, vram_threshold=8)
  57. # reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
  58. for index in range(len(images)):
  59. layout_res = images_layout_res[index]
  60. pil_img = Image.fromarray(images[index])
  61. ocr_res_list, table_res_list, single_page_mfdetrec_res = (
  62. get_res_list_from_layout_res(layout_res)
  63. )
  64. # ocr识别
  65. ocr_start = time.time()
  66. # Process each area that requires OCR processing
  67. for res in ocr_res_list:
  68. new_image, useful_list = crop_img(
  69. res, pil_img, crop_paste_x=50, crop_paste_y=50
  70. )
  71. adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
  72. single_page_mfdetrec_res, useful_list
  73. )
  74. # OCR recognition
  75. new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
  76. if self.model.apply_ocr:
  77. ocr_res = self.model.ocr_model.ocr(
  78. new_image, mfd_res=adjusted_mfdetrec_res
  79. )[0]
  80. else:
  81. ocr_res = self.model.ocr_model.ocr(
  82. new_image, mfd_res=adjusted_mfdetrec_res, rec=False
  83. )[0]
  84. # Integration results
  85. if ocr_res:
  86. ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
  87. layout_res.extend(ocr_result_list)
  88. ocr_cost = round(time.time() - ocr_start, 2)
  89. if self.model.apply_ocr:
  90. logger.info(f"ocr time: {ocr_cost}")
  91. else:
  92. logger.info(f"det time: {ocr_cost}")
  93. # 表格识别 table recognition
  94. if self.model.apply_table:
  95. table_start = time.time()
  96. for res in table_res_list:
  97. new_image, _ = crop_img(res, pil_img)
  98. single_table_start_time = time.time()
  99. html_code = None
  100. if self.model.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
  101. with torch.no_grad():
  102. table_result = self.model.table_model.predict(
  103. new_image, "html"
  104. )
  105. if len(table_result) > 0:
  106. html_code = table_result[0]
  107. elif self.model.table_model_name == MODEL_NAME.TABLE_MASTER:
  108. html_code = self.model.table_model.img2html(new_image)
  109. elif self.model.table_model_name == MODEL_NAME.RAPID_TABLE:
  110. html_code, table_cell_bboxes, elapse = (
  111. self.model.table_model.predict(new_image)
  112. )
  113. run_time = time.time() - single_table_start_time
  114. if run_time > self.model.table_max_time:
  115. logger.warning(
  116. f"table recognition processing exceeds max time {self.model.table_max_time}s"
  117. )
  118. # 判断是否返回正常
  119. if html_code:
  120. expected_ending = html_code.strip().endswith(
  121. "</html>"
  122. ) or html_code.strip().endswith("</table>")
  123. if expected_ending:
  124. res["html"] = html_code
  125. else:
  126. logger.warning(
  127. "table recognition processing fails, not found expected HTML table end"
  128. )
  129. else:
  130. logger.warning(
  131. "table recognition processing fails, not get html return"
  132. )
  133. logger.info(f"table time: {round(time.time() - table_start, 2)}")
  134. def doc_batch_analyze(
  135. dataset: Dataset,
  136. ocr: bool = False,
  137. show_log: bool = False,
  138. start_page_id=0,
  139. end_page_id=None,
  140. lang=None,
  141. layout_model=None,
  142. formula_enable=None,
  143. table_enable=None,
  144. batch_ratio: int | None = None,
  145. ) -> InferenceResult:
  146. """
  147. Perform batch analysis on a document dataset.
  148. Args:
  149. dataset (Dataset): The dataset containing document pages to be analyzed.
  150. ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
  151. show_log (bool, optional): Flag to enable logging. Defaults to False.
  152. start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
  153. end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
  154. lang (str, optional): Language for OCR. Defaults to None.
  155. layout_model (optional): Layout model to be used for analysis. Defaults to None.
  156. formula_enable (optional): Flag to enable formula detection. Defaults to None.
  157. table_enable (optional): Flag to enable table detection. Defaults to None.
  158. batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
  159. Raises:
  160. CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
  161. Returns:
  162. InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
  163. """
  164. if not torch.cuda.is_available():
  165. raise CUDA_NOT_AVAILABLE("batch analyze not support in CPU mode")
  166. lang = None if lang == "" else lang
  167. # TODO: auto detect batch size
  168. batch_ratio = 1 if batch_ratio is None else batch_ratio
  169. end_page_id = end_page_id if end_page_id else len(dataset)
  170. model_manager = ModelSingleton()
  171. custom_model: CustomPEKModel = model_manager.get_model(
  172. ocr, show_log, lang, layout_model, formula_enable, table_enable
  173. )
  174. batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
  175. model_json = []
  176. # batch analyze
  177. images = []
  178. for index in range(len(dataset)):
  179. if start_page_id <= index <= end_page_id:
  180. page_data = dataset.get_page(index)
  181. img_dict = page_data.get_image()
  182. images.append(img_dict["img"])
  183. analyze_result = batch_model(images)
  184. for index in range(len(dataset)):
  185. page_data = dataset.get_page(index)
  186. img_dict = page_data.get_image()
  187. page_width = img_dict["width"]
  188. page_height = img_dict["height"]
  189. if start_page_id <= index <= end_page_id:
  190. result = analyze_result.pop(0)
  191. else:
  192. result = []
  193. page_info = {"page_no": index, "height": page_height, "width": page_width}
  194. page_dict = {"layout_dets": result, "page_info": page_info}
  195. model_json.append(page_dict)
  196. # TODO: clean memory when gpu memory is not enough
  197. clean_memory()
  198. return InferenceResult(model_json, dataset)