소스 검색

feat: enhance heading level feature with conditional imports and error handling

myhloli 4 달 전
부모
커밋
7e6926ffb8
2개의 변경된 파일37개의 추가작업 그리고 30개의 파일을 삭제
  1. 35 28
      mineru/backend/vlm/token_to_middle_json.py
  2. 2 2
      pyproject.toml

+ 35 - 28
mineru/backend/vlm/token_to_middle_json.py

@@ -1,18 +1,24 @@
 import time
-import cv2
-import numpy as np
 from loguru import logger
-
-from mineru.backend.pipeline.model_init import AtomModelSingleton
+import numpy as np
+import cv2
 from mineru.utils.config_reader import get_llm_aided_config
 from mineru.utils.cut_image import cut_image_and_table
 from mineru.utils.enum_class import ContentType
 from mineru.utils.hash_utils import str_md5
 from mineru.backend.vlm.vlm_magic_model import MagicModel
-from mineru.utils.llm_aided import llm_aided_title
 from mineru.utils.pdf_image_tools import get_crop_img
 from mineru.version import __version__
 
+heading_level_import_success = False
+try:
+    from mineru.utils.llm_aided import llm_aided_title
+    from mineru.backend.pipeline.model_init import AtomModelSingleton
+    heading_level_import_success = True
+except Exception as e:
+    logger.warning("The heading level feature cannot be used. If you need to use the heading level feature, "
+                   "please execute `pip install mineru[pipeline]` to install the required packages.")
+
 
 def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dict:
     """将token转换为页面信息"""
@@ -37,26 +43,27 @@ def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dic
         title_aided_config = llm_aided_config.get('title_aided', None)
         if title_aided_config is not None:
             if title_aided_config.get('enable', False):
-                atom_model_manager = AtomModelSingleton()
-                ocr_model = atom_model_manager.get_atom_model(
-                    atom_model_name='ocr',
-                    ocr_show_log=False,
-                    det_db_box_thresh=0.3,
-                    lang='ch_lite'
-                )
-                for title_block in title_blocks:
-                    title_pil_img = get_crop_img(title_block['bbox'], page_pil_img, scale)
-                    title_np_img = np.array(title_pil_img)
-                    # 给title_pil_img添加上下左右各50像素白边padding
-                    title_np_img = cv2.copyMakeBorder(
-                        title_np_img, 50, 50, 50, 50, cv2.BORDER_CONSTANT, value=[255, 255, 255]
+                if heading_level_import_success:
+                    atom_model_manager = AtomModelSingleton()
+                    ocr_model = atom_model_manager.get_atom_model(
+                        atom_model_name='ocr',
+                        ocr_show_log=False,
+                        det_db_box_thresh=0.3,
+                        lang='ch_lite'
                     )
-                    title_img = cv2.cvtColor(title_np_img, cv2.COLOR_RGB2BGR)
-                    ocr_det_res = ocr_model.ocr(title_img, rec=False)[0]
-                    if len(ocr_det_res) > 0:
-                        # 计算所有res的平均高度
-                        avg_height = np.mean([box[2][1] - box[0][1] for box in ocr_det_res])
-                        title_block['line_avg_height'] = round(avg_height/scale)
+                    for title_block in title_blocks:
+                        title_pil_img = get_crop_img(title_block['bbox'], page_pil_img, scale)
+                        title_np_img = np.array(title_pil_img)
+                        # 给title_pil_img添加上下左右各50像素白边padding
+                        title_np_img = cv2.copyMakeBorder(
+                            title_np_img, 50, 50, 50, 50, cv2.BORDER_CONSTANT, value=[255, 255, 255]
+                        )
+                        title_img = cv2.cvtColor(title_np_img, cv2.COLOR_RGB2BGR)
+                        ocr_det_res = ocr_model.ocr(title_img, rec=False)[0]
+                        if len(ocr_det_res) > 0:
+                            # 计算所有res的平均高度
+                            avg_height = np.mean([box[2][1] - box[0][1] for box in ocr_det_res])
+                            title_block['line_avg_height'] = round(avg_height/scale)
 
     text_blocks = magic_model.get_text_blocks()
     interline_equation_blocks = magic_model.get_interline_equation_blocks()
@@ -86,15 +93,15 @@ def result_to_middle_json(token_list, images_list, pdf_doc, image_writer):
 
     """llm优化"""
     llm_aided_config = get_llm_aided_config()
-
     if llm_aided_config is not None:
         """标题优化"""
         title_aided_config = llm_aided_config.get('title_aided', None)
         if title_aided_config is not None:
             if title_aided_config.get('enable', False):
-                llm_aided_title_start_time = time.time()
-                llm_aided_title(middle_json["pdf_info"], title_aided_config)
-                logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}')
+                if heading_level_import_success:
+                    llm_aided_title_start_time = time.time()
+                    llm_aided_title(middle_json["pdf_info"], title_aided_config)
+                    logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}')
 
     # 关闭pdf文档
     pdf_doc.close()

+ 2 - 2
pyproject.toml

@@ -33,6 +33,8 @@ dependencies = [
     "modelscope>=1.26.0",
     "huggingface-hub>=0.32.4",
     "json-repair>=0.46.2",
+    "opencv-python>=4.11.0.86",
+    "fast-langdetect>=0.2.3,<0.3.0",
 ]
 
 [project.optional-dependencies]
@@ -60,7 +62,6 @@ pipeline = [
     "torch>=2.2.2,!=2.5.0,!=2.5.1,<3",
     "torchvision",
     "transformers>=4.49.0,!=4.51.0,<5.0.0",
-    "fast-langdetect>=0.2.3,<0.3.0",
 ]
 api = [
     "fastapi",
@@ -97,7 +98,6 @@ pipeline_old_linux = [
     "torch>=2.2.2,!=2.5.0,!=2.5.1,<3",
     "torchvision",
     "transformers>=4.49.0,!=4.51.0,<5.0.0",
-    "fast-langdetect>=0.2.3,<0.3.0",
 ]
 
 [project.urls]