Browse Source

Merge pull request #2948 from myhloli/dev

feat: enhance heading level feature with conditional imports and error handling
Xiaomeng Zhao 4 tháng trước cách đây
mục cha
commit
33bea91064
3 tập tin đã thay đổi với 59 bổ sung32 xóa
  1. 35 28
      mineru/backend/vlm/token_to_middle_json.py
  2. 22 2
      mineru/cli/gradio_app.py
  3. 2 2
      pyproject.toml

+ 35 - 28
mineru/backend/vlm/token_to_middle_json.py

@@ -1,18 +1,24 @@
 import time
-import cv2
-import numpy as np
 from loguru import logger
-
-from mineru.backend.pipeline.model_init import AtomModelSingleton
+import numpy as np
+import cv2
 from mineru.utils.config_reader import get_llm_aided_config
 from mineru.utils.cut_image import cut_image_and_table
 from mineru.utils.enum_class import ContentType
 from mineru.utils.hash_utils import str_md5
 from mineru.backend.vlm.vlm_magic_model import MagicModel
-from mineru.utils.llm_aided import llm_aided_title
 from mineru.utils.pdf_image_tools import get_crop_img
 from mineru.version import __version__
 
+heading_level_import_success = False
+try:
+    from mineru.utils.llm_aided import llm_aided_title
+    from mineru.backend.pipeline.model_init import AtomModelSingleton
+    heading_level_import_success = True
+except Exception as e:
+    logger.warning("The heading level feature cannot be used. If you need to use the heading level feature, "
+                   "please execute `pip install mineru[pipeline]` to install the required packages.")
+
 
 def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dict:
     """将token转换为页面信息"""
@@ -37,26 +43,27 @@ def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dic
         title_aided_config = llm_aided_config.get('title_aided', None)
         if title_aided_config is not None:
             if title_aided_config.get('enable', False):
-                atom_model_manager = AtomModelSingleton()
-                ocr_model = atom_model_manager.get_atom_model(
-                    atom_model_name='ocr',
-                    ocr_show_log=False,
-                    det_db_box_thresh=0.3,
-                    lang='ch_lite'
-                )
-                for title_block in title_blocks:
-                    title_pil_img = get_crop_img(title_block['bbox'], page_pil_img, scale)
-                    title_np_img = np.array(title_pil_img)
-                    # 给title_pil_img添加上下左右各50像素白边padding
-                    title_np_img = cv2.copyMakeBorder(
-                        title_np_img, 50, 50, 50, 50, cv2.BORDER_CONSTANT, value=[255, 255, 255]
+                if heading_level_import_success:
+                    atom_model_manager = AtomModelSingleton()
+                    ocr_model = atom_model_manager.get_atom_model(
+                        atom_model_name='ocr',
+                        ocr_show_log=False,
+                        det_db_box_thresh=0.3,
+                        lang='ch_lite'
                     )
-                    title_img = cv2.cvtColor(title_np_img, cv2.COLOR_RGB2BGR)
-                    ocr_det_res = ocr_model.ocr(title_img, rec=False)[0]
-                    if len(ocr_det_res) > 0:
-                        # 计算所有res的平均高度
-                        avg_height = np.mean([box[2][1] - box[0][1] for box in ocr_det_res])
-                        title_block['line_avg_height'] = round(avg_height/scale)
+                    for title_block in title_blocks:
+                        title_pil_img = get_crop_img(title_block['bbox'], page_pil_img, scale)
+                        title_np_img = np.array(title_pil_img)
+                        # 给title_pil_img添加上下左右各50像素白边padding
+                        title_np_img = cv2.copyMakeBorder(
+                            title_np_img, 50, 50, 50, 50, cv2.BORDER_CONSTANT, value=[255, 255, 255]
+                        )
+                        title_img = cv2.cvtColor(title_np_img, cv2.COLOR_RGB2BGR)
+                        ocr_det_res = ocr_model.ocr(title_img, rec=False)[0]
+                        if len(ocr_det_res) > 0:
+                            # 计算所有res的平均高度
+                            avg_height = np.mean([box[2][1] - box[0][1] for box in ocr_det_res])
+                            title_block['line_avg_height'] = round(avg_height/scale)
 
     text_blocks = magic_model.get_text_blocks()
     interline_equation_blocks = magic_model.get_interline_equation_blocks()
@@ -86,15 +93,15 @@ def result_to_middle_json(token_list, images_list, pdf_doc, image_writer):
 
     """llm优化"""
     llm_aided_config = get_llm_aided_config()
-
     if llm_aided_config is not None:
         """标题优化"""
         title_aided_config = llm_aided_config.get('title_aided', None)
         if title_aided_config is not None:
             if title_aided_config.get('enable', False):
-                llm_aided_title_start_time = time.time()
-                llm_aided_title(middle_json["pdf_info"], title_aided_config)
-                logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}')
+                if heading_level_import_success:
+                    llm_aided_title_start_time = time.time()
+                    llm_aided_title(middle_json["pdf_info"], title_aided_config)
+                    logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}')
 
     # 关闭pdf文档
     pdf_doc.close()

+ 22 - 2
mineru/cli/gradio_app.py

@@ -114,12 +114,15 @@ async def to_markdown(file_path, end_pages=10, is_ocr=False, formula_enable=True
     return md_content, txt_content, archive_zip_path, new_pdf_path
 
 
-latex_delimiters = [
+latex_delimiters_type_a = [
     {'left': '$$', 'right': '$$', 'display': True},
     {'left': '$', 'right': '$', 'display': False},
+]
+latex_delimiters_type_b = [
     {'left': '\\(', 'right': '\\)', 'display': False},
     {'left': '\\[', 'right': '\\]', 'display': True},
 ]
+latex_delimiters_type_all = latex_delimiters_type_a + latex_delimiters_type_b
 
 header_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'resources', 'header.html')
 with open(header_path, 'r') as header_file:
@@ -234,13 +237,30 @@ def update_interface(backend_choice):
     help="Set the server port for the Gradio app.",
     default=None,
 )
+@click.option(
+    '--latex-delimiters-type',
+    'latex_delimiters_type',
+    type=click.Choice(['a', 'b', 'all']),
+    help="Set the type of LaTeX delimiters to use in Markdown rendering:"
+         "'a' for type '$', 'b' for type '()[]', 'all' for both types.",
+    default='all',
+)
 def main(ctx,
         example_enable, sglang_engine_enable, api_enable, max_convert_pages,
-        server_name, server_port, **kwargs
+        server_name, server_port, latex_delimiters_type, **kwargs
 ):
 
     kwargs.update(arg_parse(ctx))
 
+    if latex_delimiters_type == 'a':
+        latex_delimiters = latex_delimiters_type_a
+    elif latex_delimiters_type == 'b':
+        latex_delimiters = latex_delimiters_type_b
+    elif latex_delimiters_type == 'all':
+        latex_delimiters = latex_delimiters_type_all
+    else:
+        raise ValueError(f"Invalid latex delimiters type: {latex_delimiters_type}.")
+
     if sglang_engine_enable:
         try:
             print("Start init SgLang engine...")

+ 2 - 2
pyproject.toml

@@ -33,6 +33,8 @@ dependencies = [
     "modelscope>=1.26.0",
     "huggingface-hub>=0.32.4",
     "json-repair>=0.46.2",
+    "opencv-python>=4.11.0.86",
+    "fast-langdetect>=0.2.3,<0.3.0",
 ]
 
 [project.optional-dependencies]
@@ -60,7 +62,6 @@ pipeline = [
     "torch>=2.2.2,!=2.5.0,!=2.5.1,<3",
     "torchvision",
     "transformers>=4.49.0,!=4.51.0,<5.0.0",
-    "fast-langdetect>=0.2.3,<0.3.0",
 ]
 api = [
     "fastapi",
@@ -97,7 +98,6 @@ pipeline_old_linux = [
     "torch>=2.2.2,!=2.5.0,!=2.5.1,<3",
     "torchvision",
     "transformers>=4.49.0,!=4.51.0,<5.0.0",
-    "fast-langdetect>=0.2.3,<0.3.0",
 ]
 
 [project.urls]