Преглед изворни кода

feat: 增强 OCR 处理逻辑,支持 PDF 旋转角度提取与 OCR 结果对比

zhch158_admin пре 4 дана
родитељ
комит
6791737004
1 измењених фајлова са 36 додато и 28 уклоњено
  1. 36 28
      ocr_tools/universal_doc_parser/core/pipeline_manager_v2.py

+ 36 - 28
ocr_tools/universal_doc_parser/core/pipeline_manager_v2.py

@@ -360,21 +360,28 @@ class EnhancedDocPipeline:
         
         # 3. 整页 OCR 获取所有 text spans(关键改进)
         all_text_spans = []
+        all_ocr_spans = []
         should_run_ocr = True
-        text_source = 'ocr'
         
         if pdf_type == 'txt' and pdf_doc is not None:
             # 文字 PDF:直接从 PDF 提取文本块
             try:
-                pdf_text_blocks = PDFUtils.extract_all_text_blocks(
+                pdf_text_blocks, rotation_angle = PDFUtils.extract_all_text_blocks(
                     pdf_doc, page_idx, scale=scale
                 )
+                # 保存rotation角度
+                page_result['angle'] = rotation_angle
+                if rotation_angle != 0:
+                    logger.info(f"📐 Page {page_idx}: PDF rotation {rotation_angle}°")
+
                 # 将 PDF 文本块转换为 span 格式
                 all_text_spans = self._convert_pdf_blocks_to_spans(
                     pdf_text_blocks, detection_image.shape
                 )
-                text_source = 'pdf'
                 logger.info(f"📝 Page {page_idx}: PDF extracted {len(all_text_spans)} text blocks")
+                if self.debug_mode:
+                    # 调试模式:同时运行 OCR 对比
+                    should_run_ocr = True
             except Exception as e:
                 logger.warning(f"⚠️ PDF text extraction failed, fallback to OCR: {e}")
                 pdf_type = 'ocr'  # Fallback to OCR
@@ -382,27 +389,26 @@ class EnhancedDocPipeline:
         # OCR PDF 或 PDF 提取失败时使用 OCR
         elif pdf_type == 'ocr':
             should_run_ocr = True
-            if self.debug_mode and text_source == 'pdf':
-                # 调试模式:同时运行 OCR 对比
-                should_run_ocr = True
-                
-            if should_run_ocr:
-                try:
-                    all_text_spans = self.ocr_recognizer.recognize_text(detection_image)
-                    all_text_spans = SpanMatcher.remove_duplicate_spans(all_text_spans)
-                    all_text_spans = self._sort_spans_by_position(all_text_spans)
-                    text_source = 'ocr'
-                    logger.info(f"📝 Page {page_idx}: OCR detected {len(all_text_spans)} text spans")
-                except Exception as e:
-                    logger.warning(f"⚠️ Full-page OCR failed: {e}")                
-                # 3.1 调试模式:对比 OCR 和 PDF 提取结果
-                if self.debug_mode and pdf_type == 'txt' and pdf_doc is not None:
-                    self._compare_ocr_and_pdf_text(
-                        page_idx, pdf_doc, all_ocr_spans, detection_image, output_dir, page_name, scale
-                    )
         else:
             raise ValueError(f"Unknown pdf_type: {pdf_type}")
+
+        if should_run_ocr:
+            try:
+                all_ocr_spans = self.ocr_recognizer.recognize_text(detection_image)
+                all_ocr_spans = SpanMatcher.remove_duplicate_spans(all_ocr_spans)
+                all_ocr_spans = self._sort_spans_by_position(all_ocr_spans)
+                logger.info(f"📝 Page {page_idx}: OCR detected {len(all_ocr_spans)} text spans")
+            except Exception as e:
+                logger.warning(f"⚠️ Full-page OCR failed: {e}")                
+            # 3.1 调试模式:对比 OCR 和 PDF 提取结果
+            if self.debug_mode and pdf_type == 'txt' and pdf_doc is not None:
+                self._compare_ocr_and_pdf_text(
+                    page_idx, pdf_doc, all_ocr_spans, detection_image, output_dir, page_name, scale
+                )
         
+        if pdf_type == 'ocr':
+            all_text_spans = all_ocr_spans
+            
         # 4. 将 OCR spans 匹配到 layout blocks
         matched_spans = {}
         if all_text_spans:
@@ -410,7 +416,7 @@ class EnhancedDocPipeline:
                 all_text_spans, layout_results, overlap_threshold=0.5
             )
         # 记录文本来源
-        page_result['text_source'] = text_source  # 'ocr' 或 'pdf'
+        page_result['text_source'] = pdf_type  # 'ocr' 或 'pdf'
         
         # 5. 分类元素
         classified_elements = self._classify_elements(layout_results, page_idx)
@@ -491,7 +497,7 @@ class EnhancedDocPipeline:
             span = {
                 'text': text,
                 'bbox': [x1, y1, x2, y2],  # 或者转为 poly 格式
-                'score': 1.0,  # PDF 提取的置信度设为 1.0
+                'confidence': 1.0,  # PDF 提取的置信度设为 1.0
                 'source': 'pdf'  # 标记来源
             }
             
@@ -519,8 +525,8 @@ class EnhancedDocPipeline:
             import cv2
             import json
             
-            # 获取 PDF 文本
-            pdf_text_blocks = PDFUtils.extract_all_text_blocks(pdf_doc, page_idx, scale=scale)
+            # 获取 PDF 文本(包含rotation处理)
+            pdf_text_blocks, rotation = PDFUtils.extract_all_text_blocks(pdf_doc, page_idx, scale=scale)
             
             # 准备可视化图像
             vis_image = image.copy()
@@ -550,6 +556,7 @@ class EnhancedDocPipeline:
             # 保存对比 JSON
             comparison_data = {
                 'page_idx': page_idx,
+                'rotation': rotation,
                 'ocr_count': len(ocr_spans),
                 'pdf_text_count': len(pdf_text_blocks),
                 'ocr_spans': [{'text': s['text'], 'bbox': s.get('bbox')} for s in ocr_spans],
@@ -564,7 +571,8 @@ class EnhancedDocPipeline:
             
         except Exception as e:
             logger.warning(f"⚠️ Debug comparison failed: {e}")
-    
+
+
     # ==================== OCR Spans 排序 ====================
     
     @staticmethod
@@ -759,7 +767,7 @@ class EnhancedDocPipeline:
                 processed_elements.append(ElementProcessors.create_error_element(item, str(e)))
         
         # 处理表格主体
-        for item in classified_elements['table_body']:
+        for idx, item in enumerate(classified_elements['table_body']):
             try:
                 spans = get_matched_spans_for_item(item)
                 
@@ -771,7 +779,7 @@ class EnhancedDocPipeline:
                     logger.info(f"🔷 Using wired UNet table recognition (configured)")
                     element = self.element_processors.process_table_element_wired(
                         detection_image, item, scale, pre_matched_spans=spans,
-                        output_dir=output_dir, basename=basename
+                        output_dir=output_dir, basename=f"{basename}_{idx}"
                     )
                     
                     # 如果有线识别失败(返回空 HTML),fallback 到 VLM