Browse Source

增加ocr版本解析功能

赵小蒙 1 năm trước cách đây
mục cha
commit
701f384994

+ 29 - 0
demo/ocr_demo.py

@@ -0,0 +1,29 @@
+import os
+
+from loguru import logger
+
+from magic_pdf.dict2md.ocr_mkcontent import mk_nlp_markdown
+from magic_pdf.pdf_parse_by_ocr import parse_pdf_by_ocr
+
+
+def save_markdown(markdown_text, input_filepath):
+    # 获取输入文件的目录
+    directory = os.path.dirname(input_filepath)
+    # 获取输入文件的文件名(不带扩展名)
+    base_name = os.path.basename(input_filepath)
+    file_name_without_ext = os.path.splitext(base_name)[0]
+    # 定义输出文件的路径
+    output_filepath = os.path.join(directory, f"{file_name_without_ext}.md")
+
+    # 将Markdown文本写入.md文件
+    with open(output_filepath, 'w', encoding='utf-8') as file:
+        file.write(markdown_text)
+
+
+if __name__ == '__main__':
+    ocr_json_file_path = r"D:\project\20231108code-clean\ocr\new\demo_4\ocr_0.json"
+    pdf_info_dict = parse_pdf_by_ocr(ocr_json_file_path)
+    markdown_text = mk_nlp_markdown(pdf_info_dict)
+    logger.info(markdown_text)
+    save_markdown(markdown_text, ocr_json_file_path)
+

+ 21 - 0
magic_pdf/dict2md/ocr_mkcontent.py

@@ -0,0 +1,21 @@
+def mk_nlp_markdown(pdf_info_dict: dict):
+
+    markdown = []
+
+    for _, page_info in pdf_info_dict.items():
+        blocks = page_info.get("preproc_blocks")
+        if not blocks:
+            continue
+        for block in blocks:
+            for line in block['lines']:
+                line_text = ''
+                for span in line['spans']:
+                    content = span['content'].replace('$', '\$')  # 转义$
+                    if span['type'] == 'inline_equation':
+                        content = f"${content}$"
+                    elif span['type'] == 'displayed_equation':
+                        content = f"$$\n{content}\n$$"
+                    line_text += content + ' '
+                # 在行末添加两个空格以强制换行
+                markdown.append(line_text.strip() + '  ')
+    return '\n'.join(markdown)

+ 33 - 1
magic_pdf/libs/boxbase.py

@@ -119,6 +119,20 @@ def _is_left_overlap(box1, box2,):
     return x0_1<=x0_2<=x1_1 and vertical_overlap_cond
 
 
+def __is_overlaps_y_exceeds_threshold(bbox1, bbox2, overlap_ratio_threshold=0.8):
+    """检查两个bbox在y轴上是否有重叠,并且该重叠区域的高度占两个bbox高度更低的那个超过80%"""
+    _, y0_1, _, y1_1 = bbox1
+    _, y0_2, _, y1_2 = bbox2
+
+    overlap = max(0, min(y1_1, y1_2) - max(y0_1, y0_2))
+    height1, height2 = y1_1 - y0_1, y1_2 - y0_2
+    max_height = max(height1, height2)
+    min_height = min(height1, height2)
+
+    return (overlap / min_height) > overlap_ratio_threshold
+
+
+
 def calculate_iou(bbox1, bbox2):
     # Determine the coordinates of the intersection rectangle
     x_left = max(bbox1[0], bbox2[0])
@@ -163,7 +177,25 @@ def calculate_overlap_area_2_minbox_area_ratio(bbox1, bbox2):
     else:
         return intersection_area / min_box_area
 
-    
+
+def get_minbox_if_overlap_by_ratio(bbox1, bbox2, ratio):
+    """
+    通过calculate_overlap_area_2_minbox_area_ratio计算两个bbox重叠的面积占最小面积的box的比例
+    如果比例大于ratio,则返回小的那个bbox,
+    否则返回None
+    """
+    x1_min, y1_min, x1_max, y1_max = bbox1
+    x2_min, y2_min, x2_max, y2_max = bbox2
+    area1 = (x1_max - x1_min) * (y1_max - y1_min)
+    area2 = (x2_max - x2_min) * (y2_max - y2_min)
+    overlap_ratio = calculate_overlap_area_2_minbox_area_ratio(bbox1, bbox2)
+    if overlap_ratio > ratio and area1 < area2:
+        return bbox1
+    elif overlap_ratio > ratio and area2 < area1:
+        return bbox2
+    else:
+        return None
+
 def get_bbox_in_boundry(bboxes:list, boundry:tuple)-> list:
     x0, y0, x1, y1 = boundry
     new_boxes = [box for box in bboxes if box[0] >= x0 and box[1] >= y0 and box[2] <= x1 and box[3] <= y1]

+ 46 - 0
magic_pdf/libs/ocr_dict_merge.py

@@ -0,0 +1,46 @@
+from magic_pdf.libs.boxbase import __is_overlaps_y_exceeds_threshold
+
+
+def merge_spans(spans):
+    # 按照y0坐标排序
+    spans.sort(key=lambda span: span['bbox'][1])
+
+    lines = []
+    current_line = [spans[0]]
+    for span in spans[1:]:
+        # 如果当前的span类型为"displayed_equation" 或者 当前行中已经有"displayed_equation"
+        if span['type'] == "displayed_equation" or any(s['type'] == "displayed_equation" for s in current_line):
+            # 则开始新行
+            lines.append(current_line)
+            current_line = [span]
+            continue
+
+        # 如果当前的span与当前行的最后一个span在y轴上重叠,则添加到当前行
+        if __is_overlaps_y_exceeds_threshold(span['bbox'], current_line[-1]['bbox']):
+            current_line.append(span)
+        else:
+            # 否则,开始新行
+            lines.append(current_line)
+            current_line = [span]
+
+    # 添加最后一行
+    if current_line:
+        lines.append(current_line)
+
+    # 计算每行的边界框,并对每行中的span按照x0进行排序
+    line_objects = []
+    for line in lines:
+        # 按照x0坐标排序
+        line.sort(key=lambda span: span['bbox'][0])
+        line_bbox = [
+            min(span['bbox'][0] for span in line),  # x0
+            min(span['bbox'][1] for span in line),  # y0
+            max(span['bbox'][2] for span in line),  # x1
+            max(span['bbox'][3] for span in line),  # y1
+        ]
+        line_objects.append({
+            "bbox": line_bbox,
+            "spans": line,
+        })
+
+    return line_objects

+ 85 - 0
magic_pdf/pdf_parse_by_ocr.py

@@ -0,0 +1,85 @@
+import json
+
+from magic_pdf.libs.boxbase import get_minbox_if_overlap_by_ratio
+from magic_pdf.libs.ocr_dict_merge import merge_spans
+
+
+def read_json_file(file_path):
+    with open(file_path, 'r') as f:
+        data = json.load(f)
+    return data
+
+
+def construct_page_component(page_id, text_blocks_preproc):
+    return_dict = {
+        'preproc_blocks': text_blocks_preproc,
+        'page_idx': page_id
+    }
+    return return_dict
+
+
+def parse_pdf_by_ocr(
+    ocr_json_file_path,
+    start_page_id=0,
+    end_page_id=None,
+):
+    ocr_pdf_info = read_json_file(ocr_json_file_path)
+    pdf_info_dict = {}
+    end_page_id = end_page_id if end_page_id else len(ocr_pdf_info) - 1
+    for page_id in range(start_page_id, end_page_id + 1):
+        ocr_page_info = ocr_pdf_info[page_id]
+        layout_dets = ocr_page_info['layout_dets']
+        spans = []
+        for layout_det in layout_dets:
+            category_id = layout_det['category_id']
+            allow_category_id_list = [13, 14, 15]
+            if category_id in allow_category_id_list:
+                x0, y0, _, _, x1, y1, _, _ = layout_det['poly']
+                bbox = [int(x0), int(y0), int(x1), int(y1)]
+                #  13: 'embedding',     # 嵌入公式
+                #  14: 'isolated',      # 单行公式
+                #  15: 'ocr_text',      # ocr识别文本
+                span = {
+                    'bbox': bbox,
+                }
+                if category_id == 13:
+                    span['content'] = layout_det['latex']
+                    span['type'] = 'inline_equation'
+                elif category_id == 14:
+                    span['content'] = layout_det['latex']
+                    span['type'] = 'displayed_equation'
+                elif category_id == 15:
+                    span['content'] = layout_det['text']
+                    span['type'] = 'text'
+                # print(span)
+                spans.append(span)
+            else:
+                continue
+
+        # 合并重叠的spans
+        for span1 in spans.copy():
+            for span2 in spans.copy():
+                if span1 != span2:
+                    overlap_box = get_minbox_if_overlap_by_ratio(span1['bbox'], span2['bbox'], 0.8)
+                    if overlap_box is not None:
+                        bbox_to_remove = next((span for span in spans if span['bbox'] == overlap_box), None)
+                        if bbox_to_remove is not None:
+                            spans.remove(bbox_to_remove)
+
+        # 将spans合并成line
+        lines = merge_spans(spans)
+
+        # 目前不做block拼接,先做个结构,每个block中只有一个line,block的bbox就是line的bbox
+        blocks = []
+        for line in lines:
+            blocks.append({
+                "bbox": line['bbox'],
+                "lines": [line],
+            })
+
+        # 构造pdf_info_dict
+        page_info = construct_page_component(page_id, blocks)
+        pdf_info_dict[f"page_{page_id}"] = page_info
+
+    return pdf_info_dict
+