소스 검색

fix: copilot suggestion

Sidney233 2 달 전
부모
커밋
65b2ddc07f

+ 1 - 1
mineru/backend/pipeline/batch_analyze.py

@@ -262,7 +262,7 @@ class BatchAnalyze:
                 atom_model_name=AtomicModel.ImgOrientationCls,
             )
             try:
-                img_orientation_cls_model.batch_predict(table_res_list_all_page, self.batch_ratio * OCR_DET_BASE_BATCH_SIZE)
+                img_orientation_cls_model.batch_predict(table_res_list_all_page, batch_size=self.batch_ratio * OCR_DET_BASE_BATCH_SIZE)
             except Exception as e:
                 logger.warning(
                     f"Image orientation classification failed: {e}, using original image"

+ 1 - 1
mineru/model/table/rec/slanet_plus/table_structure.py

@@ -16,7 +16,7 @@ from typing import Any, Dict, List, Tuple
 
 import numpy as np
 
-from .table_stucture_utils import (
+from .table_structure_utils import (
     OrtInferSession,
     TableLabelDecode,
     TablePreprocess,

+ 0 - 0
mineru/model/table/rec/slanet_plus/table_stucture_utils.py → mineru/model/table/rec/slanet_plus/table_structure_utils.py


+ 27 - 29
mineru/model/table/rec/unet_table/main.py

@@ -10,7 +10,6 @@ import numpy as np
 import cv2
 from PIL import Image
 from loguru import logger
-from  ..slanet_plus.main import RapidTableInput, RapidTable
 
 from .table_structure_unet import TSRUnet
 
@@ -262,31 +261,30 @@ class UnetTableModel:
                 if len(item) == 2 and isinstance(item[1], tuple)
             ]
 
-        if ocr_result:
-            try:
-                wired_table_results = self.wired_table_model(np_img, ocr_result)
-
-                wired_html_code = wired_table_results.pred_html
-
-                wired_len = count_table_cells_physical(wired_html_code)
-                wireless_len = count_table_cells_physical(wireless_html_code)
-
-                # logger.debug(f"wired table cell bboxes: {wired_len}, wireless table cell bboxes: {wireless_len}")
-                # 计算两种模型检测的单元格数量差异
-                gap_of_len = wireless_len - wired_len
-                # 判断是否使用无线表格模型的结果
-                if (
-                    wired_len <= int(wireless_len * 0.55)+1  # 有线模型检测到的单元格数太少(低于无线模型的50%)
-                    # or ((round(wireless_len*1.2) < wired_len) and (wired_len < (2 * wireless_len)) and table_cls_score <= 0.94)  # 有线模型检测到的单元格数反而更多
-                    or (0 <= gap_of_len <= 5 and wired_len <= round(wireless_len * 0.75))  # 两者相差不大但有线模型结果较少
-                    or (gap_of_len == 0 and wired_len <= 4)  # 单元格数量完全相等且总量小于等于4
-                ):
-                    # logger.debug("fall back to wireless table model")
-                    html_code = wireless_html_code
-                else:
-                    html_code = wired_html_code
-
-                return html_code
-            except Exception as e:
-                logger.exception(e)
-        return None
+        try:
+            wired_table_results = self.wired_table_model(np_img, ocr_result)
+
+            wired_html_code = wired_table_results.pred_html
+
+            wired_len = count_table_cells_physical(wired_html_code)
+            wireless_len = count_table_cells_physical(wireless_html_code)
+
+            # logger.debug(f"wired table cell bboxes: {wired_len}, wireless table cell bboxes: {wireless_len}")
+            # 计算两种模型检测的单元格数量差异
+            gap_of_len = wireless_len - wired_len
+            # 判断是否使用无线表格模型的结果
+            if (
+                wired_len <= int(wireless_len * 0.55)+1  # 有线模型检测到的单元格数太少(低于无线模型的50%)
+                # or ((round(wireless_len*1.2) < wired_len) and (wired_len < (2 * wireless_len)) and table_cls_score <= 0.94)  # 有线模型检测到的单元格数反而更多
+                or (0 <= gap_of_len <= 5 and wired_len <= round(wireless_len * 0.75))  # 两者相差不大但有线模型结果较少
+                or (gap_of_len == 0 and wired_len <= 4)  # 单元格数量完全相等且总量小于等于4
+            ):
+                # logger.debug("fall back to wireless table model")
+                html_code = wireless_html_code
+            else:
+                html_code = wired_html_code
+
+            return html_code
+        except Exception as e:
+            logger.exception(e)
+            return None