瀏覽代碼

refactor: streamline text span extraction and remove unused functions

myhloli 5 月之前
父節點
當前提交
a3ae57bf20
共有 2 個文件被更改,包括 3 次插入170 次删除
  1. 2 2
      mineru/backend/pipeline/model_json_to_middle_json.py
  2. 1 168
      mineru/utils/span_pre_proc.py

+ 2 - 2
mineru/backend/pipeline/model_json_to_middle_json.py

@@ -15,7 +15,7 @@ from mineru.utils.model_utils import clean_memory
 from mineru.utils.pipeline_magic_model import MagicModel
 from mineru.utils.span_block_fix import fill_spans_in_blocks, fix_discarded_block, fix_block_spans
 from mineru.utils.span_pre_proc import remove_outside_spans, remove_overlaps_low_confidence_spans, \
-    remove_overlaps_min_spans, txt_spans_extract_v3
+    remove_overlaps_min_spans, txt_spans_extract
 from mineru.version import __version__
 from mineru.utils.hash_utils import str_md5
 
@@ -113,7 +113,7 @@ def page_model_info_to_page_info(page_model_info, image_dict, page, image_writer
         pass
     else:
         """使用新版本的混合ocr方案."""
-        spans = txt_spans_extract_v3(page, spans, page_pil_img, scale, all_bboxes, all_discarded_blocks)
+        spans = txt_spans_extract(page, spans, page_pil_img, scale, all_bboxes, all_discarded_blocks)
 
     """先处理不需要排版的discarded_blocks"""
     discarded_block_with_spans, spans = fill_spans_in_blocks(

+ 1 - 168
mineru/utils/span_pre_proc.py

@@ -118,118 +118,8 @@ def __replace_unicode(text: str):
     return re.sub('|'.join(map(re.escape, ligatures.keys())), lambda m: ligatures[m.group()], text)
 
 
-"""textpage.get_text_bounded方案"""
-def txt_spans_extract_v1(pdf_page, spans, pil_img, scale):
-
-    textpage = pdf_page.get_textpage()
-    width, height = pdf_page.get_size()
-    cropbox = pdf_page.get_cropbox()
-    need_ocr_spans = []
-    for span in spans:
-        if span['type'] in [ContentType.INTERLINE_EQUATION, ContentType.IMAGE, ContentType.TABLE]:
-            continue
-        span_bbox = span['bbox']
-        rect_box = [span_bbox[0] + cropbox[0],
-                    height - span_bbox[3] + cropbox[1],
-                    span_bbox[2] + cropbox[0],
-                    height - span_bbox[1] + cropbox[1]]
-        # logger.info(f"span bbox: {span_bbox}, rect_box: {rect_box}")
-        middle_height = (rect_box[1] + rect_box[3]) / 2
-        rect_box[1] = middle_height - 1
-        rect_box[3] = middle_height + 1
-        text = textpage.get_text_bounded(left=rect_box[0], top=rect_box[1],
-                                         right=rect_box[2], bottom=rect_box[3])
-        if text and len(text) > 0:
-            text = __replace_unicode(text)
-            text = __replace_ligatures(text)
-            span['content'] = text.strip()
-            span['score'] = 1.0
-        else:
-            need_ocr_spans.append(span)
-
-    if len(need_ocr_spans) > 0:
-
-        for span in need_ocr_spans:
-            # 对span的bbox截图再ocr
-            span_pil_img = get_crop_img(span['bbox'], pil_img, scale)
-            span_img = cv2.cvtColor(np.array(span_pil_img), cv2.COLOR_RGB2BGR)
-            # 计算span的对比度,低于0.20的span不进行ocr
-            if calculate_contrast(span_img, img_mode='bgr') <= 0.17:
-                spans.remove(span)
-                continue
-
-            span['content'] = ''
-            span['score'] = 1.0
-            span['np_img'] = span_img
-
-    return spans
-
-
-"""pdf_text dict方案 span级别"""
-def txt_spans_extract_v2(pdf_page, spans, pil_img, scale):
-
-    page_dict = get_page(pdf_page)
-
-    page_all_spans = []
-    for block in page_dict['blocks']:
-        for line in block['lines']:
-            if 0 < abs(line['rotation']) < 90:
-                # 旋转角度在0-90度之间的行,直接跳过
-                continue
-            for span in line['spans']:
-                page_all_spans.append(span)
-
-    need_ocr_spans = []
-    for span in spans:
-        if span['type'] in [ContentType.TEXT]:
-            span['sub_spans'] = []
-            matched_spans = []
-            for page_span in page_all_spans:
-                if calculate_overlap_area_in_bbox1_area_ratio(page_span['bbox'].bbox, span['bbox']) > 0.5:
-                    span['sub_spans'].append(page_span)
-                    matched_spans.append(page_span)
-
-            # 从page_all_spans中移除已匹配的元素
-            page_all_spans = [span for span in page_all_spans if span not in matched_spans]
-
-            # 对sub_spans按照bbox的x坐标进行排序
-            span['sub_spans'].sort(key=lambda x: x['bbox'].x_start)
-            # 对sub_spans的content进行拼接
-            span_content = ''.join([sub_span['text'] for sub_span in span['sub_spans']])
-
-            if span_content and len(span_content) > 0:
-                span_content = __replace_unicode(span_content)
-                span_content = __replace_ligatures(span_content)
-                span['content'] = span_content.strip()
-                span['score'] = 1.0
-            else:
-                need_ocr_spans.append(span)
-
-            # 移除span的sub_spans
-            span.pop('sub_spans', None)
-        else:
-            pass
-
-    if len(need_ocr_spans) > 0:
-
-        for span in need_ocr_spans:
-            # 对span的bbox截图再ocr
-            span_pil_img = get_crop_img(span['bbox'], pil_img, scale)
-            span_img = cv2.cvtColor(np.array(span_pil_img), cv2.COLOR_RGB2BGR)
-            # 计算span的对比度,低于0.20的span不进行ocr
-            if calculate_contrast(span_img, img_mode='bgr') <= 0.17:
-                spans.remove(span)
-                continue
-
-            span['content'] = ''
-            span['score'] = 1.0
-            span['np_img'] = span_img
-
-    return spans
-
-
 """pdf_text dict方案 char级别"""
-def txt_spans_extract_v3(pdf_page, spans, pil_img, scale, all_bboxes, all_discarded_blocks):
+def txt_spans_extract(pdf_page, spans, pil_img, scale, all_bboxes, all_discarded_blocks):
 
     page_dict = get_page(pdf_page)
 
@@ -385,9 +275,6 @@ def chars_to_content(span):
     if len(span['chars']) == 0:
         pass
     else:
-        # 先给chars按char['bbox']的中心点的x坐标排序
-        # span['chars'] = sorted(span['chars'], key=lambda x: (x['bbox'][0] + x['bbox'][2]) / 2)
-
         # 给chars按char_idx排序
         span['chars'] = sorted(span['chars'], key=lambda x: x['char_idx'])
 
@@ -396,9 +283,6 @@ def chars_to_content(span):
         # Calculate the median width
         median_width = statistics.median(char_widths)
 
-        # 通过x轴重叠比率移除一部分char
-        # span = remove_x_overlapping_chars(span, median_width)
-
         content = ''
         for char in span['chars']:
 
@@ -418,57 +302,6 @@ def chars_to_content(span):
     del span['chars']
 
 
-def remove_x_overlapping_chars(span, median_width):
-    """
-    Remove characters from a span that overlap significantly on the x-axis.
-
-    Args:
-        median_width:
-        span (dict): A span containing a list of chars, each with bbox coordinates
-                    in the format [x0, y0, x1, y1]
-
-    Returns:
-        dict: The span with overlapping characters removed
-    """
-    if 'chars' not in span or len(span['chars']) < 2:
-        return span
-
-    overlap_threshold = median_width * 0.3
-
-    i = 0
-    while i < len(span['chars']) - 1:
-        char1 = span['chars'][i]
-        char2 = span['chars'][i + 1]
-
-        # Calculate overlap width
-        x_left = max(char1['bbox'][0], char2['bbox'][0])
-        x_right = min(char1['bbox'][2], char2['bbox'][2])
-
-        if x_right > x_left:  # There is overlap
-            overlap_width = x_right - x_left
-
-            if overlap_width > overlap_threshold:
-                if char1['char'] == char2['char'] or char1['char'] == ' ' or char2['char'] == ' ':
-                    # Determine which character to remove
-                    width1 = char1['bbox'][2] - char1['bbox'][0]
-                    width2 = char2['bbox'][2] - char2['bbox'][0]
-                    if width1 < width2:
-                        # Remove the narrower character
-                        span['chars'].pop(i)
-                    else:
-                        span['chars'].pop(i + 1)
-                else:
-                    i += 1
-
-                # Don't increment i since we need to check the new pair
-            else:
-                i += 1
-        else:
-            i += 1
-
-    return span
-
-
 def calculate_contrast(img, img_mode) -> float:
     """
     计算给定图像的对比度。