operators.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. import copy
  2. import json
  3. import os
  4. from typing import Callable
  5. from magic_pdf.config.constants import PARSE_TYPE_OCR, PARSE_TYPE_TXT
  6. from magic_pdf.config.enums import SupportedPdfParseMethod
  7. from magic_pdf.data.data_reader_writer import DataWriter
  8. from magic_pdf.data.dataset import Dataset
  9. from magic_pdf.filter import classify
  10. from magic_pdf.libs.draw_bbox import draw_model_bbox
  11. from magic_pdf.libs.version import __version__
  12. from magic_pdf.model import InferenceResultBase
  13. from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
  14. from magic_pdf.pipe.operators import PipeResult
  15. class InferenceResult(InferenceResultBase):
  16. def __init__(self, inference_results: list, dataset: Dataset):
  17. """Initialized method.
  18. Args:
  19. inference_results (list): the inference result generated by model
  20. dataset (Dataset): the dataset related with model inference result
  21. """
  22. self._infer_res = inference_results
  23. self._dataset = dataset
  24. def draw_model(self, file_path: str) -> None:
  25. """Draw model inference result.
  26. Args:
  27. file_path (str): the output file path
  28. """
  29. dir_name = os.path.dirname(file_path)
  30. base_name = os.path.basename(file_path)
  31. if not os.path.exists(dir_name):
  32. os.makedirs(dir_name, exist_ok=True)
  33. draw_model_bbox(
  34. copy.deepcopy(self._infer_res), self._dataset, dir_name, base_name
  35. )
  36. def dump_model(self, writer: DataWriter, file_path: str):
  37. """Dump model inference result to file.
  38. Args:
  39. writer (DataWriter): writer handle
  40. file_path (str): the location of target file
  41. """
  42. writer.write_string(
  43. file_path, json.dumps(self._infer_res, ensure_ascii=False, indent=4)
  44. )
  45. def get_infer_res(self):
  46. """Get the inference result.
  47. Returns:
  48. list: the inference result generated by model
  49. """
  50. return self._infer_res
  51. def apply(self, proc: Callable, *args, **kwargs):
  52. """Apply callable method which.
  53. Args:
  54. proc (Callable): invoke proc as follows:
  55. proc(inference_result, *args, **kwargs)
  56. Returns:
  57. Any: return the result generated by proc
  58. """
  59. return proc(copy.deepcopy(self._infer_res), *args, **kwargs)
  60. def pipe_auto_mode(
  61. self,
  62. imageWriter: DataWriter,
  63. start_page_id=0,
  64. end_page_id=None,
  65. debug_mode=False,
  66. lang=None,
  67. ) -> PipeResult:
  68. """Post-proc the model inference result.
  69. step1: classify the dataset type
  70. step2: based the result of step1, using `pipe_txt_mode` or `pipe_ocr_mode`
  71. Args:
  72. imageWriter (DataWriter): the image writer handle
  73. start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
  74. end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
  75. debug_mode (bool, optional): Defaults to False. will dump more log if enabled
  76. lang (str, optional): Defaults to None.
  77. Returns:
  78. PipeResult: the result
  79. """
  80. pdf_proc_method = classify(self._dataset.data_bits())
  81. if pdf_proc_method == SupportedPdfParseMethod.TXT:
  82. return self.pipe_txt_mode(
  83. imageWriter, start_page_id, end_page_id, debug_mode, lang
  84. )
  85. else:
  86. return self.pipe_ocr_mode(
  87. imageWriter, start_page_id, end_page_id, debug_mode, lang
  88. )
  89. def pipe_txt_mode(
  90. self,
  91. imageWriter: DataWriter,
  92. start_page_id=0,
  93. end_page_id=None,
  94. debug_mode=False,
  95. lang=None,
  96. ) -> PipeResult:
  97. """Post-proc the model inference result, Extract the text using the
  98. third library, such as `pymupdf`
  99. Args:
  100. imageWriter (DataWriter): the image writer handle
  101. start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
  102. end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
  103. debug_mode (bool, optional): Defaults to False. will dump more log if enabled
  104. lang (str, optional): Defaults to None.
  105. Returns:
  106. PipeResult: the result
  107. """
  108. def proc(*args, **kwargs) -> PipeResult:
  109. res = pdf_parse_union(*args, **kwargs)
  110. res['_parse_type'] = PARSE_TYPE_TXT
  111. res['_version_name'] = __version__
  112. if 'lang' in kwargs and kwargs['lang'] is not None:
  113. res['lang'] = kwargs['lang']
  114. return PipeResult(res, self._dataset)
  115. res = self.apply(
  116. proc,
  117. self._dataset,
  118. imageWriter,
  119. SupportedPdfParseMethod.TXT,
  120. start_page_id=start_page_id,
  121. end_page_id=end_page_id,
  122. debug_mode=debug_mode,
  123. lang=lang,
  124. )
  125. return res
  126. def pipe_ocr_mode(
  127. self,
  128. imageWriter: DataWriter,
  129. start_page_id=0,
  130. end_page_id=None,
  131. debug_mode=False,
  132. lang=None,
  133. ) -> PipeResult:
  134. """Post-proc the model inference result, Extract the text using `OCR`
  135. technical.
  136. Args:
  137. imageWriter (DataWriter): the image writer handle
  138. start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
  139. end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
  140. debug_mode (bool, optional): Defaults to False. will dump more log if enabled
  141. lang (str, optional): Defaults to None.
  142. Returns:
  143. PipeResult: the result
  144. """
  145. def proc(*args, **kwargs) -> PipeResult:
  146. res = pdf_parse_union(*args, **kwargs)
  147. res['_parse_type'] = PARSE_TYPE_OCR
  148. res['_version_name'] = __version__
  149. if 'lang' in kwargs and kwargs['lang'] is not None:
  150. res['lang'] = kwargs['lang']
  151. return PipeResult(res, self._dataset)
  152. res = self.apply(
  153. proc,
  154. self._dataset,
  155. imageWriter,
  156. SupportedPdfParseMethod.OCR,
  157. start_page_id=start_page_id,
  158. end_page_id=end_page_id,
  159. debug_mode=debug_mode,
  160. lang=lang,
  161. )
  162. return res