|
|
@@ -1,14 +1,58 @@
|
|
|
-import pytest
|
|
|
+import unittest
|
|
|
from PIL import Image
|
|
|
+from lxml import etree
|
|
|
+
|
|
|
from magic_pdf.model.ppTableModel import ppTableModel
|
|
|
|
|
|
-class TestppTableModel:
|
|
|
+
|
|
|
+class TestppTableModel(unittest.TestCase):
|
|
|
def test_image2html(self):
|
|
|
- img = Image.open("tests/unittest/test_table/assets/table.jpg")
|
|
|
+ img = Image.open("tests/test_table/assets/table.jpg")
|
|
|
# 修改table模型路径
|
|
|
config = {"device": "cuda",
|
|
|
"model_dir": "/home/quyuan/.cache/modelscope/hub/opendatalab/PDF-Extract-Kit/models/TabRec/TableMaster"}
|
|
|
table_model = ppTableModel(config)
|
|
|
res = table_model.img2html(img)
|
|
|
- true_value = """<td><table border="1"><thead><tr><td><b>Methods</b></td><td><b>R</b></td><td><b>P</b></td><td><b>F</b></td><td><b>FPS</b></td></tr></thead><tbody><tr><td>SegLink[26]</td><td>70.0</td><td>86.0</td><td>77.0</td><td>8.9</td></tr><tr><td>PixelLink[4]</td><td>73.2</td><td>83.0</td><td>77.8</td></tr><tr><td>TextSnake [18]</td><td>73.9</td><td>83.2</td><td>78.3</td><td>1.1</td></tr><tr><td>TextField [37]</td><td>75.9</td><td>87.4</td><td>81.3</td><td>5.2</td></tr><tr><td>MSR[38]</td><td>76.7</td><td>87.4</td><td>81.7</td><td>-</td></tr><tr><td>FTSN [3]</td><td>77.1</td><td>87.6</td><td>82.0</td><td>-</td></tr><tr><td>LSE[30]</td><td>81.7</td><td>84.2</td><td>82.9</td><td>-</td></tr><tr><td>CRAFT [2]</td><td>78.2</td><td>88.2</td><td>82.9</td><td>8.6</td></tr><tr><td>MCN[16]</td><td>79</td><td>88</td><td>83</td><td>-</td></tr><tr><td>ATRR[35]</td><td>82.1</td><td>85.2</td><td>83.6</td><td></td></tr><tr><td>PAN [34]</td><td>83.8</td><td>84.4</td><td>84.1</td><td>30.2</td></tr><tr><td>DB[12]</td><td>79.2</td><td>91.5</td><td>84.9</td><td>32.0</td></tr><tr><td>DRRG [41]</td><td>82.30</td><td>88.05</td><td>85.08</td></tr><tr><td>Ours (SynText)</td><td>80.68</td><td>85.40</td><td>82.97</td><td>12.68</td></tr><tr><td>Ours (MLT-17)</td><td>84.54</td><td>86.62</td></tr></tbody></table></td>\n"""
|
|
|
- assert res == true_value
|
|
|
+ # 验证生成的 HTML 是否符合预期
|
|
|
+ parser = etree.HTMLParser()
|
|
|
+ tree = etree.fromstring(res, parser)
|
|
|
+
|
|
|
+ # 检查 HTML 结构
|
|
|
+ assert tree.find('.//table') is not None, "HTML should contain a <table> element"
|
|
|
+ assert tree.find('.//thead') is not None, "HTML should contain a <thead> element"
|
|
|
+ assert tree.find('.//tbody') is not None, "HTML should contain a <tbody> element"
|
|
|
+ assert tree.find('.//tr') is not None, "HTML should contain a <tr> element"
|
|
|
+ assert tree.find('.//td') is not None, "HTML should contain a <td> element"
|
|
|
+
|
|
|
+ # 检查具体的表格内容
|
|
|
+ headers = tree.xpath('//thead/tr/td/b')
|
|
|
+ print(headers) # Print headers for debugging
|
|
|
+ assert len(headers) == 5, "Thead should have 5 columns"
|
|
|
+ assert headers[0].text and headers[0].text.strip() == "Methods", "First header should be 'Methods'"
|
|
|
+ assert headers[1].text and headers[1].text.strip() == "R", "Second header should be 'R'"
|
|
|
+ assert headers[2].text and headers[2].text.strip() == "P", "Third header should be 'P'"
|
|
|
+ assert headers[3].text and headers[3].text.strip() == "F", "Fourth header should be 'F'"
|
|
|
+ assert headers[4].text and headers[4].text.strip() == "FPS", "Fifth header should be 'FPS'"
|
|
|
+
|
|
|
+ # 检查第一行数据
|
|
|
+ first_row = tree.xpath('//tbody/tr[1]/td')
|
|
|
+ assert len(first_row) == 5, "First row should have 5 cells"
|
|
|
+ assert first_row[0].text and first_row[0].text.strip() == "SegLink[26]", "First cell should be 'SegLink[26]'"
|
|
|
+ assert first_row[1].text and first_row[1].text.strip() == "70.0", "Second cell should be '70.0'"
|
|
|
+ assert first_row[2].text and first_row[2].text.strip() == "86.0", "Third cell should be '86.0'"
|
|
|
+ assert first_row[3].text and first_row[3].text.strip() == "77.0", "Fourth cell should be '77.0'"
|
|
|
+ assert first_row[4].text and first_row[4].text.strip() == "8.9", "Fifth cell should be '8.9'"
|
|
|
+
|
|
|
+ # 检查倒数第二行数据
|
|
|
+ second_last_row = tree.xpath('//tbody/tr[position()=last()-1]/td')
|
|
|
+ assert len(second_last_row) == 5, "second_last_row should have 5 cells"
|
|
|
+ assert second_last_row[0].text and second_last_row[
|
|
|
+ 0].text.strip() == "Ours (SynText)", "First cell should be 'Ours (SynText)'"
|
|
|
+ assert second_last_row[1].text and second_last_row[1].text.strip() == "80.68", "Second cell should be '80.68'"
|
|
|
+ assert second_last_row[2].text and second_last_row[2].text.strip() == "85.40", "Third cell should be '85.40'"
|
|
|
+ assert second_last_row[3].text and second_last_row[3].text.strip() == "82.97", "Fourth cell should be '82.97'"
|
|
|
+ assert second_last_row[3].text and second_last_row[4].text.strip() == "12.68", "Fifth cell should be '12.68'"
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ unittest.main()
|