dataset.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. import os
  2. from abc import ABC, abstractmethod
  3. from typing import Callable, Iterator
  4. import fitz
  5. from magic_pdf.config.enums import SupportedPdfParseMethod
  6. from magic_pdf.data.schemas import PageInfo
  7. from magic_pdf.data.utils import fitz_doc_to_image
  8. from magic_pdf.filter import classify
  9. class PageableData(ABC):
  10. @abstractmethod
  11. def get_image(self) -> dict:
  12. """Transform data to image."""
  13. pass
  14. @abstractmethod
  15. def get_doc(self) -> fitz.Page:
  16. """Get the pymudoc page."""
  17. pass
  18. @abstractmethod
  19. def get_page_info(self) -> PageInfo:
  20. """Get the page info of the page.
  21. Returns:
  22. PageInfo: the page info of this page
  23. """
  24. pass
  25. @abstractmethod
  26. def draw_rect(self, rect_coords, color, fill, fill_opacity, width, overlay):
  27. pass
  28. @abstractmethod
  29. def insert_text(self, coord, content, fontsize, color):
  30. pass
  31. class Dataset(ABC):
  32. @abstractmethod
  33. def __len__(self) -> int:
  34. """The length of the dataset."""
  35. pass
  36. @abstractmethod
  37. def __iter__(self) -> Iterator[PageableData]:
  38. """Yield the page data."""
  39. pass
  40. @abstractmethod
  41. def supported_methods(self) -> list[SupportedPdfParseMethod]:
  42. """The methods that this dataset support.
  43. Returns:
  44. list[SupportedPdfParseMethod]: The supported methods, Valid methods are: OCR, TXT
  45. """
  46. pass
  47. @abstractmethod
  48. def data_bits(self) -> bytes:
  49. """The bits used to create this dataset."""
  50. pass
  51. @abstractmethod
  52. def get_page(self, page_id: int) -> PageableData:
  53. """Get the page indexed by page_id.
  54. Args:
  55. page_id (int): the index of the page
  56. Returns:
  57. PageableData: the page doc object
  58. """
  59. pass
  60. @abstractmethod
  61. def dump_to_file(self, file_path: str):
  62. pass
  63. @abstractmethod
  64. def apply(self, proc: Callable, *args, **kwargs):
  65. pass
  66. @abstractmethod
  67. def classify(self) -> SupportedPdfParseMethod:
  68. pass
  69. class PymuDocDataset(Dataset):
  70. def __init__(self, bits: bytes):
  71. """Initialize the dataset, which wraps the pymudoc documents.
  72. Args:
  73. bits (bytes): the bytes of the pdf
  74. """
  75. self._raw_fitz = fitz.open('pdf', bits)
  76. self._records = [Doc(v) for v in self._raw_fitz]
  77. self._data_bits = bits
  78. self._raw_data = bits
  79. def __len__(self) -> int:
  80. """The page number of the pdf."""
  81. return len(self._records)
  82. def __iter__(self) -> Iterator[PageableData]:
  83. """Yield the page doc object."""
  84. return iter(self._records)
  85. def supported_methods(self) -> list[SupportedPdfParseMethod]:
  86. """The method supported by this dataset.
  87. Returns:
  88. list[SupportedPdfParseMethod]: the supported methods
  89. """
  90. return [SupportedPdfParseMethod.OCR, SupportedPdfParseMethod.TXT]
  91. def data_bits(self) -> bytes:
  92. """The pdf bits used to create this dataset."""
  93. return self._data_bits
  94. def get_page(self, page_id: int) -> PageableData:
  95. """The page doc object.
  96. Args:
  97. page_id (int): the page doc index
  98. Returns:
  99. PageableData: the page doc object
  100. """
  101. return self._records[page_id]
  102. def dump_to_file(self, file_path: str):
  103. dir_name = os.path.dirname(file_path)
  104. if dir_name not in ('', '.', '..'):
  105. os.makedirs(dir_name, exist_ok=True)
  106. self._raw_fitz.save(file_path)
  107. def apply(self, proc: Callable, *args, **kwargs):
  108. new_args = tuple([self] + list(args))
  109. return proc(*new_args, **kwargs)
  110. def classify(self) -> SupportedPdfParseMethod:
  111. return classify(self._data_bits)
  112. class ImageDataset(Dataset):
  113. def __init__(self, bits: bytes):
  114. """Initialize the dataset, which wraps the pymudoc documents.
  115. Args:
  116. bits (bytes): the bytes of the photo which will be converted to pdf first. then converted to pymudoc.
  117. """
  118. pdf_bytes = fitz.open(stream=bits).convert_to_pdf()
  119. self._raw_fitz = fitz.open('pdf', pdf_bytes)
  120. self._records = [Doc(v) for v in self._raw_fitz]
  121. self._raw_data = bits
  122. self._data_bits = pdf_bytes
  123. def __len__(self) -> int:
  124. """The length of the dataset."""
  125. return len(self._records)
  126. def __iter__(self) -> Iterator[PageableData]:
  127. """Yield the page object."""
  128. return iter(self._records)
  129. def supported_methods(self):
  130. """The method supported by this dataset.
  131. Returns:
  132. list[SupportedPdfParseMethod]: the supported methods
  133. """
  134. return [SupportedPdfParseMethod.OCR]
  135. def data_bits(self) -> bytes:
  136. """The pdf bits used to create this dataset."""
  137. return self._data_bits
  138. def get_page(self, page_id: int) -> PageableData:
  139. """The page doc object.
  140. Args:
  141. page_id (int): the page doc index
  142. Returns:
  143. PageableData: the page doc object
  144. """
  145. return self._records[page_id]
  146. def dump_to_file(self, file_path: str):
  147. dir_name = os.path.dirname(file_path)
  148. if dir_name not in ('', '.', '..'):
  149. os.makedirs(dir_name, exist_ok=True)
  150. self._raw_fitz.save(file_path)
  151. def apply(self, proc: Callable, *args, **kwargs):
  152. return proc(self, *args, **kwargs)
  153. def classify(self) -> SupportedPdfParseMethod:
  154. return SupportedPdfParseMethod.OCR
  155. class Doc(PageableData):
  156. """Initialized with pymudoc object."""
  157. def __init__(self, doc: fitz.Page):
  158. self._doc = doc
  159. def get_image(self):
  160. """Return the imge info.
  161. Returns:
  162. dict: {
  163. img: np.ndarray,
  164. width: int,
  165. height: int
  166. }
  167. """
  168. return fitz_doc_to_image(self._doc)
  169. def get_doc(self) -> fitz.Page:
  170. """Get the pymudoc object.
  171. Returns:
  172. fitz.Page: the pymudoc object
  173. """
  174. return self._doc
  175. def get_page_info(self) -> PageInfo:
  176. """Get the page info of the page.
  177. Returns:
  178. PageInfo: the page info of this page
  179. """
  180. page_w = self._doc.rect.width
  181. page_h = self._doc.rect.height
  182. return PageInfo(w=page_w, h=page_h)
  183. def __getattr__(self, name):
  184. if hasattr(self._doc, name):
  185. return getattr(self._doc, name)
  186. def draw_rect(self, rect_coords, color, fill, fill_opacity, width, overlay):
  187. self._doc.draw_rect(
  188. rect_coords,
  189. color=color,
  190. fill=fill,
  191. fill_opacity=fill_opacity,
  192. width=width,
  193. overlay=overlay,
  194. )
  195. def insert_text(self, coord, content, fontsize, color):
  196. self._doc.insert_text(coord, content, fontsize=fontsize, color=color)