Explorar el Código

feat: 改进文档处理流程,支持从 PDF 提取文本并与 OCR 结果对比,添加调试模式

zhch158_admin hace 4 días
padre
commit
dd92babb27
Se han modificado 1 ficheros con 165 adiciones y 14 borrados
  1. 165 14
      ocr_tools/universal_doc_parser/core/pipeline_manager_v2.py

+ 165 - 14
ocr_tools/universal_doc_parser/core/pipeline_manager_v2.py

@@ -217,7 +217,7 @@ class EnhancedDocPipeline:
         try:
             # 1. 加载文档并分类
             dpi = self.config.get('input', {}).get('dpi', 200)
-            images, pdf_type, pdf_doc = PDFUtils.load_and_classify_document(
+            images, pdf_type, pdf_doc, renderer_used = PDFUtils.load_and_classify_document(
                 doc_path, dpi=dpi, page_range=page_range
             )
             results['metadata']['pdf_type'] = pdf_type
@@ -359,21 +359,58 @@ class EnhancedDocPipeline:
         page_result['layout_raw'] = layout_results
         
         # 3. 整页 OCR 获取所有 text spans(关键改进)
-        all_ocr_spans = []
-        try:
-            all_ocr_spans = self.ocr_recognizer.recognize_text(detection_image)
-            # 去除重复 spans
-            all_ocr_spans = SpanMatcher.remove_duplicate_spans(all_ocr_spans)
-            # 按坐标排序(从上到下,从左到右),方便人工检查缺失字符
-            all_ocr_spans = self._sort_spans_by_position(all_ocr_spans)
-            logger.info(f"📝 Page {page_idx}: OCR detected {len(all_ocr_spans)} text spans")
-        except Exception as e:
-            logger.warning(f"⚠️ Full-page OCR failed: {e}")
+        all_text_spans = []
+        should_run_ocr = True
+        text_source = 'ocr'
+        
+        if pdf_type == 'txt' and pdf_doc is not None:
+            # 文字 PDF:直接从 PDF 提取文本块
+            try:
+                pdf_text_blocks = PDFUtils.extract_all_text_blocks(
+                    pdf_doc, page_idx, scale=scale
+                )
+                # 将 PDF 文本块转换为 span 格式
+                all_text_spans = self._convert_pdf_blocks_to_spans(
+                    pdf_text_blocks, detection_image.shape
+                )
+                text_source = 'pdf'
+                logger.info(f"📝 Page {page_idx}: PDF extracted {len(all_text_spans)} text blocks")
+            except Exception as e:
+                logger.warning(f"⚠️ PDF text extraction failed, fallback to OCR: {e}")
+                pdf_type = 'ocr'  # Fallback to OCR
+
+        # OCR PDF 或 PDF 提取失败时使用 OCR
+        elif pdf_type == 'ocr':
+            should_run_ocr = True
+            if self.debug_mode and text_source == 'pdf':
+                # 调试模式:同时运行 OCR 对比
+                should_run_ocr = True
+                
+            if should_run_ocr:
+                try:
+                    all_text_spans = self.ocr_recognizer.recognize_text(detection_image)
+                    all_text_spans = SpanMatcher.remove_duplicate_spans(all_text_spans)
+                    all_text_spans = self._sort_spans_by_position(all_text_spans)
+                    text_source = 'ocr'
+                    logger.info(f"📝 Page {page_idx}: OCR detected {len(all_text_spans)} text spans")
+                except Exception as e:
+                    logger.warning(f"⚠️ Full-page OCR failed: {e}")                
+                # 3.1 调试模式:对比 OCR 和 PDF 提取结果
+                if self.debug_mode and pdf_type == 'txt' and pdf_doc is not None:
+                    self._compare_ocr_and_pdf_text(
+                        page_idx, pdf_doc, all_ocr_spans, detection_image, output_dir, page_name, scale
+                    )
+        else:
+            raise ValueError(f"Unknown pdf_type: {pdf_type}")
         
         # 4. 将 OCR spans 匹配到 layout blocks
-        matched_spans = SpanMatcher.match_spans_to_blocks(
-            all_ocr_spans, layout_results, overlap_threshold=0.5
-        )
+        matched_spans = {}
+        if all_text_spans:
+            matched_spans = SpanMatcher.match_spans_to_blocks(
+                all_text_spans, layout_results, overlap_threshold=0.5
+            )
+        # 记录文本来源
+        page_result['text_source'] = text_source  # 'ocr' 或 'pdf'
         
         # 5. 分类元素
         classified_elements = self._classify_elements(layout_results, page_idx)
@@ -413,6 +450,120 @@ class EnhancedDocPipeline:
         page_result['elements'] = sorted_elements
         page_result['discarded_blocks'] = sorted_discarded
         return page_result
+
+    @staticmethod
+    def _convert_pdf_blocks_to_spans(
+        pdf_text_blocks: List[Dict[str, Any]],
+        image_shape: tuple
+    ) -> List[Dict[str, Any]]:
+        """
+        将 PDF 文本块转换为 OCR span 格式
+        
+        Args:
+            pdf_text_blocks: PDF 提取的文本块 [{'text': str, 'bbox': [x1,y1,x2,y2]}, ...]
+            image_shape: 图像尺寸 (height, width, channels)
+            
+        Returns:
+            OCR span 格式的列表
+        """
+        spans = []
+        
+        for block in pdf_text_blocks:
+            text = block.get('text', '').strip()
+            bbox = block.get('bbox')
+            
+            if not text or not bbox or len(bbox) < 4:
+                continue
+            
+            # 确保 bbox 在图像范围内
+            x1, y1, x2, y2 = bbox
+            h, w = image_shape[:2]
+            
+            x1 = max(0, min(x1, w))
+            y1 = max(0, min(y1, h))
+            x2 = max(0, min(x2, w))
+            y2 = max(0, min(y2, h))
+            
+            if x2 <= x1 or y2 <= y1:
+                continue
+            
+            # 转换为 OCR span 格式
+            span = {
+                'text': text,
+                'bbox': [x1, y1, x2, y2],  # 或者转为 poly 格式
+                'score': 1.0,  # PDF 提取的置信度设为 1.0
+                'source': 'pdf'  # 标记来源
+            }
+            
+            spans.append(span)
+        
+        return spans
+        
+    def _compare_ocr_and_pdf_text(
+        self, 
+        page_idx: int, 
+        pdf_doc: Any, 
+        ocr_spans: List[Dict[str, Any]], 
+        image: np.ndarray,
+        output_dir: Optional[str],
+        page_name: str,
+        scale: float
+    ):
+        """
+        对比 OCR 和 PDF 提取结果,并输出调试信息
+        """
+        if not output_dir:
+            return
+
+        try:
+            import cv2
+            import json
+            
+            # 获取 PDF 文本
+            pdf_text_blocks = PDFUtils.extract_all_text_blocks(pdf_doc, page_idx, scale=scale)
+            
+            # 准备可视化图像
+            vis_image = image.copy()
+            
+            # 绘制 PDF 文本框 (蓝色)
+            for block in pdf_text_blocks:
+                bbox = [int(x) for x in block['bbox']]
+                cv2.rectangle(vis_image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (255, 0, 0), 2)
+                
+            # 绘制 OCR 文本框 (红色)
+            for span in ocr_spans:
+                bbox = span.get('bbox')
+                if bbox:
+                    if isinstance(bbox[0], list): # poly
+                        pts = np.array(bbox, np.int32)
+                        pts = pts.reshape((-1, 1, 2))
+                        cv2.polylines(vis_image, [pts], True, (0, 0, 255), 2)
+                    else: # bbox
+                        cv2.rectangle(vis_image, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (0, 0, 255), 2)
+            
+            # 保存对比图像
+            debug_dir = Path(output_dir) / "debug_comparison"
+            debug_dir.mkdir(parents=True, exist_ok=True)
+            output_path = debug_dir / f"{page_name}_comparison.jpg"
+            cv2.imwrite(str(output_path), vis_image)
+            
+            # 保存对比 JSON
+            comparison_data = {
+                'page_idx': page_idx,
+                'ocr_count': len(ocr_spans),
+                'pdf_text_count': len(pdf_text_blocks),
+                'ocr_spans': [{'text': s['text'], 'bbox': s.get('bbox')} for s in ocr_spans],
+                'pdf_blocks': pdf_text_blocks
+            }
+            
+            json_path = debug_dir / f"{page_name}_comparison.json"
+            with open(json_path, 'w', encoding='utf-8') as f:
+                json.dump(comparison_data, f, ensure_ascii=False, indent=2)
+                
+            logger.info(f"📝 Saved debug comparison to {debug_dir}")
+            
+        except Exception as e:
+            logger.warning(f"⚠️ Debug comparison failed: {e}")
     
     # ==================== OCR Spans 排序 ====================