Parcourir la source

feat: enhance title block processing with average height calculation and padding for OCR

myhloli il y a 4 mois
Parent
commit
06db3d173b
2 fichiers modifiés avec 45 ajouts et 8 suppressions
  1. 33 1
      mineru/backend/vlm/token_to_middle_json.py
  2. 12 7
      mineru/utils/llm_aided.py

+ 33 - 1
mineru/backend/vlm/token_to_middle_json.py

@@ -1,12 +1,16 @@
 import time
+import cv2
+import numpy as np
 from loguru import logger
 
+from mineru.backend.pipeline.model_init import AtomModelSingleton
 from mineru.utils.config_reader import get_llm_aided_config
 from mineru.utils.cut_image import cut_image_and_table
-from mineru.utils.enum_class import BlockType, ContentType
+from mineru.utils.enum_class import ContentType
 from mineru.utils.hash_utils import str_md5
 from mineru.backend.vlm.vlm_magic_model import MagicModel
 from mineru.utils.llm_aided import llm_aided_title
+from mineru.utils.pdf_image_tools import get_crop_img
 from mineru.version import __version__
 
 
@@ -26,6 +30,34 @@ def token_to_page_info(token, image_dict, page, image_writer, page_index) -> dic
     image_blocks = magic_model.get_image_blocks()
     table_blocks = magic_model.get_table_blocks()
     title_blocks = magic_model.get_title_blocks()
+
+    # 如果有标题优化需求,则对title_blocks截图det
+    llm_aided_config = get_llm_aided_config()
+    if llm_aided_config is not None:
+        title_aided_config = llm_aided_config.get('title_aided', None)
+        if title_aided_config is not None:
+            if title_aided_config.get('enable', False):
+                atom_model_manager = AtomModelSingleton()
+                ocr_model = atom_model_manager.get_atom_model(
+                    atom_model_name='ocr',
+                    ocr_show_log=False,
+                    det_db_box_thresh=0.3,
+                    lang='ch_lite'
+                )
+                for title_block in title_blocks:
+                    title_pil_img = get_crop_img(title_block['bbox'], page_pil_img, scale)
+                    title_np_img = np.array(title_pil_img)
+                    # 给title_pil_img添加上下左右各50像素白边padding
+                    title_np_img = cv2.copyMakeBorder(
+                        title_np_img, 50, 50, 50, 50, cv2.BORDER_CONSTANT, value=[255, 255, 255]
+                    )
+                    title_img = cv2.cvtColor(title_np_img, cv2.COLOR_RGB2BGR)
+                    ocr_det_res = ocr_model.ocr(title_img, rec=False)[0]
+                    if len(ocr_det_res) > 0:
+                        # 计算所有res的平均高度
+                        avg_height = np.mean([box[2][1] - box[0][1] for box in ocr_det_res])
+                        title_block['line_avg_height'] = round(avg_height/scale)
+
     text_blocks = magic_model.get_text_blocks()
     interline_equation_blocks = magic_model.get_interline_equation_blocks()
 

+ 12 - 7
mineru/utils/llm_aided.py

@@ -20,14 +20,19 @@ def llm_aided_title(page_info_list, title_aided_config):
             if block["type"] == "title":
                 origin_title_list.append(block)
                 title_text = merge_para_with_text(block)
-                page_line_height_list = []
-                for line in block['lines']:
-                    bbox = line['bbox']
-                    page_line_height_list.append(int(bbox[3] - bbox[1]))
-                if len(page_line_height_list) > 0:
-                    line_avg_height = sum(page_line_height_list) / len(page_line_height_list)
+
+                if 'line_avg_height' in block:
+                    line_avg_height = block['line_avg_height']
                 else:
-                    line_avg_height = int(block['bbox'][3] - block['bbox'][1])
+                    title_block_line_height_list = []
+                    for line in block['lines']:
+                        bbox = line['bbox']
+                        title_block_line_height_list.append(int(bbox[3] - bbox[1]))
+                    if len(title_block_line_height_list) > 0:
+                        line_avg_height = sum(title_block_line_height_list) / len(title_block_line_height_list)
+                    else:
+                        line_avg_height = int(block['bbox'][3] - block['bbox'][1])
+
                 title_dict[f"{i}"] = [title_text, line_avg_height, int(page_info['page_idx']) + 1]
                 i += 1
     # logger.info(f"Title list: {title_dict}")