浏览代码

feat: enhance table classification logic and add OCR detection flag

myhloli 3 月之前
父节点
当前提交
865b44a517
共有 2 个文件被更改,包括 13 次插入4 次删除
  1. 5 1
      mineru/backend/pipeline/pipeline_analyze.py
  2. 8 3
      mineru/model/table/cls/paddle_table_cls.py

+ 5 - 1
mineru/backend/pipeline/pipeline_analyze.py

@@ -190,7 +190,11 @@ def batch_image_analyze(
             batch_ratio = 1
             logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_ratio}')
 
-    batch_model = BatchAnalyze(model_manager, batch_ratio, formula_enable, table_enable)
+    if str(device).startswith('mps'):
+        enable_ocr_det_batch = False
+    else:
+        enable_ocr_det_batch = True
+    batch_model = BatchAnalyze(model_manager, batch_ratio, formula_enable, table_enable, enable_ocr_det_batch)
     results = batch_model(images_with_extra_info)
 
     clean_memory(get_device())

+ 8 - 3
mineru/model/table/cls/paddle_table_cls.py

@@ -3,6 +3,7 @@ import os
 import cv2
 import numpy as np
 import onnxruntime
+from loguru import logger
 
 from mineru.backend.pipeline.model_list import AtomicModel
 from mineru.utils.enum_class import ModelPath
@@ -63,6 +64,10 @@ class PaddleTableClsModel:
 
     def predict(self, img):
         x = self.preprocess(img)
-        (result,) = self.sess.run(None, {"x": x})
-        label = self.labels[np.argmax(result)]
-        return label
+        result = self.sess.run(None, {"x": x})
+        idx = np.argmax(result)
+        conf = float(np.max(result))
+        # logger.debug(f"Table classification result: {self.labels[idx]} with confidence {conf:.4f}")
+        if idx == 0 and conf < 0.8:
+            idx = 1
+        return self.labels[idx]