Explorar el Código

feat: improve heading level feature with enhanced configuration and error handling

myhloli hace 4 meses
padre
commit
2d742bcaa3
Se han modificado 1 ficheros con 38 adiciones y 44 borrados
  1. 38 44
      mineru/backend/vlm/token_to_middle_json.py

+ 38 - 44
mineru/backend/vlm/token_to_middle_json.py

@@ -11,13 +11,18 @@ from mineru.utils.pdf_image_tools import get_crop_img
 from mineru.version import __version__
 
 heading_level_import_success = False
-try:
-    from mineru.utils.llm_aided import llm_aided_title
-    from mineru.backend.pipeline.model_init import AtomModelSingleton
-    heading_level_import_success = True
-except Exception as e:
-    logger.warning("The heading level feature cannot be used. If you need to use the heading level feature, "
-                   "please execute `pip install mineru[pipeline]` to install the required packages.")
+llm_aided_config = get_llm_aided_config()
+if llm_aided_config is not None:
+    title_aided_config = llm_aided_config.get('title_aided', None)
+    if title_aided_config is not None:
+        if title_aided_config.get('enable', False):
+            try:
+                from mineru.utils.llm_aided import llm_aided_title
+                from mineru.backend.pipeline.model_init import AtomModelSingleton
+                heading_level_import_success = True
+            except Exception as e:
+                logger.warning("The heading level feature cannot be used. If you need to use the heading level feature, "
+                               "please execute `pip install mineru[core]` to install the required packages.")
 
 
 def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dict:
@@ -38,32 +43,27 @@ def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dic
     title_blocks = magic_model.get_title_blocks()
 
     # 如果有标题优化需求,则对title_blocks截图det
-    llm_aided_config = get_llm_aided_config()
-    if llm_aided_config is not None:
-        title_aided_config = llm_aided_config.get('title_aided', None)
-        if title_aided_config is not None:
-            if title_aided_config.get('enable', False):
-                if heading_level_import_success:
-                    atom_model_manager = AtomModelSingleton()
-                    ocr_model = atom_model_manager.get_atom_model(
-                        atom_model_name='ocr',
-                        ocr_show_log=False,
-                        det_db_box_thresh=0.3,
-                        lang='ch_lite'
-                    )
-                    for title_block in title_blocks:
-                        title_pil_img = get_crop_img(title_block['bbox'], page_pil_img, scale)
-                        title_np_img = np.array(title_pil_img)
-                        # 给title_pil_img添加上下左右各50像素白边padding
-                        title_np_img = cv2.copyMakeBorder(
-                            title_np_img, 50, 50, 50, 50, cv2.BORDER_CONSTANT, value=[255, 255, 255]
-                        )
-                        title_img = cv2.cvtColor(title_np_img, cv2.COLOR_RGB2BGR)
-                        ocr_det_res = ocr_model.ocr(title_img, rec=False)[0]
-                        if len(ocr_det_res) > 0:
-                            # 计算所有res的平均高度
-                            avg_height = np.mean([box[2][1] - box[0][1] for box in ocr_det_res])
-                            title_block['line_avg_height'] = round(avg_height/scale)
+    if heading_level_import_success:
+        atom_model_manager = AtomModelSingleton()
+        ocr_model = atom_model_manager.get_atom_model(
+            atom_model_name='ocr',
+            ocr_show_log=False,
+            det_db_box_thresh=0.3,
+            lang='ch_lite'
+        )
+        for title_block in title_blocks:
+            title_pil_img = get_crop_img(title_block['bbox'], page_pil_img, scale)
+            title_np_img = np.array(title_pil_img)
+            # 给title_pil_img添加上下左右各50像素白边padding
+            title_np_img = cv2.copyMakeBorder(
+                title_np_img, 50, 50, 50, 50, cv2.BORDER_CONSTANT, value=[255, 255, 255]
+            )
+            title_img = cv2.cvtColor(title_np_img, cv2.COLOR_RGB2BGR)
+            ocr_det_res = ocr_model.ocr(title_img, rec=False)[0]
+            if len(ocr_det_res) > 0:
+                # 计算所有res的平均高度
+                avg_height = np.mean([box[2][1] - box[0][1] for box in ocr_det_res])
+                title_block['line_avg_height'] = round(avg_height/scale)
 
     text_blocks = magic_model.get_text_blocks()
     interline_equation_blocks = magic_model.get_interline_equation_blocks()
@@ -91,17 +91,11 @@ def result_to_middle_json(token_list, images_list, pdf_doc, image_writer):
         page_info = token_to_page_info(token, image_dict, page, image_writer, index)
         middle_json["pdf_info"].append(page_info)
 
-    """llm优化"""
-    llm_aided_config = get_llm_aided_config()
-    if llm_aided_config is not None:
-        """标题优化"""
-        title_aided_config = llm_aided_config.get('title_aided', None)
-        if title_aided_config is not None:
-            if title_aided_config.get('enable', False):
-                if heading_level_import_success:
-                    llm_aided_title_start_time = time.time()
-                    llm_aided_title(middle_json["pdf_info"], title_aided_config)
-                    logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}')
+    """llm优化标题分级"""
+    if heading_level_import_success:
+        llm_aided_title_start_time = time.time()
+        llm_aided_title(middle_json["pdf_info"], title_aided_config)
+        logger.info(f'llm aided title time: {round(time.time() - llm_aided_title_start_time, 2)}')
 
     # 关闭pdf文档
     pdf_doc.close()