operators.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. import copy
  2. import json
  3. import os
  4. from typing import Callable
  5. from magic_pdf.config.enums import SupportedPdfParseMethod
  6. from magic_pdf.data.data_reader_writer import DataWriter
  7. from magic_pdf.data.dataset import Dataset
  8. from magic_pdf.libs.version import __version__
  9. from magic_pdf.filter import classify
  10. from magic_pdf.libs.draw_bbox import draw_model_bbox
  11. from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
  12. from magic_pdf.pipe.operators import PipeResult
  13. from magic_pdf.model import InferenceResultBase
  14. from magic_pdf.libs.version import __version__
  15. from magic_pdf.config.constants import PARSE_TYPE_TXT, PARSE_TYPE_OCR
  16. class InferenceResult(InferenceResultBase):
  17. def __init__(self, inference_results: list, dataset: Dataset):
  18. """Initialized method.
  19. Args:
  20. inference_results (list): the inference result generated by model
  21. dataset (Dataset): the dataset related with model inference result
  22. """
  23. self._infer_res = inference_results
  24. self._dataset = dataset
  25. def draw_model(self, file_path: str) -> None:
  26. """Draw model inference result.
  27. Args:
  28. file_path (str): the output file path
  29. """
  30. dir_name = os.path.dirname(file_path)
  31. base_name = os.path.basename(file_path)
  32. if not os.path.exists(dir_name):
  33. os.makedirs(dir_name, exist_ok=True)
  34. draw_model_bbox(
  35. copy.deepcopy(self._infer_res), self._dataset, dir_name, base_name
  36. )
  37. def dump_model(self, writer: DataWriter, file_path: str):
  38. """Dump model inference result to file.
  39. Args:
  40. writer (DataWriter): writer handle
  41. file_path (str): the location of target file
  42. """
  43. writer.write_string(
  44. file_path, json.dumps(self._infer_res, ensure_ascii=False, indent=4)
  45. )
  46. def get_infer_res(self):
  47. """Get the inference result.
  48. Returns:
  49. list: the inference result generated by model
  50. """
  51. return self._infer_res
  52. def apply(self, proc: Callable, *args, **kwargs):
  53. """Apply callable method which.
  54. Args:
  55. proc (Callable): invoke proc as follows:
  56. proc(inference_result, *args, **kwargs)
  57. Returns:
  58. Any: return the result generated by proc
  59. """
  60. return proc(copy.deepcopy(self._infer_res), *args, **kwargs)
  61. def pipe_auto_mode(
  62. self,
  63. imageWriter: DataWriter,
  64. start_page_id=0,
  65. end_page_id=None,
  66. debug_mode=False,
  67. lang=None,
  68. ) -> PipeResult:
  69. """Post-proc the model inference result.
  70. step1: classify the dataset type
  71. step2: based the result of step1, using `pipe_txt_mode` or `pipe_ocr_mode`
  72. Args:
  73. imageWriter (DataWriter): the image writer handle
  74. start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
  75. end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
  76. debug_mode (bool, optional): Defaults to False. will dump more log if enabled
  77. lang (str, optional): Defaults to None.
  78. Returns:
  79. PipeResult: the result
  80. """
  81. pdf_proc_method = classify(self._dataset.data_bits())
  82. if pdf_proc_method == SupportedPdfParseMethod.TXT:
  83. return self.pipe_txt_mode(
  84. imageWriter, start_page_id, end_page_id, debug_mode, lang
  85. )
  86. else:
  87. return self.pipe_ocr_mode(
  88. imageWriter, start_page_id, end_page_id, debug_mode, lang
  89. )
  90. def pipe_txt_mode(
  91. self,
  92. imageWriter: DataWriter,
  93. start_page_id=0,
  94. end_page_id=None,
  95. debug_mode=False,
  96. lang=None,
  97. ) -> PipeResult:
  98. """Post-proc the model inference result, Extract the text using the
  99. third library, such as `pymupdf`
  100. Args:
  101. imageWriter (DataWriter): the image writer handle
  102. start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
  103. end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
  104. debug_mode (bool, optional): Defaults to False. will dump more log if enabled
  105. lang (str, optional): Defaults to None.
  106. Returns:
  107. PipeResult: the result
  108. """
  109. def proc(*args, **kwargs) -> PipeResult:
  110. res = pdf_parse_union(*args, **kwargs)
  111. return PipeResult(res, self._dataset)
  112. res = self.apply(
  113. proc,
  114. self._dataset,
  115. imageWriter,
  116. SupportedPdfParseMethod.TXT,
  117. start_page_id=start_page_id,
  118. end_page_id=end_page_id,
  119. debug_mode=debug_mode,
  120. lang=lang,
  121. )
  122. res['_parse_type'] = PARSE_TYPE_TXT
  123. res['_version_name'] = __version__
  124. return res
  125. def pipe_ocr_mode(
  126. self,
  127. imageWriter: DataWriter,
  128. start_page_id=0,
  129. end_page_id=None,
  130. debug_mode=False,
  131. lang=None,
  132. ) -> PipeResult:
  133. """Post-proc the model inference result, Extract the text using `OCR`
  134. technical.
  135. Args:
  136. imageWriter (DataWriter): the image writer handle
  137. start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
  138. end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
  139. debug_mode (bool, optional): Defaults to False. will dump more log if enabled
  140. lang (str, optional): Defaults to None.
  141. Returns:
  142. PipeResult: the result
  143. """
  144. def proc(*args, **kwargs) -> PipeResult:
  145. res = pdf_parse_union(*args, **kwargs)
  146. return PipeResult(res, self._dataset)
  147. res = self.apply(
  148. proc,
  149. self._dataset,
  150. imageWriter,
  151. SupportedPdfParseMethod.TXT,
  152. start_page_id=start_page_id,
  153. end_page_id=end_page_id,
  154. debug_mode=debug_mode,
  155. lang=lang,
  156. )
  157. res['_parse_type'] = PARSE_TYPE_OCR
  158. res['_version_name'] = __version__
  159. return res