Browse Source

Merge pull request #2850 from myhloli/dev

Dev
Xiaomeng Zhao 4 months ago
parent
commit
21bd73ea03

+ 1 - 1
mineru/backend/pipeline/batch_analyze.py

@@ -323,7 +323,7 @@ class BatchAnalyze:
                                                layout_res_item['poly'][4], layout_res_item['poly'][5]]
                             layout_res_width = layout_res_bbox[2] - layout_res_bbox[0]
                             layout_res_height = layout_res_bbox[3] - layout_res_bbox[1]
-                            if ocr_text in ['(204号', '(20', '(2', '(2号'] and ocr_score < 0.8 and layout_res_width < layout_res_height:
+                            if ocr_text in ['(204号', '(20', '(2', '(2号', '(20号'] and ocr_score < 0.8 and layout_res_width < layout_res_height:
                                 layout_res_item['category_id'] = 16
 
                     total_processed += len(img_crop_list)

+ 50 - 2
mineru/backend/vlm/token_to_middle_json.py

@@ -1,9 +1,16 @@
-import re
+import time
+import cv2
+import numpy as np
+from loguru import logger
 
+from mineru.backend.pipeline.model_init import AtomModelSingleton
+from mineru.utils.config_reader import get_llm_aided_config
 from mineru.utils.cut_image import cut_image_and_table
-from mineru.utils.enum_class import BlockType, ContentType
+from mineru.utils.enum_class import ContentType
 from mineru.utils.hash_utils import str_md5
 from mineru.backend.vlm.vlm_magic_model import MagicModel
+from mineru.utils.llm_aided import llm_aided_title
+from mineru.utils.pdf_image_tools import get_crop_img
 from mineru.version import __version__
 
 
@@ -23,6 +30,34 @@ def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dic
     image_blocks = magic_model.get_image_blocks()
     table_blocks = magic_model.get_table_blocks()
     title_blocks = magic_model.get_title_blocks()
+
+    # 如果有标题优化需求,则对title_blocks截图det
+    llm_aided_config = get_llm_aided_config()
+    if llm_aided_config is not None:
+        title_aided_config = llm_aided_config.get('title_aided', None)
+        if title_aided_config is not None:
+            if title_aided_config.get('enable', False):
+                atom_model_manager = AtomModelSingleton()
+                ocr_model = atom_model_manager.get_atom_model(
+                    atom_model_name='ocr',
+                    ocr_show_log=False,
+                    det_db_box_thresh=0.3,
+                    lang='ch_lite'
+                )
+                for title_block in title_blocks:
+                    title_pil_img = get_crop_img(title_block['bbox'], page_pil_img, scale)
+                    title_np_img = np.array(title_pil_img)
+                    # 给title_pil_img添加上下左右各50像素白边padding
+                    title_np_img = cv2.copyMakeBorder(
+                        title_np_img, 50, 50, 50, 50, cv2.BORDER_CONSTANT, value=[255, 255, 255]
+                    )
+                    title_img = cv2.cvtColor(title_np_img, cv2.COLOR_RGB2BGR)
+                    ocr_det_res = ocr_model.ocr(title_img, rec=False)[0]
+                    if len(ocr_det_res) > 0:
+                        # 计算所有res的平均高度
+                        avg_height = np.mean([box[2][1] - box[0][1] for box in ocr_det_res])
+                        title_block['line_avg_height'] = round(avg_height/scale)
+
     text_blocks = magic_model.get_text_blocks()
     interline_equation_blocks = magic_model.get_interline_equation_blocks()
 
@@ -48,6 +83,19 @@ def result_to_middle_json(token_list, images_list, pdf_doc, image_writer):
         image_dict = images_list[index]
         page_info = token_to_page_info(token, image_dict, page, image_writer, index)
         middle_json["pdf_info"].append(page_info)
+
+    """llm优化"""
+    llm_aided_config = get_llm_aided_config()
+
+    if llm_aided_config is not None:
+        """标题优化"""
+        title_aided_config = llm_aided_config.get('title_aided', None)
+        if title_aided_config is not None:
+            if title_aided_config.get('enable', False):
+                llm_aided_title_start_time = time.time()
+                llm_aided_title(middle_json["pdf_info"], title_aided_config)
+                logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}')
+
     # 关闭pdf文档
     pdf_doc.close()
     return middle_json

+ 11 - 4
mineru/cli/gradio_app.py

@@ -209,14 +209,14 @@ def update_interface(backend_choice):
     'mem_fraction_static',
     type=float,
     help="Set the static memory fraction for SgLang engine. ",
-    default=0.5,
+    default=None,
 )
 @click.option(
     '--enable-torch-compile',
     'torch_compile_enable',
     type=bool,
     help="Enable torch compile for SgLang engine. ",
-    default=True,
+    default=False,
 )
 @click.option(
     '--enable-api',
@@ -231,12 +231,19 @@ def main(example_enable, sglang_engine_enable, mem_fraction_static, torch_compil
             print("Start init SgLang engine...")
             from mineru.backend.vlm.vlm_analyze import ModelSingleton
             modelsingleton = ModelSingleton()
+
+            model_params = {
+                "enable_torch_compile": torch_compile_enable
+            }
+            # 只有当mem_fraction_static不为None时才添加该参数
+            if mem_fraction_static is not None:
+                model_params["mem_fraction_static"] = mem_fraction_static
+
             predictor = modelsingleton.get_model(
                 "sglang-engine",
                 None,
                 None,
-                mem_fraction_static=mem_fraction_static,
-                enable_torch_compile=torch_compile_enable,
+                **model_params
             )
             print("SgLang engine init successfully.")
         except Exception as e:

+ 13 - 9
mineru/utils/llm_aided.py

@@ -1,7 +1,7 @@
 # Copyright (c) Opendatalab. All rights reserved.
 from loguru import logger
 from openai import OpenAI
-import ast
+import json_repair
 
 from mineru.backend.pipeline.pipeline_middle_json_mkcontent import merge_para_with_text
 
@@ -20,14 +20,19 @@ def llm_aided_title(page_info_list, title_aided_config):
             if block["type"] == "title":
                 origin_title_list.append(block)
                 title_text = merge_para_with_text(block)
-                page_line_height_list = []
-                for line in block['lines']:
-                    bbox = line['bbox']
-                    page_line_height_list.append(int(bbox[3] - bbox[1]))
-                if len(page_line_height_list) > 0:
-                    line_avg_height = sum(page_line_height_list) / len(page_line_height_list)
+
+                if 'line_avg_height' in block:
+                    line_avg_height = block['line_avg_height']
                 else:
-                    line_avg_height = int(block['bbox'][3] - block['bbox'][1])
+                    title_block_line_height_list = []
+                    for line in block['lines']:
+                        bbox = line['bbox']
+                        title_block_line_height_list.append(int(bbox[3] - bbox[1]))
+                    if len(title_block_line_height_list) > 0:
+                        line_avg_height = sum(title_block_line_height_list) / len(title_block_line_height_list)
+                    else:
+                        line_avg_height = int(block['bbox'][3] - block['bbox'][1])
+
                 title_dict[f"{i}"] = [title_text, line_avg_height, int(page_info['page_idx']) + 1]
                 i += 1
     # logger.info(f"Title list: {title_dict}")
@@ -91,7 +96,6 @@ Corrected title list:
             if "</think>" in content:
                 idx = content.index("</think>") + len("</think>")
                 content = content[idx:].strip()
-            import json_repair
             dict_completion = json_repair.loads(content)
             dict_completion = {int(k): int(v) for k, v in dict_completion.items()}