Kaynağa Gözat

feat(layout): improve title block handling and layout detection

- Merge title blocks that are close to each other horizontally
- Adjust line insertion logic for title blocks- Increase image size and decrease confidence threshold for layout detection
- Update DocLayoutYOLO model weights
- Refactor drawing of bounding boxes for different block types
myhloli 10 ay önce
ebeveyn
işleme
c20e9a1e84

+ 1 - 1
docker/ascend_npu/requirements.txt

@@ -17,7 +17,7 @@ paddlepaddle==3.0.0b1
 struct-eqtable==0.3.2
 einops
 accelerate
-doclayout_yolo==0.0.2
+doclayout_yolo==0.0.2b1
 rapidocr-paddle
 rapidocr-onnxruntime
 rapid_table==0.3.0

+ 1 - 1
docker/china/requirements.txt

@@ -16,7 +16,7 @@ paddleocr==2.7.3
 struct-eqtable==0.3.2
 einops
 accelerate
-doclayout_yolo==0.0.2
+doclayout_yolo==0.0.2b1
 rapidocr-paddle
 rapidocr-onnxruntime
 rapid_table==0.3.0

+ 1 - 1
docker/global/requirements.txt

@@ -16,7 +16,7 @@ paddleocr==2.7.3
 struct-eqtable==0.3.2
 einops
 accelerate
-doclayout_yolo==0.0.2
+doclayout_yolo==0.0.2b1
 rapidocr-paddle
 rapidocr-onnxruntime
 rapid_table==0.3.0

+ 14 - 2
magic_pdf/libs/draw_bbox.py

@@ -362,12 +362,24 @@ def draw_line_sort_bbox(pdf_info, pdf_bytes, out_path, filename):
     for page in pdf_info:
         page_line_list = []
         for block in page['preproc_blocks']:
-            if block['type'] in [BlockType.Text, BlockType.Title, BlockType.InterlineEquation]:
+            if block['type'] in [BlockType.Text]:
                 for line in block['lines']:
                     bbox = line['bbox']
                     index = line['index']
                     page_line_list.append({'index': index, 'bbox': bbox})
-            if block['type'] in [BlockType.Image, BlockType.Table]:
+            elif block['type'] in [BlockType.Title, BlockType.InterlineEquation]:
+                if 'virtual_lines' in block:
+                    if len(block['virtual_lines']) > 0 and block['virtual_lines'][0].get('index', None) is not None:
+                        for line in block['virtual_lines']:
+                            bbox = line['bbox']
+                            index = line['index']
+                            page_line_list.append({'index': index, 'bbox': bbox})
+                else:
+                    for line in block['lines']:
+                        bbox = line['bbox']
+                        index = line['index']
+                        page_line_list.append({'index': index, 'bbox': bbox})
+            elif block['type'] in [BlockType.Image, BlockType.Table]:
                 for sub_block in block['blocks']:
                     if sub_block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
                         if len(sub_block['virtual_lines']) > 0 and sub_block['virtual_lines'][0].get('index', None) is not None:

+ 19 - 19
magic_pdf/model/pdf_extract_kit.py

@@ -144,7 +144,7 @@ class CustomPEKModel:
                         model_config_dir, 'layoutlmv3', 'layoutlmv3_base_inference.yaml'
                     )
                 ),
-                device=self.device,
+                device='cpu' if str(self.device).startswith("mps") else self.device,
             )
         elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
             self.layout_model = atom_model_manager.get_atom_model(
@@ -192,24 +192,24 @@ class CustomPEKModel:
             layout_res = self.layout_model(image, ignore_catids=[])
         elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
             # doclayout_yolo
-            if height > width:
-                input_res = {"poly":[0,0,width,0,width,height,0,height]}
-                new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0)
-                paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
-                layout_res = self.layout_model.predict(new_image)
-                for res in layout_res:
-                    p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
-                    p1 = p1 - paste_x + xmin
-                    p2 = p2 - paste_y + ymin
-                    p3 = p3 - paste_x + xmin
-                    p4 = p4 - paste_y + ymin
-                    p5 = p5 - paste_x + xmin
-                    p6 = p6 - paste_y + ymin
-                    p7 = p7 - paste_x + xmin
-                    p8 = p8 - paste_y + ymin
-                    res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
-            else:
-                layout_res = self.layout_model.predict(image)
+            # if height > width:
+            #     input_res = {"poly":[0,0,width,0,width,height,0,height]}
+            #     new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0)
+            #     paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
+            #     layout_res = self.layout_model.predict(new_image)
+            #     for res in layout_res:
+            #         p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
+            #         p1 = p1 - paste_x + xmin
+            #         p2 = p2 - paste_y + ymin
+            #         p3 = p3 - paste_x + xmin
+            #         p4 = p4 - paste_y + ymin
+            #         p5 = p5 - paste_x + xmin
+            #         p6 = p6 - paste_y + ymin
+            #         p7 = p7 - paste_x + xmin
+            #         p8 = p8 - paste_y + ymin
+            #         res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
+            # else:
+            layout_res = self.layout_model.predict(image)
 
         layout_cost = round(time.time() - layout_start, 2)
         logger.info(f'layout detection time: {layout_cost}')

+ 7 - 3
magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py

@@ -9,7 +9,11 @@ class DocLayoutYOLOModel(object):
     def predict(self, image):
         layout_res = []
         doclayout_yolo_res = self.model.predict(
-            image, imgsz=1024, conf=0.25, iou=0.45, verbose=False, device=self.device
+            image,
+            imgsz=1280,
+            conf=0.10,
+            iou=0.45,
+            verbose=False, device=self.device
         )[0]
         for xyxy, conf, cla in zip(
             doclayout_yolo_res.boxes.xyxy.cpu(),
@@ -32,8 +36,8 @@ class DocLayoutYOLOModel(object):
                 image_res.cpu()
                 for image_res in self.model.predict(
                     images[index : index + batch_size],
-                    imgsz=1024,
-                    conf=0.25,
+                    imgsz=1280,
+                    conf=0.10,
                     iou=0.45,
                     verbose=False,
                     device=self.device,

+ 84 - 25
magic_pdf/pdf_parse_union_core_v2.py

@@ -12,7 +12,7 @@ from loguru import logger
 from magic_pdf.config.enums import SupportedPdfParseMethod
 from magic_pdf.config.ocr_content_type import BlockType, ContentType
 from magic_pdf.data.dataset import Dataset, PageableData
-from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio
+from magic_pdf.libs.boxbase import calculate_overlap_area_in_bbox1_area_ratio, __is_overlaps_y_exceeds_threshold
 from magic_pdf.libs.clean_memory import clean_memory
 from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir, get_llm_aided_config, get_device
 from magic_pdf.libs.convert_utils import dict_to_list
@@ -365,10 +365,11 @@ def cal_block_index(fix_blocks, sorted_bboxes):
                 block['index'] = median_value
 
             # 删除图表body block中的虚拟line信息, 并用real_lines信息回填
-            if block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
-                block['virtual_lines'] = copy.deepcopy(block['lines'])
-                block['lines'] = copy.deepcopy(block['real_lines'])
-                del block['real_lines']
+            if block['type'] in [BlockType.ImageBody, BlockType.TableBody, BlockType.Title, BlockType.InterlineEquation]:
+                if 'real_lines' in block:
+                    block['virtual_lines'] = copy.deepcopy(block['lines'])
+                    block['lines'] = copy.deepcopy(block['real_lines'])
+                    del block['real_lines']
     else:
         # 使用xycut排序
         block_bboxes = []
@@ -417,7 +418,7 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
     block_weight = x1 - x0
 
     # 如果block高度小于n行正文,则直接返回block的bbox
-    if line_height * 3 < block_height:
+    if line_height * 2 < block_height:
         if (
             block_height > page_h * 0.25 and page_w * 0.5 > block_weight > page_w * 0.25
         ):  # 可能是双列结构,可以切细点
@@ -425,16 +426,16 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
         else:
             # 如果block的宽度超过0.4页面宽度,则将block分成3行(是一种复杂布局,图不能切的太细)
             if block_weight > page_w * 0.4:
-                line_height = (y1 - y0) / 3
                 lines = 3
+                line_height = (y1 - y0) / lines
             elif block_weight > page_w * 0.25:  # (可能是三列结构,也切细点)
                 lines = int(block_height / line_height) + 1
             else:  # 判断长宽比
                 if block_height / block_weight > 1.2:  # 细长的不分
                     return [[x0, y0, x1, y1]]
                 else:  # 不细长的还是分成两行
-                    line_height = (y1 - y0) / 2
                     lines = 2
+                    line_height = (y1 - y0) / lines
 
         # 确定从哪个y位置开始绘制线条
         current_y = y0
@@ -453,30 +454,32 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
 
 def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
     page_line_list = []
+
+    def add_lines_to_block(b):
+        line_bboxes = insert_lines_into_block(b['bbox'], line_height, page_w, page_h)
+        b['lines'] = []
+        for line_bbox in line_bboxes:
+            b['lines'].append({'bbox': line_bbox, 'spans': []})
+        page_line_list.extend(line_bboxes)
+
     for block in fix_blocks:
         if block['type'] in [
-            BlockType.Text, BlockType.Title, BlockType.InterlineEquation,
+            BlockType.Text, BlockType.Title,
             BlockType.ImageCaption, BlockType.ImageFootnote,
             BlockType.TableCaption, BlockType.TableFootnote
         ]:
             if len(block['lines']) == 0:
-                bbox = block['bbox']
-                lines = insert_lines_into_block(bbox, line_height, page_w, page_h)
-                for line in lines:
-                    block['lines'].append({'bbox': line, 'spans': []})
-                page_line_list.extend(lines)
+                add_lines_to_block(block)
+            elif block['type'] in [BlockType.Title] and len(block['lines']) == 1 and (block['bbox'][3] - block['bbox'][1]) > line_height * 2:
+                block['real_lines'] = copy.deepcopy(block['lines'])
+                add_lines_to_block(block)
             else:
                 for line in block['lines']:
                     bbox = line['bbox']
                     page_line_list.append(bbox)
-        elif block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
-            bbox = block['bbox']
+        elif block['type'] in [BlockType.ImageBody, BlockType.TableBody, BlockType.InterlineEquation]:
             block['real_lines'] = copy.deepcopy(block['lines'])
-            lines = insert_lines_into_block(bbox, line_height, page_w, page_h)
-            block['lines'] = []
-            for line in lines:
-                block['lines'].append({'bbox': line, 'spans': []})
-            page_line_list.extend(lines)
+            add_lines_to_block(block)
 
     if len(page_line_list) > 200:  # layoutreader最高支持512line
         return None
@@ -663,12 +666,68 @@ def parse_page_core(
     discarded_blocks = magic_model.get_discarded(page_id)
     text_blocks = magic_model.get_text_blocks(page_id)
     title_blocks = magic_model.get_title_blocks(page_id)
-    inline_equations, interline_equations, interline_equation_blocks = (
-        magic_model.get_equations(page_id)
-    )
-
+    inline_equations, interline_equations, interline_equation_blocks = magic_model.get_equations(page_id)
     page_w, page_h = magic_model.get_page_size(page_id)
 
+    def merge_title_blocks(blocks, x_distance_threshold=0.1*page_w):
+        def merge_two_blocks(b1, b2):
+            # 合并两个标题块的边界框
+            x_min = min(b1['bbox'][0], b2['bbox'][0])
+            y_min = min(b1['bbox'][1], b2['bbox'][1])
+            x_max = max(b1['bbox'][2], b2['bbox'][2])
+            y_max = max(b1['bbox'][3], b2['bbox'][3])
+            merged_bbox = (x_min, y_min, x_max, y_max)
+
+            # 合并两个标题块的文本内容
+            merged_score = (b1['score'] + b2['score']) / 2
+
+            return {'bbox': merged_bbox, 'score': merged_score}
+
+        # 按 y 轴重叠度聚集标题块
+        y_overlapping_blocks = []
+        while blocks:
+            block1 = blocks.pop(0)
+            current_row = [block1]
+            to_remove = []
+            for block2 in blocks:
+                if __is_overlaps_y_exceeds_threshold(block1['bbox'], block2['bbox'], 0.9):
+                    current_row.append(block2)
+                    to_remove.append(block2)
+            for b in to_remove:
+                blocks.remove(b)
+            y_overlapping_blocks.append(current_row)
+
+        # 按x轴坐标排序并合并标题块
+        merged_blocks = []
+        for row in y_overlapping_blocks:
+            if len(row) == 1:
+                merged_blocks.append(row[0])
+                continue
+
+            # 按x轴坐标排序
+            row.sort(key=lambda x: x['bbox'][0])
+
+            merged_block = row[0]
+            for i in range(1, len(row)):
+                left_block = merged_block
+                right_block = row[i]
+
+                left_height = left_block['bbox'][3] - left_block['bbox'][1]
+                right_height = right_block['bbox'][3] - right_block['bbox'][1]
+
+                if right_block['bbox'][0] - left_block['bbox'][2] < x_distance_threshold and left_height * 0.95 < right_height < left_height * 1.05:
+                    merged_block = merge_two_blocks(merged_block, right_block)
+                else:
+                    merged_blocks.append(merged_block)
+                    merged_block = right_block
+
+            merged_blocks.append(merged_block)
+
+        return merged_blocks
+
+    """同一行被断开的titile合并"""
+    title_blocks = merge_title_blocks(title_blocks)
+
     """将所有区块的bbox整理到一起"""
     # interline_equation_blocks参数不够准,后面切换到interline_equations上
     interline_equation_blocks = []

+ 1 - 1
magic_pdf/resources/model_config/model_configs.yaml

@@ -1,6 +1,6 @@
 weights:
   layoutlmv3: Layout/LayoutLMv3/model_final.pth
-  doclayout_yolo: Layout/YOLO/doclayout_yolo_ft.pt
+  doclayout_yolo: Layout/YOLO/doclayout_yolo_docstructbench_imgsz1280_2501.pt
   yolo_v8_mfd: MFD/YOLO/yolo_v8_ft.pt
   unimernet_small: MFR/unimernet_small
   struct_eqtable: TabRec/StructEqTable

+ 1 - 1
setup.py

@@ -48,7 +48,7 @@ if __name__ == '__main__':
                      "struct-eqtable==0.3.2",  # 表格解析
                      "einops",  # struct-eqtable依赖
                      "accelerate",  # struct-eqtable依赖
-                     "doclayout_yolo==0.0.2",  # doclayout_yolo
+                     "doclayout_yolo==0.0.2b1",  # doclayout_yolo
                      "rapidocr-paddle",  # rapidocr-paddle
                      "rapidocr_onnxruntime",
                      "rapid_table==0.3.0",  # rapid_table