Ver código fonte

Merge pull request #854 from myhloli/add-structeqtable

feat(table): upgrade StructEqTable model and integrate into PDF Extract Kit
Xiaomeng Zhao 1 ano atrás
pai
commit
dc31c97b8a

+ 12 - 9
magic_pdf/model/pdf_extract_kit.py

@@ -38,15 +38,13 @@ except ImportError as e:
 from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
 from magic_pdf.model.pek_sub_modules.post_process import latex_rm_whitespace
 from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
-# from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
+from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
 from magic_pdf.model.ppTableModel import ppTableModel
 
 
 def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
     if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
-        # table_model = StructTableModel(model_path, max_time=max_time, device=_device_)
-        logger.error("StructEqTable is under upgrade, the current version does not support it.")
-        exit(1)
+        table_model = StructTableModel(model_path, max_time=max_time)
     elif table_model_type == MODEL_NAME.TABLE_MASTER:
         config = {
             "model_dir": model_path,
@@ -393,7 +391,7 @@ class CustomPEKModel:
             elif int(res['category_id']) in [5]:
                 table_res_list.append(res)
 
-        if torch.cuda.is_available():
+        if torch.cuda.is_available() and self.device != 'cpu':
             properties = torch.cuda.get_device_properties(self.device)
             total_memory = properties.total_memory / (1024 ** 3)  # 将字节转换为 GB
             if total_memory <= 10:
@@ -463,7 +461,9 @@ class CustomPEKModel:
                 html_code = None
                 if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
                     with torch.no_grad():
-                        latex_code = self.table_model.image2latex(new_image)[0]
+                        table_result = self.table_model.predict(new_image, "html")
+                        if len(table_result) > 0:
+                            html_code = table_result[0]
                 else:
                     html_code = self.table_model.img2html(new_image)
 
@@ -474,14 +474,17 @@ class CustomPEKModel:
                 # 判断是否返回正常
 
                 if latex_code:
-                    expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith(
-                        'end{table}')
+                    expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith('end{table}')
                     if expected_ending:
                         res["latex"] = latex_code
                     else:
                         logger.warning(f"table recognition processing fails, not found expected LaTeX table end")
                 elif html_code:
-                    res["html"] = html_code
+                    expected_ending = html_code.strip().endswith('</html>') or html_code.strip().endswith('</table>')
+                    if expected_ending:
+                        res["html"] = html_code
+                    else:
+                        logger.warning(f"table recognition processing fails, not found expected HTML table end")
                 else:
                     logger.warning(f"table recognition processing fails, not get latex or html return")
             logger.info(f"table time: {round(time.time() - table_start, 2)}")

+ 24 - 21
magic_pdf/model/pek_sub_modules/structeqtable/StructTableModel.py

@@ -1,28 +1,31 @@
-from loguru import logger
-
-try:
-    from struct_eqtable.model import StructTable
-except ImportError:
-    logger.error("StructEqTable is under upgrade, the current version does not support it.")
-from pypandoc import convert_text
+import torch
+from struct_eqtable import build_model
 
 
 class StructTableModel:
-    def __init__(self, model_path, max_new_tokens=2048, max_time=400, device = 'cpu'):
+    def __init__(self, model_path, max_new_tokens=1024, max_time=60):
         # init
-        self.model_path = model_path
-        self.max_new_tokens = max_new_tokens # maximum output tokens length
-        self.max_time = max_time # timeout for processing in seconds
-        if device == 'cuda':
-            self.model = StructTable(self.model_path, self.max_new_tokens, self.max_time).cuda()
+        assert torch.cuda.is_available(), "CUDA must be available for StructEqTable model."
+        self.model = build_model(
+            model_ckpt=model_path,
+            max_new_tokens=max_new_tokens,
+            max_time=max_time,
+            lmdeploy=False,
+            flash_attn=False,
+            batch_size=1,
+        ).cuda()
+        self.default_format = "html"
+
+    def predict(self, images, output_format=None, **kwargs):
+
+        if output_format is None:
+            output_format = self.default_format
         else:
-            self.model = StructTable(self.model_path, self.max_new_tokens, self.max_time)
+            if output_format not in ['latex', 'markdown', 'html']:
+                raise ValueError(f"Output format {output_format} is not supported.")
 
-    def image2latex(self, image) -> str:
-        table_latex = self.model.forward(image)
-        return table_latex
+        results = self.model(
+            images, output_format=output_format
+        )
 
-    def image2html(self, image) -> str:
-        table_latex = self.image2latex(image)
-        table_html = convert_text(table_latex, 'html', format='latex')
-        return table_html
+        return results

+ 3 - 3
magic_pdf/model/ppTableModel.py

@@ -39,9 +39,9 @@ class ppTableModel(object):
             image = np.array(image)
         pred_res, _ = self.table_sys(image)
         pred_html = pred_res["html"]
-        res = '<td><table  border="1">' + pred_html.replace("<html><body><table>", "").replace("</table></body></html>",
-                                                                                               "") + "</table></td>\n"
-        return res
+        # res = '<td><table  border="1">' + pred_html.replace("<html><body><table>", "").replace(
+        # "</table></body></html>","") + "</table></td>\n"
+        return pred_html
 
     def parse_args(self, **kwargs):
         parser = init_args()

+ 3 - 2
setup.py

@@ -43,8 +43,9 @@ if __name__ == '__main__':
                      "paddleocr==2.7.3",  # 2.8.0及2.8.1版本与detectron2有冲突,需锁定2.7.3
                      "paddlepaddle==3.0.0b1;platform_system=='Linux'",  # 解决linux的段异常问题
                      "paddlepaddle==2.6.1;platform_system=='Windows' or platform_system=='Darwin'",  # windows版本3.0.0b1效率下降,需锁定2.6.1
-                     "pypandoc",  # 表格解析latex转html
-                     "struct-eqtable==0.1.0",  # 表格解析
+                     "struct-eqtable==0.3.2",  # 表格解析
+                     "einops",  # struct-eqtable依赖
+                     "accelerate",  # struct-eqtable依赖
                      "doclayout_yolo==0.0.2",  # doclayout_yolo
                      "detectron2"
                      ],