dataset.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. from abc import ABC, abstractmethod
  2. from typing import Iterator
  3. import fitz
  4. from magic_pdf.config.enums import SupportedPdfParseMethod
  5. from magic_pdf.data.schemas import PageInfo
  6. from magic_pdf.data.utils import fitz_doc_to_image
  7. class PageableData(ABC):
  8. @abstractmethod
  9. def get_image(self) -> dict:
  10. """Transform data to image."""
  11. pass
  12. @abstractmethod
  13. def get_doc(self) -> fitz.Page:
  14. """Get the pymudoc page."""
  15. pass
  16. @abstractmethod
  17. def get_page_info(self) -> PageInfo:
  18. """Get the page info of the page.
  19. Returns:
  20. PageInfo: the page info of this page
  21. """
  22. pass
  23. class Dataset(ABC):
  24. @abstractmethod
  25. def __len__(self) -> int:
  26. """The length of the dataset."""
  27. pass
  28. @abstractmethod
  29. def __iter__(self) -> Iterator[PageableData]:
  30. """Yield the page data."""
  31. pass
  32. @abstractmethod
  33. def supported_methods(self) -> list[SupportedPdfParseMethod]:
  34. """The methods that this dataset support.
  35. Returns:
  36. list[SupportedPdfParseMethod]: The supported methods, Valid methods are: OCR, TXT
  37. """
  38. pass
  39. @abstractmethod
  40. def data_bits(self) -> bytes:
  41. """The bits used to create this dataset."""
  42. pass
  43. @abstractmethod
  44. def get_page(self, page_id: int) -> PageableData:
  45. """Get the page indexed by page_id.
  46. Args:
  47. page_id (int): the index of the page
  48. Returns:
  49. PageableData: the page doc object
  50. """
  51. pass
  52. class PymuDocDataset(Dataset):
  53. def __init__(self, bits: bytes):
  54. """Initialize the dataset, which wraps the pymudoc documents.
  55. Args:
  56. bits (bytes): the bytes of the pdf
  57. """
  58. self._records = [Doc(v) for v in fitz.open('pdf', bits)]
  59. self._data_bits = bits
  60. self._raw_data = bits
  61. def __len__(self) -> int:
  62. """The page number of the pdf."""
  63. return len(self._records)
  64. def __iter__(self) -> Iterator[PageableData]:
  65. """Yield the page doc object."""
  66. return iter(self._records)
  67. def supported_methods(self) -> list[SupportedPdfParseMethod]:
  68. """The method supported by this dataset.
  69. Returns:
  70. list[SupportedPdfParseMethod]: the supported methods
  71. """
  72. return [SupportedPdfParseMethod.OCR, SupportedPdfParseMethod.TXT]
  73. def data_bits(self) -> bytes:
  74. """The pdf bits used to create this dataset."""
  75. return self._data_bits
  76. def get_page(self, page_id: int) -> PageableData:
  77. """The page doc object.
  78. Args:
  79. page_id (int): the page doc index
  80. Returns:
  81. PageableData: the page doc object
  82. """
  83. return self._records[page_id]
  84. class ImageDataset(Dataset):
  85. def __init__(self, bits: bytes):
  86. """Initialize the dataset, which wraps the pymudoc documents.
  87. Args:
  88. bits (bytes): the bytes of the photo which will be converted to pdf first. then converted to pymudoc.
  89. """
  90. pdf_bytes = fitz.open(stream=bits).convert_to_pdf()
  91. self._records = [Doc(v) for v in fitz.open('pdf', pdf_bytes)]
  92. self._raw_data = bits
  93. self._data_bits = pdf_bytes
  94. def __len__(self) -> int:
  95. """The length of the dataset."""
  96. return len(self._records)
  97. def __iter__(self) -> Iterator[PageableData]:
  98. """Yield the page object."""
  99. return iter(self._records)
  100. def supported_methods(self):
  101. """The method supported by this dataset.
  102. Returns:
  103. list[SupportedPdfParseMethod]: the supported methods
  104. """
  105. return [SupportedPdfParseMethod.OCR]
  106. def data_bits(self) -> bytes:
  107. """The pdf bits used to create this dataset."""
  108. return self._data_bits
  109. def get_page(self, page_id: int) -> PageableData:
  110. """The page doc object.
  111. Args:
  112. page_id (int): the page doc index
  113. Returns:
  114. PageableData: the page doc object
  115. """
  116. return self._records[page_id]
  117. class Doc(PageableData):
  118. """Initialized with pymudoc object."""
  119. def __init__(self, doc: fitz.Page):
  120. self._doc = doc
  121. def get_image(self):
  122. """Return the imge info.
  123. Returns:
  124. dict: {
  125. img: np.ndarray,
  126. width: int,
  127. height: int
  128. }
  129. """
  130. return fitz_doc_to_image(self._doc)
  131. def get_doc(self) -> fitz.Page:
  132. """Get the pymudoc object.
  133. Returns:
  134. fitz.Page: the pymudoc object
  135. """
  136. return self._doc
  137. def get_page_info(self) -> PageInfo:
  138. """Get the page info of the page.
  139. Returns:
  140. PageInfo: the page info of this page
  141. """
  142. page_w = self._doc.rect.width
  143. page_h = self._doc.rect.height
  144. return PageInfo(w=page_w, h=page_h)
  145. def __getattr__(self, name):
  146. if hasattr(self._doc, name):
  147. return getattr(self._doc, name)