Jelajahi Sumber

Merge pull request #3275 from myhloli/dev

Dev
Xiaomeng Zhao 3 bulan lalu
induk
melakukan
5db0f7b9ce

+ 19 - 17
mineru/backend/pipeline/batch_analyze.py

@@ -10,7 +10,7 @@ from ...utils.config_reader import get_formula_enable, get_table_enable
 from ...utils.model_utils import crop_img, get_res_list_from_layout_res
 from ...utils.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list, OcrConfidence
 
-YOLO_LAYOUT_BASE_BATCH_SIZE = 8
+YOLO_LAYOUT_BASE_BATCH_SIZE = 1
 MFD_BASE_BATCH_SIZE = 1
 MFR_BASE_BATCH_SIZE = 16
 OCR_DET_BASE_BATCH_SIZE = 16
@@ -251,17 +251,33 @@ class BatchAnalyze:
             for table_res_dict in tqdm(table_res_list_all_page, desc="Table Predict"):
                 _lang = table_res_dict['lang']
 
+                # 调整图片方向
+                img_orientation_cls_model = atom_model_manager.get_atom_model(
+                    atom_model_name=AtomicModel.ImgOrientationCls,
+                )
+                try:
+                    table_img = img_orientation_cls_model.predict(
+                        table_res_dict["table_img"]
+                    )
+                except Exception as e:
+                    logger.warning(
+                        f"Image orientation classification failed: {e}, using original image"
+                    )
+                    table_img = table_res_dict["table_img"]
+
                 # 有线表/无线表分类
                 table_cls_model = atom_model_manager.get_atom_model(
                     atom_model_name=AtomicModel.TableCls,
                 )
+                table_cls_score = 0.5
                 try:
-                    table_label = table_cls_model.predict(table_res_dict["table_img"])
+                    table_label, table_cls_score = table_cls_model.predict(table_img)
                 except Exception as e:
                     table_label = AtomicModel.WirelessTable
                     logger.warning(
                         f"Table classification failed: {e}, using default model {table_label}"
                     )
+                # table_label = AtomicModel.WirelessTable
                 # logger.debug(f"Table classification result: {table_label}")
                 if table_label not in [AtomicModel.WirelessTable, AtomicModel.WiredTable]:
                     raise ValueError(
@@ -274,21 +290,7 @@ class BatchAnalyze:
                     lang=_lang,
                 )
 
-                # 调整图片方向
-                img_orientation_cls_model = atom_model_manager.get_atom_model(
-                    atom_model_name=AtomicModel.ImgOrientationCls,
-                )
-                try:
-                    table_img = img_orientation_cls_model.predict(
-                        table_res_dict["table_img"]
-                    )
-                except Exception as e:
-                    logger.warning(
-                        f"Image orientation classification failed: {e}, using original image"
-                    )
-                    table_img = table_res_dict["table_img"]
-
-                html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(table_img)
+                html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(table_img, table_cls_score)
                 # 判断是否返回正常
                 if html_code:
                     # 检查html_code是否包含'<table>'和'</table>'

+ 30 - 8
mineru/backend/pipeline/model_init.py

@@ -19,7 +19,11 @@ from ...utils.models_download_utils import auto_download_and_get_model_root_path
 def img_orientation_cls_model_init():
     atom_model_manager = AtomModelSingleton()
     ocr_engine = atom_model_manager.get_atom_model(
-        atom_model_name="ocr", det_db_box_thresh=0.5, det_db_unclip_ratio=1.6, lang="ch_lite"
+        atom_model_name=AtomicModel.OCR,
+        det_db_box_thresh=0.5,
+        det_db_unclip_ratio=1.6,
+        lang="ch_lite",
+        enable_merge_det_boxes=False
     )
     cls_model = PaddleOrientationClsModel(ocr_engine)
     return cls_model
@@ -32,7 +36,11 @@ def table_cls_model_init():
 def wired_table_model_init(lang=None):
     atom_model_manager = AtomModelSingleton()
     ocr_engine = atom_model_manager.get_atom_model(
-        atom_model_name="ocr", det_db_box_thresh=0.5, det_db_unclip_ratio=1.6, lang=lang
+        atom_model_name=AtomicModel.OCR,
+        det_db_box_thresh=0.5,
+        det_db_unclip_ratio=1.6,
+        lang=lang,
+        enable_merge_det_boxes=False
     )
     table_model = UnetTableModel(ocr_engine)
     return table_model
@@ -41,7 +49,11 @@ def wired_table_model_init(lang=None):
 def wireless_table_model_init(lang=None):
     atom_model_manager = AtomModelSingleton()
     ocr_engine = atom_model_manager.get_atom_model(
-        atom_model_name="ocr", det_db_box_thresh=0.5, det_db_unclip_ratio=1.6, lang=lang
+        atom_model_name=AtomicModel.OCR,
+        det_db_box_thresh=0.5,
+        det_db_unclip_ratio=1.6,
+        lang=lang,
+        enable_merge_det_boxes=False
     )
     table_model = RapidTableModel(ocr_engine)
     return table_model
@@ -67,21 +79,23 @@ def doclayout_yolo_model_init(weight, device='cpu'):
 
 def ocr_model_init(det_db_box_thresh=0.3,
                    lang=None,
-                   use_dilation=True,
                    det_db_unclip_ratio=1.8,
+                   enable_merge_det_boxes=True
                    ):
     if lang is not None and lang != '':
         model = PytorchPaddleOCR(
             det_db_box_thresh=det_db_box_thresh,
             lang=lang,
-            use_dilation=use_dilation,
+            use_dilation=True,
             det_db_unclip_ratio=det_db_unclip_ratio,
+            enable_merge_det_boxes=enable_merge_det_boxes,
         )
     else:
         model = PytorchPaddleOCR(
             det_db_box_thresh=det_db_box_thresh,
-            use_dilation=use_dilation,
+            use_dilation=True,
             det_db_unclip_ratio=det_db_unclip_ratio,
+            enable_merge_det_boxes=enable_merge_det_boxes,
         )
     return model
 
@@ -99,8 +113,14 @@ class AtomModelSingleton:
 
         lang = kwargs.get('lang', None)
 
-        if atom_model_name in [AtomicModel.OCR, AtomicModel.WiredTable, AtomicModel.WirelessTable]:
+        if atom_model_name in [AtomicModel.WiredTable, AtomicModel.WirelessTable]:
             key = (atom_model_name, lang)
+        elif atom_model_name in [AtomicModel.OCR]:
+            key = (atom_model_name,
+                   kwargs.get('det_db_box_thresh', 0.3),
+                   lang, kwargs.get('det_db_unclip_ratio', 1.8),
+                   kwargs.get('enable_merge_det_boxes', True)
+                   )
         else:
             key = atom_model_name
 
@@ -127,8 +147,10 @@ def atom_model_init(model_name: str, **kwargs):
         )
     elif model_name == AtomicModel.OCR:
         atom_model = ocr_model_init(
-            kwargs.get('det_db_box_thresh'),
+            kwargs.get('det_db_box_thresh', 0.3),
             kwargs.get('lang'),
+            kwargs.get('det_db_unclip_ratio', 1.8),
+            kwargs.get('enable_merge_det_boxes', True)
         )
     elif model_name == AtomicModel.WirelessTable:
         atom_model = wireless_table_model_init(

+ 3 - 3
mineru/backend/pipeline/model_json_to_middle_json.py

@@ -21,13 +21,14 @@ from mineru.utils.span_block_fix import fill_spans_in_blocks, fix_discarded_bloc
 from mineru.utils.span_pre_proc import remove_outside_spans, remove_overlaps_low_confidence_spans, \
     remove_overlaps_min_spans, txt_spans_extract
 from mineru.version import __version__
-from mineru.utils.hash_utils import str_md5
+from mineru.utils.hash_utils import bytes_md5
 
 
 def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer, page_index, ocr_enable=False, formula_enabled=True):
     scale = image_dict["scale"]
     page_pil_img = image_dict["img_pil"]
-    page_img_md5 = str_md5(image_dict["img_base64"])
+    # page_img_md5 = str_md5(image_dict["img_base64"])
+    page_img_md5 = bytes_md5(page_pil_img.tobytes())
     page_w, page_h = map(int, page.get_size())
     magic_model = MagicModel(page_model_info, scale)
 
@@ -210,7 +211,6 @@ def result_to_middle_json(model_list, images_list, pdf_doc, image_writer, lang=N
         atom_model_manager = AtomModelSingleton()
         ocr_model = atom_model_manager.get_atom_model(
             atom_model_name='ocr',
-            ocr_show_log=False,
             det_db_box_thresh=0.3,
             lang=lang
         )

+ 2 - 1
mineru/backend/pipeline/pipeline_analyze.py

@@ -6,6 +6,7 @@ from loguru import logger
 
 from .model_init import MineruPipelineModel
 from mineru.utils.config_reader import get_device
+from ...utils.enum_class import ImageType
 from ...utils.pdf_classify import classify
 from ...utils.pdf_image_tools import load_images_from_pdf
 from ...utils.model_utils import get_vram, clean_memory
@@ -98,7 +99,7 @@ def doc_analyze(
         _lang = lang_list[pdf_idx]
 
         # 收集每个数据集中的页面
-        images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
+        images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.PIL)
         all_image_lists.append(images_list)
         all_pdf_docs.append(pdf_doc)
         for page_idx in range(len(images_list)):

+ 3 - 1
mineru/backend/vlm/token_to_middle_json.py

@@ -8,6 +8,7 @@ from mineru.utils.enum_class import ContentType
 from mineru.utils.hash_utils import str_md5
 from mineru.backend.vlm.vlm_magic_model import MagicModel
 from mineru.utils.pdf_image_tools import get_crop_img
+from mineru.utils.pdf_reader import base64_to_pil_image
 from mineru.version import __version__
 
 heading_level_import_success = False
@@ -32,7 +33,8 @@ def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dic
     # 提取所有完整块,每个块从<|box_start|>开始到<|md_end|>或<|im_end|>结束
 
     scale = image_dict["scale"]
-    page_pil_img = image_dict["img_pil"]
+    # page_pil_img = image_dict["img_pil"]
+    page_pil_img = base64_to_pil_image(image_dict["img_base64"])
     page_img_md5 = str_md5(image_dict["img_base64"])
     width, height = map(int, page.get_size())
 

+ 2 - 1
mineru/backend/vlm/vlm_analyze.py

@@ -8,6 +8,7 @@ from mineru.utils.pdf_image_tools import load_images_from_pdf
 from .base_predictor import BasePredictor
 from .predictor import get_predictor
 from .token_to_middle_json import result_to_middle_json
+from ...utils.enum_class import ImageType
 from ...utils.models_download_utils import auto_download_and_get_model_root_path
 
 
@@ -53,7 +54,7 @@ def doc_analyze(
         predictor = ModelSingleton().get_model(backend, model_path, server_url, **kwargs)
 
     # load_images_start = time.time()
-    images_list, pdf_doc = load_images_from_pdf(pdf_bytes)
+    images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.BASE64)
     images_base64_list = [image_dict["img_base64"] for image_dict in images_list]
     # load_images_time = round(time.time() - load_images_start, 2)
     # logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")

+ 5 - 1
mineru/model/layout/doclayout_yolo.py

@@ -60,10 +60,14 @@ class DocLayoutYOLOModel:
         with tqdm(total=len(images), desc="Layout Predict") as pbar:
             for idx in range(0, len(images), batch_size):
                 batch = images[idx: idx + batch_size]
+                if batch_size == 1:
+                    conf = 0.9 * self.conf
+                else:
+                    conf = self.conf
                 predictions = self.model.predict(
                     batch,
                     imgsz=self.imgsz,
-                    conf=self.conf,
+                    conf=conf,
                     iou=self.iou,
                     verbose=False,
                 )

+ 5 - 2
mineru/model/ocr/paddleocr2pytorch/pytorch_paddle.py

@@ -56,6 +56,7 @@ class PytorchPaddleOCR(TextSystem):
         args = parser.parse_args(args)
 
         self.lang = kwargs.get('lang', 'ch')
+        self.enable_merge_det_boxes = kwargs.get("enable_merge_det_boxes", True)
 
         device = get_device()
         if device == 'cpu' and self.lang in ['ch', 'ch_server', 'japan', 'chinese_cht']:
@@ -135,7 +136,8 @@ class PytorchPaddleOCR(TextSystem):
                         continue
                     dt_boxes = sorted_boxes(dt_boxes)
                     # merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
-                    dt_boxes = merge_det_boxes(dt_boxes)
+                    if self.enable_merge_det_boxes:
+                        dt_boxes = merge_det_boxes(dt_boxes)
                     if mfd_res:
                         dt_boxes = update_det_boxes(dt_boxes, mfd_res)
                     tmp_res = [box.tolist() for box in dt_boxes]
@@ -172,7 +174,8 @@ class PytorchPaddleOCR(TextSystem):
         dt_boxes = sorted_boxes(dt_boxes)
 
         # merge_det_boxes 和 update_det_boxes 都会把poly转成bbox再转回poly,因此需要过滤所有倾斜程度较大的文本框
-        dt_boxes = merge_det_boxes(dt_boxes)
+        if self.enable_merge_det_boxes:
+            dt_boxes = merge_det_boxes(dt_boxes)
 
         if mfd_res:
             dt_boxes = update_det_boxes(dt_boxes, mfd_res)

+ 9 - 10
mineru/model/ori_cls/paddle_ori_cls.py

@@ -4,6 +4,8 @@ import os
 import cv2
 import numpy as np
 import onnxruntime
+from loguru import logger
+
 from mineru.utils.enum_class import ModelPath
 from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
 
@@ -90,7 +92,7 @@ class PaddleOrientationClsModel:
                     # elif aspect_ratio > 1.2:  # Wider than tall - horizontal text
                     #     horizontal_count += 1
 
-                if vertical_count >= len(det_res) * 0.3:
+                if vertical_count >= len(det_res) * 0.28 and vertical_count >= 3:
                     is_rotated = True
                 # logger.debug(f"Text orientation analysis: vertical={vertical_count}, det_res={len(det_res)}, rotated={is_rotated}")
 
@@ -100,16 +102,13 @@ class PaddleOrientationClsModel:
                     x = self.preprocess(img)
                     (result,) = self.sess.run(None, {"x": x})
                     label = self.labels[np.argmax(result)]
-
-                    if label == "90":
-                        rotation = cv2.ROTATE_90_COUNTERCLOCKWISE
-                        img = cv2.rotate(np.asarray(img), rotation)
-                    elif label == "180":
-                        rotation = cv2.ROTATE_180
-                        img = cv2.rotate(np.asarray(img), rotation)
-                    elif label == "270":
+                    # logger.debug(f"Orientation classification result: {label}")
+                    if label == "270":
                         rotation = cv2.ROTATE_90_CLOCKWISE
                         img = cv2.rotate(np.asarray(img), rotation)
+                    elif label == "90":
+                        rotation = cv2.ROTATE_90_COUNTERCLOCKWISE
+                        img = cv2.rotate(np.asarray(img), rotation)
                     else:
-                        img = np.array(img)
+                        pass
         return img

+ 3 - 4
mineru/model/table/cls/paddle_table_cls.py

@@ -24,8 +24,7 @@ class PaddleTableClsModel:
 
     def preprocess(self, img):
         # PIL图像转cv2
-        img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
-        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+        img = np.array(img)
         # 放大图片,使其最短边长为256
         h, w = img.shape[:2]
         scale = 256 / min(h, w)
@@ -68,6 +67,6 @@ class PaddleTableClsModel:
         idx = np.argmax(result)
         conf = float(np.max(result))
         # logger.debug(f"Table classification result: {self.labels[idx]} with confidence {conf:.4f}")
-        if idx == 0 and conf < 0.9:
+        if idx == 0 and conf < 0.8:
             idx = 1
-        return self.labels[idx]
+        return self.labels[idx], conf

+ 1 - 1
mineru/model/table/rec/rapid_table.py

@@ -22,7 +22,7 @@ class RapidTableModel(object):
         self.ocr_engine = ocr_engine
 
 
-    def predict(self, image):
+    def predict(self, image, table_cls_score):
         bgr_image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
         # Continue with OCR on potentially rotated image
         ocr_result = self.ocr_engine.ocr(bgr_image)[0]

+ 5 - 34
mineru/model/table/rec/unet_table/table_line_rec_utils.py

@@ -38,7 +38,7 @@ def min_area_rect(coords):
     box = image_location_sort_box(box)
 
     x1, y1, x2, y2, x3, y3, x4, y4 = box
-    degree, w, h, cx, cy = calculate_center_rotate_angle(box)
+    w, h = calculate_center_rotate_angle(box)
     if w < h:
         xmin = (x1 + x2) / 2
         xmax = (x3 + x4) / 2
@@ -50,9 +50,6 @@ def min_area_rect(coords):
         xmax = (x2 + x3) / 2
         ymin = (y1 + y4) / 2
         ymax = (y2 + y3) / 2
-    # degree,w,h,cx,cy = solve(box)
-    # x1,y1,x2,y2,x3,y3,x4,y4 = box
-    # return {'degree':degree,'w':w,'h':h,'cx':cx,'cy':cy}
     return [xmin, ymin, xmax, ymax]
 
 
@@ -65,21 +62,8 @@ def image_location_sort_box(box):
 
 
 def calculate_center_rotate_angle(box):
-    """
-    绕 cx,cy点 w,h 旋转 angle 的坐标,能一定程度缓解图片的内部倾斜,但是还是依赖模型稳妥
-    x = cx-w/2
-    y = cy-h/2
-    x1-cx = -w/2*cos(angle) +h/2*sin(angle)
-    y1 -cy= -w/2*sin(angle) -h/2*cos(angle)
-
-    h(x1-cx) = -wh/2*cos(angle) +hh/2*sin(angle)
-    w(y1 -cy)= -ww/2*sin(angle) -hw/2*cos(angle)
-    (hh+ww)/2sin(angle) = h(x1-cx)-w(y1 -cy)
 
-    """
     x1, y1, x2, y2, x3, y3, x4, y4 = box[:8]
-    cx = (x1 + x3 + x2 + x4) / 4.0
-    cy = (y1 + y3 + y4 + y2) / 4.0
     w = (
         np.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
         + np.sqrt((x3 - x4) ** 2 + (y3 - y4) ** 2)
@@ -88,11 +72,8 @@ def calculate_center_rotate_angle(box):
         np.sqrt((x2 - x3) ** 2 + (y2 - y3) ** 2)
         + np.sqrt((x1 - x4) ** 2 + (y1 - y4) ** 2)
     ) / 2
-    # x = cx-w/2
-    # y = cy-h/2
-    sinA = (h * (x1 - cx) - w * (y1 - cy)) * 1.0 / (h * h + w * w) * 2
-    angle = np.arcsin(sinA)
-    return angle, w, h, cx, cy
+
+    return w, h
 
 
 def _order_points(pts):
@@ -241,18 +222,8 @@ def min_area_rect_box(
         box = box.reshape((8,)).tolist()
         box = image_location_sort_box(box)
         x1, y1, x2, y2, x3, y3, x4, y4 = box
-        angle, w, h, cx, cy = calculate_center_rotate_angle(box)
-        # if adjustBox:
-        #     x1, y1, x2, y2, x3, y3, x4, y4 = xy_rotate_box(cx, cy, w + 5, h + 5, angle=0, degree=None)
-        #     x1, x4 = max(x1, 0), max(x4, 0)
-        #     y1, y2 = max(y1, 0), max(y2, 0)
-
-        # if w > 32 and h > 32 and flag:
-        #     if abs(angle / np.pi * 180) < 20:
-        #         if filtersmall and (w < 10 or h < 10):
-        #             continue
-        #         boxes.append([x1, y1, x2, y2, x3, y3, x4, y4])
-        # else:
+        w, h = calculate_center_rotate_angle(box)
+
         if w * h < 0.5 * W * H:
             if filtersmall and (
                 w < 15 or h < 15

+ 83 - 13
mineru/model/table/rec/unet_table/unet_table.py

@@ -1,5 +1,4 @@
 import html
-import logging
 import os
 import time
 import traceback
@@ -9,6 +8,7 @@ from typing import List, Optional, Union, Dict, Any
 import cv2
 import numpy as np
 from loguru import logger
+from rapid_table import RapidTableInput, RapidTable
 
 from mineru.utils.enum_class import ModelPath
 from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
@@ -48,6 +48,7 @@ class UnetTableRecognition:
         self,
         img: InputType,
         ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None,
+        ocr_engine = None,
         **kwargs,
     ) -> UnetTableOutput:
         s = time.perf_counter()
@@ -61,7 +62,7 @@ class UnetTableRecognition:
         img = self.load_img(img)
         polygons, rotated_polygons = self.table_structure(img, **kwargs)
         if polygons is None:
-            logging.warning("polygons is None.")
+            # logger.warning("polygons is None.")
             return UnetTableOutput("", None, None, 0.0)
 
         try:
@@ -85,7 +86,7 @@ class UnetTableRecognition:
                 )
             cell_box_det_map, not_match_orc_boxes = match_ocr_cell(ocr_result, polygons)
             # 如果有识别框没有ocr结果,直接进行rec补充
-            cell_box_det_map = self.fill_blank_rec(img, polygons, cell_box_det_map)
+            cell_box_det_map = self.fill_blank_rec(img, polygons, cell_box_det_map, ocr_engine)
             # 转换为中间格式,修正识别框坐标,将物理识别框,逻辑识别框,ocr识别框整合为dict,方便后续处理
             t_rec_ocr_list = self.transform_res(cell_box_det_map, polygons, logi_points)
             # 将每个单元格中的ocr识别结果排序和同行合并,输出的html能完整保留文字的换行格式
@@ -102,7 +103,7 @@ class UnetTableRecognition:
             elapse = time.perf_counter() - s
 
         except Exception:
-            logging.warning(traceback.format_exc())
+            logger.warning(traceback.format_exc())
             return UnetTableOutput("", None, None, 0.0)
         return UnetTableOutput(pred_html, polygons, logi_points, elapse)
 
@@ -151,14 +152,51 @@ class UnetTableRecognition:
         img: np.ndarray,
         sorted_polygons: np.ndarray,
         cell_box_map: Dict[int, List[str]],
+        ocr_engine
     ) -> Dict[int, List[Any]]:
         """找到poly对应为空的框,尝试将直接将poly框直接送到识别中"""
+        img_crop_info_list = []
+        img_crop_list = []
         for i in range(sorted_polygons.shape[0]):
             if cell_box_map.get(i):
                 continue
             box = sorted_polygons[i]
-            cell_box_map[i] = [[box, "", 1]]
+            if ocr_engine is None:
+                logger.warning(f"No OCR engine provided for box {i}: {box}")
+                continue
+            # 从img中截取对应的区域
+            x1, y1, x2, y2 = box[0][0], box[0][1], box[2][0], box[2][1]
+            if x1 >= x2 or y1 >= y2:
+                logger.warning(f"Invalid box coordinates: {box}")
+                continue
+            img_crop = img[int(y1):int(y2), int(x1):int(x2)]
+            img_crop_list.append(img_crop)
+            img_crop_info_list.append([i, box])
             continue
+
+        if len(img_crop_list) > 0:
+            # 进行ocr识别
+            ocr_result = ocr_engine.ocr(img_crop_list, det=False)
+            if not ocr_result or not isinstance(ocr_result, list) or len(ocr_result) == 0:
+                logger.warning("OCR engine returned no results or invalid result for image crops.")
+                return cell_box_map
+            ocr_res_list = ocr_result[0]
+            if not isinstance(ocr_res_list, list) or len(ocr_res_list) != len(img_crop_list):
+                logger.warning("OCR result list length does not match image crop list length.")
+                return cell_box_map
+            for j, ocr_res in enumerate(ocr_res_list):
+                img_crop_info_list[j].append(ocr_res)
+
+
+            for i, box, ocr_res in img_crop_info_list:
+                # 处理ocr结果
+                ocr_text, ocr_score = ocr_res
+                # logger.debug(f"OCR result for box {i}: {ocr_text} with score {ocr_score}")
+                if ocr_score < 0.9:
+                    # logger.warning(f"Low confidence OCR result for box {i}: {ocr_text} with score {ocr_score}")
+                    continue
+                cell_box_map[i] = [[box, ocr_text, ocr_score]]
+
         return cell_box_map
 
 
@@ -170,11 +208,14 @@ def escape_html(input_string):
 class UnetTableModel:
     def __init__(self, ocr_engine):
         model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.unet_structure), ModelPath.unet_structure)
-        input_args = UnetTableInput(model_path=model_path)
-        self.table_model = UnetTableRecognition(input_args)
+        wired_input_args = UnetTableInput(model_path=model_path)
+        self.wired_table_model = UnetTableRecognition(wired_input_args)
+        slanet_plus_model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.slanet_plus), ModelPath.slanet_plus)
+        wireless_input_args = RapidTableInput(model_type='slanet_plus', model_path=slanet_plus_model_path)
+        self.wireless_table_model = RapidTable(wireless_input_args)
         self.ocr_engine = ocr_engine
 
-    def predict(self, img):
+    def predict(self, img, table_cls_score):
         bgr_img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
         ocr_result = self.ocr_engine.ocr(bgr_img)[0]
 
@@ -188,11 +229,40 @@ class UnetTableModel:
             ocr_result = None
         if ocr_result:
             try:
-                table_results = self.table_model(np.asarray(img), ocr_result)
-                html_code = table_results.pred_html
-                table_cell_bboxes = table_results.cell_bboxes
-                logic_points = table_results.logic_points
-                elapse = table_results.elapse
+                wired_table_results = self.wired_table_model(np.asarray(img), ocr_result, self.ocr_engine)
+                wired_html_code = wired_table_results.pred_html
+                wired_table_cell_bboxes = wired_table_results.cell_bboxes
+                wired_logic_points = wired_table_results.logic_points
+                wired_elapse = wired_table_results.elapse
+
+                wireless_table_results = self.wireless_table_model(np.asarray(img), ocr_result)
+                wireless_html_code = wireless_table_results.pred_html
+                wireless_table_cell_bboxes = wireless_table_results.cell_bboxes
+                wireless_logic_points = wireless_table_results.logic_points
+                wireless_elapse = wireless_table_results.elapse
+
+                wired_len = len(wired_table_cell_bboxes) if wired_table_cell_bboxes is not None else 0
+                wireless_len = len(wireless_table_cell_bboxes) if wireless_table_cell_bboxes is not None else 0
+                # logger.debug(f"wired table cell bboxes: {wired_len}, wireless table cell bboxes: {wireless_len}")
+                # 计算两种模型检测的单元格数量差异
+                gap_of_len = wireless_len - wired_len
+                # 判断是否使用无线表格模型的结果
+                if (
+                    wired_len <= round(wireless_len * 0.5)  # 有线模型检测到的单元格数太少(低于无线模型的50%)
+                    or ((wireless_len < wired_len) and (wired_len < (2 * wireless_len)) and table_cls_score <= 0.949)  # 有线模型检测到的单元格数反而更多
+                    or (0 <= gap_of_len <= 5 and wired_len <= round(wireless_len * 0.75))  # 两者相差不大但有线模型结果较少
+                    or (gap_of_len == 0 and wired_len <= 4)  # 单元格数量完全相等且总量小于等于4
+                ):
+                    # logger.debug("fall back to wireless table model")
+                    html_code = wireless_html_code
+                    table_cell_bboxes = wireless_table_cell_bboxes
+                    logic_points = wireless_logic_points
+                else:
+                    html_code = wired_html_code
+                    table_cell_bboxes = wired_table_cell_bboxes
+                    logic_points = wired_logic_points
+
+                elapse = wired_elapse + wireless_elapse
                 return html_code, table_cell_bboxes, logic_points, elapse
             except Exception as e:
                 logger.exception(e)

+ 6 - 1
mineru/utils/enum_class.py

@@ -66,4 +66,9 @@ class ModelPath:
 
 class SplitFlag:
     CROSS_PAGE = 'cross_page'
-    LINES_DELETED = 'lines_deleted'
+    LINES_DELETED = 'lines_deleted'
+
+
+class ImageType:
+    PIL = 'pil_img'
+    BASE64 = 'base64_img'

+ 1 - 1
mineru/utils/model_utils.py

@@ -290,7 +290,7 @@ def remove_overlaps_low_confidence_blocks(combined_res_list, overlap_threshold=0
                                                                              overlap_threshold)]
 
         # 如果内部有3个及以上的小block
-        if len(blocks_inside) >= 3:
+        if len(blocks_inside) >= 2:
             # 计算小block的平均分数
             avg_score = sum(s for _, s, _ in blocks_inside) / len(blocks_inside)
 

+ 19 - 13
mineru/utils/pdf_classify.py

@@ -24,11 +24,11 @@ def classify(pdf_bytes):
     Returns:
         str: 'txt' 表示可以直接提取文本,'ocr' 表示需要OCR
     """
-    try:
-        # 从字节数据加载PDF
-        sample_pdf_bytes = extract_pages(pdf_bytes)
-        pdf = pdfium.PdfDocument(sample_pdf_bytes)
 
+    # 从字节数据加载PDF
+    sample_pdf_bytes = extract_pages(pdf_bytes)
+    pdf = pdfium.PdfDocument(sample_pdf_bytes)
+    try:
         # 获取PDF页数
         page_count = len(pdf)
 
@@ -42,19 +42,25 @@ def classify(pdf_bytes):
         # 设置阈值:如果每页平均少于50个有效字符,认为需要OCR
         chars_threshold = 50
 
+        # 检查平均字符数和无效字符
         if (get_avg_cleaned_chars_per_page(pdf, pages_to_check) < chars_threshold) or detect_invalid_chars(sample_pdf_bytes):
             return 'ocr'
-        else:
 
-            if get_high_image_coverage_ratio(sample_pdf_bytes, pages_to_check) >= 0.8:
-                return 'ocr'
+        # 检查图像覆盖率
+        if get_high_image_coverage_ratio(sample_pdf_bytes, pages_to_check) >= 0.8:
+            return 'ocr'
+
+        return 'txt'
 
-            return 'txt'
     except Exception as e:
         logger.error(f"判断PDF类型时出错: {e}")
         # 出错时默认使用OCR
         return 'ocr'
 
+    finally:
+        # 无论执行哪个路径,都确保PDF被关闭
+        pdf.close()
+
 
 def get_avg_cleaned_chars_per_page(pdf_doc, pages_to_check):
     # 总字符数
@@ -78,8 +84,6 @@ def get_avg_cleaned_chars_per_page(pdf_doc, pages_to_check):
 
     # logger.debug(f"PDF分析: 平均每页清理后{avg_cleaned_chars_per_page:.1f}字符")
 
-    pdf_doc.close()  # 关闭PDF文档
-
     return avg_cleaned_chars_per_page
 
 
@@ -158,6 +162,9 @@ def get_high_image_coverage_ratio(sample_pdf_bytes, pages_to_check):
 
         page_count += 1
 
+    # 关闭资源
+    pdf_stream.close()
+
     # 如果没有处理任何页面,返回0
     if page_count == 0:
         return 0.0
@@ -166,9 +173,6 @@ def get_high_image_coverage_ratio(sample_pdf_bytes, pages_to_check):
     high_coverage_ratio = high_image_coverage_pages / page_count
     # logger.debug(f"PDF分析: 高图像覆盖页面比例: {high_coverage_ratio:.2f}")
 
-    # 关闭资源
-    pdf_stream.close()
-
     return high_coverage_ratio
 
 
@@ -205,6 +209,7 @@ def extract_pages(src_pdf_bytes: bytes) -> bytes:
     try:
         # 将选择的页面导入新文档
         sample_docs.import_pages(pdf, page_indices)
+        pdf.close()
 
         # 将新PDF保存到内存缓冲区
         output_buffer = BytesIO()
@@ -213,6 +218,7 @@ def extract_pages(src_pdf_bytes: bytes) -> bytes:
         # 获取字节数据
         return output_buffer.getvalue()
     except Exception as e:
+        pdf.close()
         logger.exception(e)
         return b''  # 出错时返回空字节
 

+ 10 - 6
mineru/utils/pdf_image_tools.py

@@ -7,27 +7,30 @@ from PIL import Image
 
 from mineru.data.data_reader_writer import FileBasedDataWriter
 from mineru.utils.pdf_reader import image_to_b64str, image_to_bytes, page_to_image
+from .enum_class import ImageType
 from .hash_utils import str_sha256
 
 
-def pdf_page_to_image(page: pdfium.PdfPage, dpi=200) -> dict:
+def pdf_page_to_image(page: pdfium.PdfPage, dpi=200, image_type=ImageType.PIL) -> dict:
     """Convert pdfium.PdfDocument to image, Then convert the image to base64.
 
     Args:
         page (_type_): pdfium.PdfPage
         dpi (int, optional): reset the dpi of dpi. Defaults to 200.
+        image_type (ImageType, optional): The type of image to return. Defaults to ImageType.PIL.
 
     Returns:
         dict:  {'img_base64': str, 'img_pil': pil_img, 'scale': float }
     """
     pil_img, scale = page_to_image(page, dpi=dpi)
-    img_base64 = image_to_b64str(pil_img)
-
     image_dict = {
-        "img_base64": img_base64,
-        "img_pil": pil_img,
         "scale": scale,
     }
+    if image_type == ImageType.BASE64:
+        image_dict["img_base64"] = image_to_b64str(pil_img)
+    else:
+        image_dict["img_pil"] = pil_img
+
     return image_dict
 
 
@@ -36,6 +39,7 @@ def load_images_from_pdf(
     dpi=200,
     start_page_id=0,
     end_page_id=None,
+    image_type=ImageType.PIL,  # PIL or BASE64
 ):
     images_list = []
     pdf_doc = pdfium.PdfDocument(pdf_bytes)
@@ -48,7 +52,7 @@ def load_images_from_pdf(
     for index in range(0, pdf_page_num):
         if start_page_id <= index <= end_page_id:
             page = pdf_doc[index]
-            image_dict = pdf_page_to_image(page, dpi=dpi)
+            image_dict = pdf_page_to_image(page, dpi=dpi, image_type=image_type)
             images_list.append(image_dict)
 
     return images_list, pdf_doc

+ 32 - 18
mineru/utils/pdf_reader.py

@@ -9,8 +9,8 @@ from pypdfium2 import PdfBitmap, PdfDocument, PdfPage
 
 def page_to_image(
     page: PdfPage,
-    dpi: int = 144,  # changed from 200 to 144
-    max_width_or_height: int = 2560,  # changed from 4500 to 2560
+    dpi: int = 200,
+    max_width_or_height: int = 3500,  # changed from 4500 to 3500
 ) -> (Image.Image, float):
     scale = dpi / 72
 
@@ -19,19 +19,21 @@ def page_to_image(
         scale = max_width_or_height / long_side_length
 
     bitmap: PdfBitmap = page.render(scale=scale)  # type: ignore
+
+    image = bitmap.to_pil()
     try:
-        image = bitmap.to_pil()
-    finally:
-        try:
-            bitmap.close()
-        except Exception:
-            pass
+        bitmap.close()
+    except Exception as e:
+        logger.error(f"Failed to close bitmap: {e}")
     return image, scale
 
 
+
+
 def image_to_bytes(
     image: Image.Image,
-    image_format: str = "PNG",  # 也可以用 "JPEG"
+    # image_format: str = "PNG",  # 也可以用 "JPEG"
+    image_format: str = "JPEG",
 ) -> bytes:
     with BytesIO() as image_buffer:
         image.save(image_buffer, format=image_format)
@@ -40,16 +42,26 @@ def image_to_bytes(
 
 def image_to_b64str(
     image: Image.Image,
-    image_format: str = "PNG",  # 也可以用 "JPEG"
+    # image_format: str = "PNG",  # 也可以用 "JPEG"
+    image_format: str = "JPEG",
 ) -> str:
     image_bytes = image_to_bytes(image, image_format)
     return base64.b64encode(image_bytes).decode("utf-8")
 
 
+def base64_to_pil_image(
+    base64_str: str,
+) -> Image.Image:
+    """Convert base64 string to PIL Image."""
+    image_bytes = base64.b64decode(base64_str)
+    with BytesIO(image_bytes) as image_buffer:
+        return Image.open(image_buffer).convert("RGB")
+
+
 def pdf_to_images(
     pdf: str | bytes | PdfDocument,
-    dpi: int = 144,
-    max_width_or_height: int = 2560,
+    dpi: int = 200,
+    max_width_or_height: int = 3500,
     start_page_id: int = 0,
     end_page_id: int | None = None,
 ) -> list[Image.Image]:
@@ -76,11 +88,12 @@ def pdf_to_images(
 
 def pdf_to_images_bytes(
     pdf: str | bytes | PdfDocument,
-    dpi: int = 144,
-    max_width_or_height: int = 2560,
+    dpi: int = 200,
+    max_width_or_height: int = 3500,
     start_page_id: int = 0,
     end_page_id: int | None = None,
-    image_format: str = "PNG",
+    # image_format: str = "PNG",  # 也可以用 "JPEG"
+    image_format: str = "JPEG",
 ) -> list[bytes]:
     images = pdf_to_images(pdf, dpi, max_width_or_height, start_page_id, end_page_id)
     return [image_to_bytes(image, image_format) for image in images]
@@ -88,11 +101,12 @@ def pdf_to_images_bytes(
 
 def pdf_to_images_b64strs(
     pdf: str | bytes | PdfDocument,
-    dpi: int = 144,
-    max_width_or_height: int = 2560,
+    dpi: int = 200,
+    max_width_or_height: int = 3500,
     start_page_id: int = 0,
     end_page_id: int | None = None,
-    image_format: str = "PNG",
+    # image_format: str = "PNG",  # 也可以用 "JPEG"
+    image_format: str = "JPEG",
 ) -> list[str]:
     images = pdf_to_images(pdf, dpi, max_width_or_height, start_page_id, end_page_id)
     return [image_to_b64str(image, image_format) for image in images]