dataset.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. import os
  2. from abc import ABC, abstractmethod
  3. from typing import Callable, Iterator
  4. import fitz
  5. from magic_pdf.config.enums import SupportedPdfParseMethod
  6. from magic_pdf.data.schemas import PageInfo
  7. from magic_pdf.data.utils import fitz_doc_to_image
  8. from magic_pdf.filter import classify
  9. class PageableData(ABC):
  10. @abstractmethod
  11. def get_image(self) -> dict:
  12. """Transform data to image."""
  13. pass
  14. @abstractmethod
  15. def get_doc(self) -> fitz.Page:
  16. """Get the pymudoc page."""
  17. pass
  18. @abstractmethod
  19. def get_page_info(self) -> PageInfo:
  20. """Get the page info of the page.
  21. Returns:
  22. PageInfo: the page info of this page
  23. """
  24. pass
  25. @abstractmethod
  26. def draw_rect(self, rect_coords, color, fill, fill_opacity, width, overlay):
  27. """draw rectangle.
  28. Args:
  29. rect_coords (list[float]): four elements array contain the top-left and bottom-right coordinates, [x0, y0, x1, y1]
  30. color (list[float] | None): three element tuple which describe the RGB of the board line, None means no board line
  31. fill (list[float] | None): fill the board with RGB, None means will not fill with color
  32. fill_opacity (float): opacity of the fill, range from [0, 1]
  33. width (float): the width of board
  34. overlay (bool): fill the color in foreground or background. True means fill in background.
  35. """
  36. pass
  37. @abstractmethod
  38. def insert_text(self, coord, content, fontsize, color):
  39. """insert text.
  40. Args:
  41. coord (list[float]): four elements array contain the top-left and bottom-right coordinates, [x0, y0, x1, y1]
  42. content (str): the text content
  43. fontsize (int): font size of the text
  44. color (list[float] | None): three element tuple which describe the RGB of the board line, None will use the default font color!
  45. """
  46. pass
  47. class Dataset(ABC):
  48. @abstractmethod
  49. def __len__(self) -> int:
  50. """The length of the dataset."""
  51. pass
  52. @abstractmethod
  53. def __iter__(self) -> Iterator[PageableData]:
  54. """Yield the page data."""
  55. pass
  56. @abstractmethod
  57. def supported_methods(self) -> list[SupportedPdfParseMethod]:
  58. """The methods that this dataset support.
  59. Returns:
  60. list[SupportedPdfParseMethod]: The supported methods, Valid methods are: OCR, TXT
  61. """
  62. pass
  63. @abstractmethod
  64. def data_bits(self) -> bytes:
  65. """The bits used to create this dataset."""
  66. pass
  67. @abstractmethod
  68. def get_page(self, page_id: int) -> PageableData:
  69. """Get the page indexed by page_id.
  70. Args:
  71. page_id (int): the index of the page
  72. Returns:
  73. PageableData: the page doc object
  74. """
  75. pass
  76. @abstractmethod
  77. def dump_to_file(self, file_path: str):
  78. """Dump the file
  79. Args:
  80. file_path (str): the file path
  81. """
  82. pass
  83. @abstractmethod
  84. def apply(self, proc: Callable, *args, **kwargs):
  85. """Apply callable method which.
  86. Args:
  87. proc (Callable): invoke proc as follows:
  88. proc(self, *args, **kwargs)
  89. Returns:
  90. Any: return the result generated by proc
  91. """
  92. pass
  93. @abstractmethod
  94. def classify(self) -> SupportedPdfParseMethod:
  95. """classify the dataset
  96. Returns:
  97. SupportedPdfParseMethod: _description_
  98. """
  99. pass
  100. @abstractmethod
  101. def clone(self):
  102. """clone this dataset
  103. """
  104. pass
  105. class PymuDocDataset(Dataset):
  106. def __init__(self, bits: bytes):
  107. """Initialize the dataset, which wraps the pymudoc documents.
  108. Args:
  109. bits (bytes): the bytes of the pdf
  110. """
  111. self._raw_fitz = fitz.open('pdf', bits)
  112. self._records = [Doc(v) for v in self._raw_fitz]
  113. self._data_bits = bits
  114. self._raw_data = bits
  115. def __len__(self) -> int:
  116. """The page number of the pdf."""
  117. return len(self._records)
  118. def __iter__(self) -> Iterator[PageableData]:
  119. """Yield the page doc object."""
  120. return iter(self._records)
  121. def supported_methods(self) -> list[SupportedPdfParseMethod]:
  122. """The method supported by this dataset.
  123. Returns:
  124. list[SupportedPdfParseMethod]: the supported methods
  125. """
  126. return [SupportedPdfParseMethod.OCR, SupportedPdfParseMethod.TXT]
  127. def data_bits(self) -> bytes:
  128. """The pdf bits used to create this dataset."""
  129. return self._data_bits
  130. def get_page(self, page_id: int) -> PageableData:
  131. """The page doc object.
  132. Args:
  133. page_id (int): the page doc index
  134. Returns:
  135. PageableData: the page doc object
  136. """
  137. return self._records[page_id]
  138. def dump_to_file(self, file_path: str):
  139. """Dump the file
  140. Args:
  141. file_path (str): the file path
  142. """
  143. dir_name = os.path.dirname(file_path)
  144. if dir_name not in ('', '.', '..'):
  145. os.makedirs(dir_name, exist_ok=True)
  146. self._raw_fitz.save(file_path)
  147. def apply(self, proc: Callable, *args, **kwargs):
  148. """Apply callable method which.
  149. Args:
  150. proc (Callable): invoke proc as follows:
  151. proc(dataset, *args, **kwargs)
  152. Returns:
  153. Any: return the result generated by proc
  154. """
  155. return proc(self, *args, **kwargs)
  156. def classify(self) -> SupportedPdfParseMethod:
  157. """classify the dataset
  158. Returns:
  159. SupportedPdfParseMethod: _description_
  160. """
  161. return classify(self._data_bits)
  162. def clone(self):
  163. """clone this dataset
  164. """
  165. return PymuDocDataset(self._raw_data)
  166. class ImageDataset(Dataset):
  167. def __init__(self, bits: bytes):
  168. """Initialize the dataset, which wraps the pymudoc documents.
  169. Args:
  170. bits (bytes): the bytes of the photo which will be converted to pdf first. then converted to pymudoc.
  171. """
  172. pdf_bytes = fitz.open(stream=bits).convert_to_pdf()
  173. self._raw_fitz = fitz.open('pdf', pdf_bytes)
  174. self._records = [Doc(v) for v in self._raw_fitz]
  175. self._raw_data = bits
  176. self._data_bits = pdf_bytes
  177. def __len__(self) -> int:
  178. """The length of the dataset."""
  179. return len(self._records)
  180. def __iter__(self) -> Iterator[PageableData]:
  181. """Yield the page object."""
  182. return iter(self._records)
  183. def supported_methods(self):
  184. """The method supported by this dataset.
  185. Returns:
  186. list[SupportedPdfParseMethod]: the supported methods
  187. """
  188. return [SupportedPdfParseMethod.OCR]
  189. def data_bits(self) -> bytes:
  190. """The pdf bits used to create this dataset."""
  191. return self._data_bits
  192. def get_page(self, page_id: int) -> PageableData:
  193. """The page doc object.
  194. Args:
  195. page_id (int): the page doc index
  196. Returns:
  197. PageableData: the page doc object
  198. """
  199. return self._records[page_id]
  200. def dump_to_file(self, file_path: str):
  201. """Dump the file
  202. Args:
  203. file_path (str): the file path
  204. """
  205. dir_name = os.path.dirname(file_path)
  206. if dir_name not in ('', '.', '..'):
  207. os.makedirs(dir_name, exist_ok=True)
  208. self._raw_fitz.save(file_path)
  209. def apply(self, proc: Callable, *args, **kwargs):
  210. """Apply callable method which.
  211. Args:
  212. proc (Callable): invoke proc as follows:
  213. proc(dataset, *args, **kwargs)
  214. Returns:
  215. Any: return the result generated by proc
  216. """
  217. return proc(self, *args, **kwargs)
  218. def classify(self) -> SupportedPdfParseMethod:
  219. """classify the dataset
  220. Returns:
  221. SupportedPdfParseMethod: _description_
  222. """
  223. return SupportedPdfParseMethod.OCR
  224. def clone(self):
  225. """clone this dataset
  226. """
  227. return ImageDataset(self._raw_data)
  228. class Doc(PageableData):
  229. """Initialized with pymudoc object."""
  230. def __init__(self, doc: fitz.Page):
  231. self._doc = doc
  232. def get_image(self):
  233. """Return the image info.
  234. Returns:
  235. dict: {
  236. img: np.ndarray,
  237. width: int,
  238. height: int
  239. }
  240. """
  241. return fitz_doc_to_image(self._doc)
  242. def get_doc(self) -> fitz.Page:
  243. """Get the pymudoc object.
  244. Returns:
  245. fitz.Page: the pymudoc object
  246. """
  247. return self._doc
  248. def get_page_info(self) -> PageInfo:
  249. """Get the page info of the page.
  250. Returns:
  251. PageInfo: the page info of this page
  252. """
  253. page_w = self._doc.rect.width
  254. page_h = self._doc.rect.height
  255. return PageInfo(w=page_w, h=page_h)
  256. def __getattr__(self, name):
  257. if hasattr(self._doc, name):
  258. return getattr(self._doc, name)
  259. def draw_rect(self, rect_coords, color, fill, fill_opacity, width, overlay):
  260. """draw rectangle.
  261. Args:
  262. rect_coords (list[float]): four elements array contain the top-left and bottom-right coordinates, [x0, y0, x1, y1]
  263. color (list[float] | None): three element tuple which describe the RGB of the board line, None means no board line
  264. fill (list[float] | None): fill the board with RGB, None means will not fill with color
  265. fill_opacity (float): opacity of the fill, range from [0, 1]
  266. width (float): the width of board
  267. overlay (bool): fill the color in foreground or background. True means fill in background.
  268. """
  269. self._doc.draw_rect(
  270. rect_coords,
  271. color=color,
  272. fill=fill,
  273. fill_opacity=fill_opacity,
  274. width=width,
  275. overlay=overlay,
  276. )
  277. def insert_text(self, coord, content, fontsize, color):
  278. """insert text.
  279. Args:
  280. coord (list[float]): four elements array contain the top-left and bottom-right coordinates, [x0, y0, x1, y1]
  281. content (str): the text content
  282. fontsize (int): font size of the text
  283. color (list[float] | None): three element tuple which describe the RGB of the board line, None will use the default font color!
  284. """
  285. self._doc.insert_text(coord, content, fontsize=fontsize, color=color)