test_api.py 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import json
  2. import os
  3. import shutil
  4. import tempfile
  5. from magic_pdf.integrations.rag.api import DataReader, RagDocumentReader
  6. from magic_pdf.integrations.rag.type import CategoryType
  7. from magic_pdf.integrations.rag.utils import \
  8. convert_middle_json_to_layout_elements
  9. def test_rag_document_reader():
  10. # setup
  11. unitest_dir = '/tmp/magic_pdf/unittest/integrations/rag'
  12. os.makedirs(unitest_dir, exist_ok=True)
  13. temp_output_dir = tempfile.mkdtemp(dir=unitest_dir)
  14. os.makedirs(temp_output_dir, exist_ok=True)
  15. # test
  16. with open('tests/test_integrations/test_rag/assets/middle.json') as f:
  17. json_data = json.load(f)
  18. res = convert_middle_json_to_layout_elements(json_data, temp_output_dir)
  19. doc = RagDocumentReader(res)
  20. assert len(list(iter(doc))) == 1
  21. page = list(iter(doc))[0]
  22. assert len(list(iter(page))) == 10
  23. assert len(page.get_rel_map()) == 3
  24. item = list(iter(page))[0]
  25. assert item.category_type == CategoryType.text
  26. # teardown
  27. shutil.rmtree(temp_output_dir)
  28. def test_data_reader():
  29. # setup
  30. unitest_dir = '/tmp/magic_pdf/unittest/integrations/rag'
  31. os.makedirs(unitest_dir, exist_ok=True)
  32. temp_output_dir = tempfile.mkdtemp(dir=unitest_dir)
  33. os.makedirs(temp_output_dir, exist_ok=True)
  34. # test
  35. data_reader = DataReader('tests/test_integrations/test_rag/assets', 'ocr',
  36. temp_output_dir)
  37. assert data_reader.get_documents_count() == 2
  38. for idx in range(data_reader.get_documents_count()):
  39. document = data_reader.get_document_result(idx)
  40. assert document is not None
  41. # teardown
  42. shutil.rmtree(temp_output_dir)