models.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. import copy
  2. import json
  3. import os
  4. from typing import Callable
  5. from magic_pdf.config.constants import PARSE_TYPE_OCR, PARSE_TYPE_TXT
  6. from magic_pdf.config.enums import SupportedPdfParseMethod
  7. from magic_pdf.data.data_reader_writer import DataWriter
  8. from magic_pdf.data.dataset import Dataset
  9. from magic_pdf.libs.draw_bbox import draw_model_bbox
  10. from magic_pdf.libs.version import __version__
  11. from magic_pdf.operators.pipes import PipeResult
  12. from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
  13. from magic_pdf.operators import InferenceResultBase
  14. class InferenceResult(InferenceResultBase):
  15. def __init__(self, inference_results: list, dataset: Dataset):
  16. """Initialized method.
  17. Args:
  18. inference_results (list): the inference result generated by model
  19. dataset (Dataset): the dataset related with model inference result
  20. """
  21. self._infer_res = inference_results
  22. self._dataset = dataset
  23. def draw_model(self, file_path: str) -> None:
  24. """Draw model inference result.
  25. Args:
  26. file_path (str): the output file path
  27. """
  28. dir_name = os.path.dirname(file_path)
  29. base_name = os.path.basename(file_path)
  30. if not os.path.exists(dir_name):
  31. os.makedirs(dir_name, exist_ok=True)
  32. draw_model_bbox(
  33. copy.deepcopy(self._infer_res), self._dataset, dir_name, base_name
  34. )
  35. def dump_model(self, writer: DataWriter, file_path: str):
  36. """Dump model inference result to file.
  37. Args:
  38. writer (DataWriter): writer handle
  39. file_path (str): the location of target file
  40. """
  41. writer.write_string(
  42. file_path, json.dumps(self._infer_res, ensure_ascii=False, indent=4)
  43. )
  44. def get_infer_res(self):
  45. """Get the inference result.
  46. Returns:
  47. list: the inference result generated by model
  48. """
  49. return self._infer_res
  50. def apply(self, proc: Callable, *args, **kwargs):
  51. """Apply callable method which.
  52. Args:
  53. proc (Callable): invoke proc as follows:
  54. proc(inference_result, *args, **kwargs)
  55. Returns:
  56. Any: return the result generated by proc
  57. """
  58. return proc(copy.deepcopy(self._infer_res), *args, **kwargs)
  59. def pipe_txt_mode(
  60. self,
  61. imageWriter: DataWriter,
  62. start_page_id=0,
  63. end_page_id=None,
  64. debug_mode=False,
  65. lang=None,
  66. ) -> PipeResult:
  67. """Post-proc the model inference result, Extract the text using the
  68. third library, such as `pymupdf`
  69. Args:
  70. imageWriter (DataWriter): the image writer handle
  71. start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
  72. end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
  73. debug_mode (bool, optional): Defaults to False. will dump more log if enabled
  74. lang (str, optional): Defaults to None.
  75. Returns:
  76. PipeResult: the result
  77. """
  78. def proc(*args, **kwargs) -> PipeResult:
  79. res = pdf_parse_union(*args, **kwargs)
  80. res['_parse_type'] = PARSE_TYPE_TXT
  81. res['_version_name'] = __version__
  82. if 'lang' in kwargs and kwargs['lang'] is not None:
  83. res['lang'] = kwargs['lang']
  84. return PipeResult(res, self._dataset)
  85. res = self.apply(
  86. proc,
  87. self._dataset,
  88. imageWriter,
  89. SupportedPdfParseMethod.TXT,
  90. start_page_id=start_page_id,
  91. end_page_id=end_page_id,
  92. debug_mode=debug_mode,
  93. lang=lang,
  94. )
  95. return res
  96. def pipe_ocr_mode(
  97. self,
  98. imageWriter: DataWriter,
  99. start_page_id=0,
  100. end_page_id=None,
  101. debug_mode=False,
  102. lang=None,
  103. ) -> PipeResult:
  104. """Post-proc the model inference result, Extract the text using `OCR`
  105. technical.
  106. Args:
  107. imageWriter (DataWriter): the image writer handle
  108. start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
  109. end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
  110. debug_mode (bool, optional): Defaults to False. will dump more log if enabled
  111. lang (str, optional): Defaults to None.
  112. Returns:
  113. PipeResult: the result
  114. """
  115. def proc(*args, **kwargs) -> PipeResult:
  116. res = pdf_parse_union(*args, **kwargs)
  117. res['_parse_type'] = PARSE_TYPE_OCR
  118. res['_version_name'] = __version__
  119. if 'lang' in kwargs and kwargs['lang'] is not None:
  120. res['lang'] = kwargs['lang']
  121. return PipeResult(res, self._dataset)
  122. res = self.apply(
  123. proc,
  124. self._dataset,
  125. imageWriter,
  126. SupportedPdfParseMethod.OCR,
  127. start_page_id=start_page_id,
  128. end_page_id=end_page_id,
  129. debug_mode=debug_mode,
  130. lang=lang,
  131. )
  132. return res