dataset.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. import os
  2. from abc import ABC, abstractmethod
  3. from typing import Callable, Iterator
  4. import fitz
  5. from loguru import logger
  6. from magic_pdf.config.enums import SupportedPdfParseMethod
  7. from magic_pdf.data.schemas import PageInfo
  8. from magic_pdf.data.utils import fitz_doc_to_image
  9. from magic_pdf.filter import classify
  10. class PageableData(ABC):
  11. @abstractmethod
  12. def get_image(self) -> dict:
  13. """Transform data to image."""
  14. pass
  15. @abstractmethod
  16. def get_doc(self) -> fitz.Page:
  17. """Get the pymudoc page."""
  18. pass
  19. @abstractmethod
  20. def get_page_info(self) -> PageInfo:
  21. """Get the page info of the page.
  22. Returns:
  23. PageInfo: the page info of this page
  24. """
  25. pass
  26. @abstractmethod
  27. def draw_rect(self, rect_coords, color, fill, fill_opacity, width, overlay):
  28. """draw rectangle.
  29. Args:
  30. rect_coords (list[float]): four elements array contain the top-left and bottom-right coordinates, [x0, y0, x1, y1]
  31. color (list[float] | None): three element tuple which describe the RGB of the board line, None means no board line
  32. fill (list[float] | None): fill the board with RGB, None means will not fill with color
  33. fill_opacity (float): opacity of the fill, range from [0, 1]
  34. width (float): the width of board
  35. overlay (bool): fill the color in foreground or background. True means fill in background.
  36. """
  37. pass
  38. @abstractmethod
  39. def insert_text(self, coord, content, fontsize, color):
  40. """insert text.
  41. Args:
  42. coord (list[float]): four elements array contain the top-left and bottom-right coordinates, [x0, y0, x1, y1]
  43. content (str): the text content
  44. fontsize (int): font size of the text
  45. color (list[float] | None): three element tuple which describe the RGB of the board line, None will use the default font color!
  46. """
  47. pass
  48. class Dataset(ABC):
  49. @abstractmethod
  50. def __len__(self) -> int:
  51. """The length of the dataset."""
  52. pass
  53. @abstractmethod
  54. def __iter__(self) -> Iterator[PageableData]:
  55. """Yield the page data."""
  56. pass
  57. @abstractmethod
  58. def supported_methods(self) -> list[SupportedPdfParseMethod]:
  59. """The methods that this dataset support.
  60. Returns:
  61. list[SupportedPdfParseMethod]: The supported methods, Valid methods are: OCR, TXT
  62. """
  63. pass
  64. @abstractmethod
  65. def data_bits(self) -> bytes:
  66. """The bits used to create this dataset."""
  67. pass
  68. @abstractmethod
  69. def get_page(self, page_id: int) -> PageableData:
  70. """Get the page indexed by page_id.
  71. Args:
  72. page_id (int): the index of the page
  73. Returns:
  74. PageableData: the page doc object
  75. """
  76. pass
  77. @abstractmethod
  78. def dump_to_file(self, file_path: str):
  79. """Dump the file.
  80. Args:
  81. file_path (str): the file path
  82. """
  83. pass
  84. @abstractmethod
  85. def apply(self, proc: Callable, *args, **kwargs):
  86. """Apply callable method which.
  87. Args:
  88. proc (Callable): invoke proc as follows:
  89. proc(self, *args, **kwargs)
  90. Returns:
  91. Any: return the result generated by proc
  92. """
  93. pass
  94. @abstractmethod
  95. def classify(self) -> SupportedPdfParseMethod:
  96. """classify the dataset.
  97. Returns:
  98. SupportedPdfParseMethod: _description_
  99. """
  100. pass
  101. @abstractmethod
  102. def clone(self):
  103. """clone this dataset."""
  104. pass
  105. class PymuDocDataset(Dataset):
  106. def __init__(self, bits: bytes, lang=None):
  107. """Initialize the dataset, which wraps the pymudoc documents.
  108. Args:
  109. bits (bytes): the bytes of the pdf
  110. """
  111. self._raw_fitz = fitz.open('pdf', bits)
  112. self._records = [Doc(v) for v in self._raw_fitz]
  113. self._data_bits = bits
  114. self._raw_data = bits
  115. self._classify_result = None
  116. if lang == '':
  117. self._lang = None
  118. elif lang == 'auto':
  119. from magic_pdf.model.sub_modules.language_detection.utils import \
  120. auto_detect_lang
  121. self._lang = auto_detect_lang(bits)
  122. logger.info(f'lang: {lang}, detect_lang: {self._lang}')
  123. else:
  124. self._lang = lang
  125. logger.info(f'lang: {lang}')
  126. def __len__(self) -> int:
  127. """The page number of the pdf."""
  128. return len(self._records)
  129. def __iter__(self) -> Iterator[PageableData]:
  130. """Yield the page doc object."""
  131. return iter(self._records)
  132. def supported_methods(self) -> list[SupportedPdfParseMethod]:
  133. """The method supported by this dataset.
  134. Returns:
  135. list[SupportedPdfParseMethod]: the supported methods
  136. """
  137. return [SupportedPdfParseMethod.OCR, SupportedPdfParseMethod.TXT]
  138. def data_bits(self) -> bytes:
  139. """The pdf bits used to create this dataset."""
  140. return self._data_bits
  141. def get_page(self, page_id: int) -> PageableData:
  142. """The page doc object.
  143. Args:
  144. page_id (int): the page doc index
  145. Returns:
  146. PageableData: the page doc object
  147. """
  148. return self._records[page_id]
  149. def dump_to_file(self, file_path: str):
  150. """Dump the file.
  151. Args:
  152. file_path (str): the file path
  153. """
  154. dir_name = os.path.dirname(file_path)
  155. if dir_name not in ('', '.', '..'):
  156. os.makedirs(dir_name, exist_ok=True)
  157. self._raw_fitz.save(file_path)
  158. def apply(self, proc: Callable, *args, **kwargs):
  159. """Apply callable method which.
  160. Args:
  161. proc (Callable): invoke proc as follows:
  162. proc(dataset, *args, **kwargs)
  163. Returns:
  164. Any: return the result generated by proc
  165. """
  166. if 'lang' in kwargs and self._lang is not None:
  167. kwargs['lang'] = self._lang
  168. return proc(self, *args, **kwargs)
  169. def classify(self) -> SupportedPdfParseMethod:
  170. """classify the dataset.
  171. Returns:
  172. SupportedPdfParseMethod: _description_
  173. """
  174. if self._classify_result is None:
  175. self._classify_result = classify(self._data_bits)
  176. return self._classify_result
  177. def clone(self):
  178. """clone this dataset."""
  179. return PymuDocDataset(self._raw_data)
  180. def set_images(self, images):
  181. for i in range(len(self._records)):
  182. self._records[i].set_image(images[i])
  183. class ImageDataset(Dataset):
  184. def __init__(self, bits: bytes):
  185. """Initialize the dataset, which wraps the pymudoc documents.
  186. Args:
  187. bits (bytes): the bytes of the photo which will be converted to pdf first. then converted to pymudoc.
  188. """
  189. pdf_bytes = fitz.open(stream=bits).convert_to_pdf()
  190. self._raw_fitz = fitz.open('pdf', pdf_bytes)
  191. self._records = [Doc(v) for v in self._raw_fitz]
  192. self._raw_data = bits
  193. self._data_bits = pdf_bytes
  194. def __len__(self) -> int:
  195. """The length of the dataset."""
  196. return len(self._records)
  197. def __iter__(self) -> Iterator[PageableData]:
  198. """Yield the page object."""
  199. return iter(self._records)
  200. def supported_methods(self):
  201. """The method supported by this dataset.
  202. Returns:
  203. list[SupportedPdfParseMethod]: the supported methods
  204. """
  205. return [SupportedPdfParseMethod.OCR]
  206. def data_bits(self) -> bytes:
  207. """The pdf bits used to create this dataset."""
  208. return self._data_bits
  209. def get_page(self, page_id: int) -> PageableData:
  210. """The page doc object.
  211. Args:
  212. page_id (int): the page doc index
  213. Returns:
  214. PageableData: the page doc object
  215. """
  216. return self._records[page_id]
  217. def dump_to_file(self, file_path: str):
  218. """Dump the file.
  219. Args:
  220. file_path (str): the file path
  221. """
  222. dir_name = os.path.dirname(file_path)
  223. if dir_name not in ('', '.', '..'):
  224. os.makedirs(dir_name, exist_ok=True)
  225. self._raw_fitz.save(file_path)
  226. def apply(self, proc: Callable, *args, **kwargs):
  227. """Apply callable method which.
  228. Args:
  229. proc (Callable): invoke proc as follows:
  230. proc(dataset, *args, **kwargs)
  231. Returns:
  232. Any: return the result generated by proc
  233. """
  234. return proc(self, *args, **kwargs)
  235. def classify(self) -> SupportedPdfParseMethod:
  236. """classify the dataset.
  237. Returns:
  238. SupportedPdfParseMethod: _description_
  239. """
  240. return SupportedPdfParseMethod.OCR
  241. def clone(self):
  242. """clone this dataset."""
  243. return ImageDataset(self._raw_data)
  244. def set_images(self, images):
  245. for i in range(len(self._records)):
  246. self._records[i].set_image(images[i])
  247. class Doc(PageableData):
  248. """Initialized with pymudoc object."""
  249. def __init__(self, doc: fitz.Page):
  250. self._doc = doc
  251. self._img = None
  252. def get_image(self):
  253. """Return the image info.
  254. Returns:
  255. dict: {
  256. img: np.ndarray,
  257. width: int,
  258. height: int
  259. }
  260. """
  261. if self._img is None:
  262. self._img = fitz_doc_to_image(self._doc)
  263. return self._img
  264. def set_image(self, img):
  265. """
  266. Args:
  267. img (np.ndarray): the image
  268. """
  269. if self._img is None:
  270. self._img = img
  271. def get_doc(self) -> fitz.Page:
  272. """Get the pymudoc object.
  273. Returns:
  274. fitz.Page: the pymudoc object
  275. """
  276. return self._doc
  277. def get_page_info(self) -> PageInfo:
  278. """Get the page info of the page.
  279. Returns:
  280. PageInfo: the page info of this page
  281. """
  282. page_w = self._doc.rect.width
  283. page_h = self._doc.rect.height
  284. return PageInfo(w=page_w, h=page_h)
  285. def __getattr__(self, name):
  286. if hasattr(self._doc, name):
  287. return getattr(self._doc, name)
  288. def draw_rect(self, rect_coords, color, fill, fill_opacity, width, overlay):
  289. """draw rectangle.
  290. Args:
  291. rect_coords (list[float]): four elements array contain the top-left and bottom-right coordinates, [x0, y0, x1, y1]
  292. color (list[float] | None): three element tuple which describe the RGB of the board line, None means no board line
  293. fill (list[float] | None): fill the board with RGB, None means will not fill with color
  294. fill_opacity (float): opacity of the fill, range from [0, 1]
  295. width (float): the width of board
  296. overlay (bool): fill the color in foreground or background. True means fill in background.
  297. """
  298. self._doc.draw_rect(
  299. rect_coords,
  300. color=color,
  301. fill=fill,
  302. fill_opacity=fill_opacity,
  303. width=width,
  304. overlay=overlay,
  305. )
  306. def insert_text(self, coord, content, fontsize, color):
  307. """insert text.
  308. Args:
  309. coord (list[float]): four elements array contain the top-left and bottom-right coordinates, [x0, y0, x1, y1]
  310. content (str): the text content
  311. fontsize (int): font size of the text
  312. color (list[float] | None): three element tuple which describe the RGB of the board line, None will use the default font color!
  313. """
  314. self._doc.insert_text(coord, content, fontsize=fontsize, color=color)