Browse Source

<fix>(para_split_v2): index out of range issue of span_text first char (#396)

Co-authored-by: liukaiwen <liukaiwen@pjlab.org.cn>
Kaiwen Liu 1 year ago
parent
commit
65c3ac66ae
1 changed files with 50 additions and 47 deletions
  1. 50 47
      magic_pdf/para/para_split_v2.py

+ 50 - 47
magic_pdf/para/para_split_v2.py

@@ -100,59 +100,62 @@ def __detect_list_lines(lines, new_layout_bboxes, lang):
 
     if lang != 'en':
         return lines, None
-    else:
-        total_lines = len(lines)
-        line_fea_encode = []
-        """
-        对每一行进行特征编码,编码规则如下:
-        1. 如果行顶格,且大写字母开头或者数字开头,编码为1
-        2. 如果顶格,其他非大写开头编码为4
-        3. 如果非顶格,首字符大写,编码为2
-        4. 如果非顶格,首字符非大写编码为3
-        """
-        if len(lines) > 0:
-            x_map_tag_dict, min_x_tag = cluster_line_x(lines)
-        for l in lines:
-            span_text = __get_span_text(l['spans'][0])
-            first_char = span_text[0]
-            layout = __find_layout_bbox_by_line(l['bbox'], new_layout_bboxes)
-            if not layout:
-                line_fea_encode.append(0)
+
+    total_lines = len(lines)
+    line_fea_encode = []
+    """
+    对每一行进行特征编码,编码规则如下:
+    1. 如果行顶格,且大写字母开头或者数字开头,编码为1
+    2. 如果顶格,其他非大写开头编码为4
+    3. 如果非顶格,首字符大写,编码为2
+    4. 如果非顶格,首字符非大写编码为3
+    """
+    if len(lines) > 0:
+        x_map_tag_dict, min_x_tag = cluster_line_x(lines)
+    for l in lines:
+        span_text = __get_span_text(l['spans'][0])
+        if not span_text:
+            line_fea_encode.append(0)
+            continue
+        first_char = span_text[0]
+        layout = __find_layout_bbox_by_line(l['bbox'], new_layout_bboxes)
+        if not layout:
+            line_fea_encode.append(0)
+        else:
+            #
+            if x_map_tag_dict[round(l['bbox'][0])] == min_x_tag:
+                # if first_char.isupper() or first_char.isdigit() or not first_char.isalnum():
+                if not first_char.isalnum() or if_match_reference_list(span_text):
+                    line_fea_encode.append(1)
+                else:
+                    line_fea_encode.append(4)
             else:
-                #
-                if x_map_tag_dict[round(l['bbox'][0])] == min_x_tag:
-                    # if first_char.isupper() or first_char.isdigit() or not first_char.isalnum():
-                    if not first_char.isalnum() or if_match_reference_list(span_text):
-                        line_fea_encode.append(1)
-                    else:
-                        line_fea_encode.append(4)
+                if first_char.isupper():
+                    line_fea_encode.append(2)
                 else:
-                    if first_char.isupper():
-                        line_fea_encode.append(2)
-                    else:
-                        line_fea_encode.append(3)
+                    line_fea_encode.append(3)
 
-        # 然后根据编码进行分段, 选出来 1,2,3连续出现至少2次的行,认为是列表。
+    # 然后根据编码进行分段, 选出来 1,2,3连续出现至少2次的行,认为是列表。
 
-        list_indice, list_start_idx = find_repeating_patterns2(line_fea_encode)
-        if len(list_indice) > 0:
+    list_indice, list_start_idx = find_repeating_patterns2(line_fea_encode)
+    if len(list_indice) > 0:
+        if debug_able:
+            logger.info(f"发现了列表,列表行数:{list_indice}, {list_start_idx}")
+
+    # TODO check一下这个特列表里缩进的行左侧是不是对齐的。
+    segments = []
+    for start, end in list_indice:
+        for i in range(start, end + 1):
+            if i > 0:
+                if line_fea_encode[i] == 4:
+                    if debug_able:
+                        logger.info(f"列表行的第{i}行不是顶格的")
+                    break
+        else:
             if debug_able:
-                logger.info(f"发现了列表,列表行数:{list_indice}, {list_start_idx}")
-
-        # TODO check一下这个特列表里缩进的行左侧是不是对齐的。
-        segments = []
-        for start, end in list_indice:
-            for i in range(start, end + 1):
-                if i > 0:
-                    if line_fea_encode[i] == 4:
-                        if debug_able:
-                            logger.info(f"列表行的第{i}行不是顶格的")
-                        break
-            else:
-                if debug_able:
-                    logger.info(f"列表行的第{start}到第{end}行是列表")
+                logger.info(f"列表行的第{start}到第{end}行是列表")
 
-        return split_indices(total_lines, list_indice), list_start_idx
+    return split_indices(total_lines, list_indice), list_start_idx
 
 
 def cluster_line_x(lines: list) -> dict: