Explorar el Código

add table recognition success detect

liukaiwen hace 1 año
padre
commit
b18496b0ae
Se han modificado 2 ficheros con 15 adiciones y 5 borrados
  1. 4 1
      magic_pdf/libs/Constants.py
  2. 11 4
      magic_pdf/model/pdf_extract_kit.py

+ 4 - 1
magic_pdf/libs/Constants.py

@@ -8,4 +8,7 @@ CROSS_PAGE = "cross_page"
 block维度自定义字段
 """
 # block中lines是否被删除
-LINES_DELETED = "lines_deleted"
+LINES_DELETED = "lines_deleted"
+
+# table recognition max time default value
+TABLE_MAX_TIME_VALUE = 400

+ 11 - 4
magic_pdf/model/pdf_extract_kit.py

@@ -2,6 +2,7 @@ from loguru import logger
 import os
 import time
 
+from magic_pdf.libs.Constants import TABLE_MAX_TIME_VALUE
 
 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
 try:
@@ -105,6 +106,7 @@ class CustomPEKModel:
         self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
         self.table_config = kwargs.get("table_config", self.configs["config"]["table_config"])
         self.apply_table = self.table_config.get("is_table_recog_enable", False)
+        self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
         self.apply_ocr = ocr
         logger.info(
             "DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}".format(
@@ -141,9 +143,8 @@ class CustomPEKModel:
 
         # init structeqtable
         if self.apply_table:
-            max_time = self.table_config.get("max_time", 400)
             self.table_model = table_model_init(str(os.path.join(models_dir, self.configs["weights"]["table"])),
-                                                max_time=max_time, _device_=self.device)
+                                                max_time = self.table_max_time, _device_=self.device)
         logger.info('DocAnalysis init done!')
 
     def __call__(self, image):
@@ -290,6 +291,12 @@ class CustomPEKModel:
                     end_time = time.time()
                     run_time = end_time - start_time
                     logger.info(f"------------table recognition processing ends within {run_time}s-----")
-                    layout["latex"] = latex_code
-
+                    if run_time > self.table_max_time:
+                        logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
+                    # 判断是否返回正常
+                    if latex_code and latex_code.strip().endswith('end{tabular}'):
+                        layout["latex"] = latex_code
+                    else:
+                        print(latex_code)
+                        logger.warning(f"------------table recognition processing fails----------")
         return layout_res