api.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import os
  2. from pathlib import Path
  3. from loguru import logger
  4. from magic_pdf.integrations.rag.type import (ElementRelation, LayoutElements,
  5. Node)
  6. from magic_pdf.integrations.rag.utils import inference
  7. class RagPageReader:
  8. def __init__(self, pagedata: LayoutElements):
  9. self.o = [
  10. Node(
  11. category_type=v.category_type,
  12. text=v.text,
  13. image_path=v.image_path,
  14. anno_id=v.anno_id,
  15. latex=v.latex,
  16. html=v.html,
  17. ) for v in pagedata.layout_dets
  18. ]
  19. self.pagedata = pagedata
  20. def __iter__(self):
  21. return iter(self.o)
  22. def get_rel_map(self) -> list[ElementRelation]:
  23. return self.pagedata.extra.element_relation
  24. class RagDocumentReader:
  25. def __init__(self, ragdata: list[LayoutElements]):
  26. self.o = [RagPageReader(v) for v in ragdata]
  27. def __iter__(self):
  28. return iter(self.o)
  29. class DataReader:
  30. def __init__(self, path_or_directory: str, method: str, output_dir: str):
  31. self.path_or_directory = path_or_directory
  32. self.method = method
  33. self.output_dir = output_dir
  34. self.pdfs = []
  35. if os.path.isdir(path_or_directory):
  36. for doc_path in Path(path_or_directory).glob('*.pdf'):
  37. self.pdfs.append(doc_path)
  38. else:
  39. assert path_or_directory.endswith('.pdf')
  40. self.pdfs.append(Path(path_or_directory))
  41. def get_documents_count(self) -> int:
  42. """Returns the number of documents in the directory."""
  43. return len(self.pdfs)
  44. def get_document_result(self, idx: int) -> RagDocumentReader | None:
  45. """
  46. Args:
  47. idx (int): the index of documents under the
  48. directory path_or_directory
  49. Returns:
  50. RagDocumentReader | None: RagDocumentReader is an iterable object,
  51. more details @RagDocumentReader
  52. """
  53. if idx >= self.get_documents_count() or idx < 0:
  54. logger.error(f'invalid idx: {idx}')
  55. return None
  56. res = inference(str(self.pdfs[idx]), self.output_dir, self.method)
  57. if res is None:
  58. logger.warning(f'failed to inference pdf {self.pdfs[idx]}')
  59. return None
  60. return RagDocumentReader(res)
  61. def get_document_filename(self, idx: int) -> Path:
  62. """get the filename of the document."""
  63. return self.pdfs[idx]