test_rapidtable.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import unittest
  2. import os
  3. from PIL import Image
  4. from lxml import etree
  5. from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
  6. from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
  7. class TestppTableModel(unittest.TestCase):
  8. def test_image2html(self):
  9. img = Image.open(os.path.join(os.path.dirname(__file__), "assets/table.jpg"))
  10. atom_model_manager = AtomModelSingleton()
  11. ocr_engine = atom_model_manager.get_atom_model(
  12. atom_model_name='ocr',
  13. ocr_show_log=False,
  14. det_db_box_thresh=0.5,
  15. det_db_unclip_ratio=1.6,
  16. lang='ch'
  17. )
  18. table_model = RapidTableModel(ocr_engine, 'slanet_plus')
  19. html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(img)
  20. # 验证生成的 HTML 是否符合预期
  21. parser = etree.HTMLParser()
  22. tree = etree.fromstring(html_code, parser)
  23. # 检查 HTML 结构
  24. assert tree.find('.//table') is not None, "HTML should contain a <table> element"
  25. assert tree.find('.//tr') is not None, "HTML should contain a <tr> element"
  26. assert tree.find('.//td') is not None, "HTML should contain a <td> element"
  27. # 检查具体的表格内容
  28. headers = tree.xpath('//table/tr[1]/td')
  29. assert len(headers) == 5, "Thead should have 5 columns"
  30. assert headers[0].text and headers[0].text.strip() == "Methods", "First header should be 'Methods'"
  31. assert headers[1].text and headers[1].text.strip() == "R", "Second header should be 'R'"
  32. assert headers[2].text and headers[2].text.strip() == "P", "Third header should be 'P'"
  33. assert headers[3].text and headers[3].text.strip() == "F", "Fourth header should be 'F'"
  34. assert headers[4].text and headers[4].text.strip() == "FPS", "Fifth header should be 'FPS'"
  35. # 检查第一行数据
  36. first_row = tree.xpath('//table/tr[2]/td')
  37. assert len(first_row) == 5, "First row should have 5 cells"
  38. assert first_row[0].text and first_row[0].text.strip() == "SegLink[26]", "First cell should be 'SegLink[26]'"
  39. assert first_row[1].text and first_row[1].text.strip() == "70.0", "Second cell should be '70.0'"
  40. assert first_row[2].text and first_row[2].text.strip() == "86.0", "Third cell should be '86.0'"
  41. assert first_row[3].text and first_row[3].text.strip() == "77.0", "Fourth cell should be '77.0'"
  42. assert first_row[4].text and first_row[4].text.strip() == "8.9", "Fifth cell should be '8.9'"
  43. # 检查倒数第二行数据
  44. second_last_row = tree.xpath('//table/tr[position()=last()-1]/td')
  45. assert len(second_last_row) == 5, "second_last_row should have 5 cells"
  46. assert second_last_row[0].text and second_last_row[0].text.strip() == "Ours (SynText)", "First cell should be 'Ours (SynText)'"
  47. assert second_last_row[1].text and second_last_row[1].text.strip() == "80.68", "Second cell should be '80.68'"
  48. assert second_last_row[2].text and second_last_row[2].text.strip() == "85.40", "Third cell should be '85.40'"
  49. # assert second_last_row[3].text and second_last_row[3].text.strip() == "82.97", "Fourth cell should be '82.97'"
  50. # assert second_last_row[3].text and second_last_row[4].text.strip() == "12.68", "Fifth cell should be '12.68'"
  51. if __name__ == "__main__":
  52. unittest.main()