Browse Source

feat(优化OCR框颜色管理): 在module_debug_viz.py中新增OCR框和印章框颜色获取函数,更新draw_ocr_spans_cv2函数以支持按类别动态着色,提升可视化效果;在output_formatter_v2.py中调整OCR框和单元格框颜色引用,确保一致性;在visualization_utils.py中完善颜色定义,增强模块间的可维护性。

zhch158_admin 1 month ago
parent
commit
0d0e8d028a
3 changed files with 109 additions and 101 deletions
  1. 27 5
      ocr_utils/module_debug_viz.py
  2. 2 2
      ocr_utils/output_formatter_v2.py
  3. 80 94
      ocr_utils/visualization_utils.py

+ 27 - 5
ocr_utils/module_debug_viz.py

@@ -54,13 +54,30 @@ LAYOUT_HIGHLIGHT_LINE_THICKNESS = 4
 LAYOUT_HIGHLIGHT_OUTLINE_BGR = (0, 0, 0)
 LAYOUT_DEFAULT_LINE_THICKNESS = 2
 
-# 亮蓝(BGR),在白底/浅灰流水上比黄色更易辨认;与 layout 红色框区分
-OCR_BOX_COLOR_BGR = (255, 0, 0)
+# OCR 框线宽 (不受配色统一影响)
 OCR_BOX_LINE_THICKNESS = 2
 OCR_BOX_DASH_LENGTH = 8
 OCR_BOX_DASH_GAP = 6
 
 
+def _ocr_box_color_bgr() -> tuple:
+    """亮蓝 OCR 框 (BGR),派生自 VisualizationUtils.COLOR_MAP['ocr_box']。"""
+    from ocr_utils.visualization_utils import VisualizationUtils
+    return VisualizationUtils.rgb_to_bgr(VisualizationUtils.COLOR_MAP['ocr_box'])
+
+
+def _seal_ocr_box_color_bgr() -> tuple:
+    """印章 OCR 框 (BGR),派生自 VisualizationUtils.COLOR_MAP['seal_ocr_box']。"""
+    from ocr_utils.visualization_utils import VisualizationUtils
+    return VisualizationUtils.rgb_to_bgr(VisualizationUtils.COLOR_MAP['seal_ocr_box'])
+
+
+def ocr_box_color_rgb() -> tuple:
+    """OCR 亮蓝 (RGB),供 PIL / Plotly 使用。"""
+    from ocr_utils.visualization_utils import VisualizationUtils
+    return VisualizationUtils.COLOR_MAP['ocr_box']
+
+
 def _to_bgr(image: Union[np.ndarray, Image.Image]) -> np.ndarray:
     if isinstance(image, Image.Image):
         arr = np.array(image)
@@ -194,7 +211,10 @@ def draw_ocr_spans_cv2(
     *,
     max_label_chars: int = 12,
 ) -> np.ndarray:
-    """在 BGR 图像上绘制 OCR span(poly 或 bbox);无文字用虚线框。"""
+    """在 BGR 图像上绘制 OCR span(poly 或 bbox);无文字用虚线框。
+    
+    span 可带 category='seal' 使用印章专用亮橙色,否则使用亮蓝。
+    """
     vis = _to_bgr(image)
     for span in spans:
         poly = span.get('poly')
@@ -209,19 +229,21 @@ def draw_ocr_spans_cv2(
             )
         if pts is not None:
             text_raw = str(span.get('text', '') or '').strip()
+            color = _seal_ocr_box_color_bgr() if span.get('category') == 'seal' else _ocr_box_color_bgr()
             _draw_span_outline(
                 vis,
                 pts,
-                OCR_BOX_COLOR_BGR,
+                color,
                 OCR_BOX_LINE_THICKNESS,
                 dashed=not text_raw,
             )
         text = str(span.get('text', '')).strip()[:max_label_chars]
         if text and pts is not None:
+            color = _seal_ocr_box_color_bgr() if span.get('category') == 'seal' else _ocr_box_color_bgr()
             x, y = int(pts[0][0]), int(pts[0][1])
             cv2.putText(
                 vis, text, (x, max(y - 2, 10)),
-                cv2.FONT_HERSHEY_SIMPLEX, 0.35, OCR_BOX_COLOR_BGR, 1, cv2.LINE_AA,
+                cv2.FONT_HERSHEY_SIMPLEX, 0.35, color, 1, cv2.LINE_AA,
             )
     return vis
 

+ 2 - 2
ocr_utils/output_formatter_v2.py

@@ -59,8 +59,8 @@ class OutputFormatterV2:
     
     # 颜色映射(导出供其他模块使用)
     COLOR_MAP = VisualizationUtils.COLOR_MAP
-    OCR_BOX_COLOR = VisualizationUtils.OCR_BOX_COLOR
-    CELL_BOX_COLOR = VisualizationUtils.CELL_BOX_COLOR
+    OCR_BOX_COLOR = VisualizationUtils.COLOR_MAP['ocr_box']
+    CELL_BOX_COLOR = VisualizationUtils.COLOR_MAP['cell_box']
     
     def __init__(self, output_dir: str):
         """

+ 80 - 94
ocr_utils/visualization_utils.py

@@ -71,9 +71,25 @@ class VisualizationUtils:
         
         # 错误
         'error': (255, 0, 0),               # 红色
+        
+        # --- 通用工具颜色(非元素类别,供 module_debug_viz / ocr_validator 引用) ---
+        
+        # OCR 文字框:亮蓝(白底/浅灰上比黄/红色易辨认)
+        'ocr_box': (0, 0, 255),
+        # 印章 OCR 框:亮橙(独立管线,与 layout seal 颜色一致,审计时区分)
+        'seal_ocr_box': (255, 140, 0),
+        # 表格单元格框:与 ocr_box 同色
+        'cell_box': (0, 0, 255),
+        # 丢弃/废弃元素框
+        'discard': (128, 128, 128),
     }
     
-    # OCR 框颜色(与 module_debug_viz.OCR_BOX_COLOR_BGR 一致:亮蓝 BGR→RGB)
+    @staticmethod
+    def rgb_to_bgr(rgb: tuple) -> tuple:
+        """RGB → BGR(供 OpenCV 模块使用)。"""
+        return tuple(rgb[i] for i in (2, 1, 0)) if len(rgb) >= 3 else rgb
+    
+    # --- 向后兼容别名(推荐使用 COLOR_MAP['ocr_box'] 等) ---
     OCR_BOX_COLOR = (0, 0, 255)
     CELL_BOX_COLOR = (0, 0, 255)
     DISCARD_COLOR = (128, 128, 128)  # 灰色
@@ -242,18 +258,18 @@ class VisualizationUtils:
                 # 半透明填充
                 overlay = Image.new('RGBA', image.size, (255, 255, 255, 0))
                 overlay_draw = ImageDraw.Draw(overlay)
-                overlay_draw.rectangle([x0, y0, x1, y1], fill=(*VisualizationUtils.DISCARD_COLOR, 30))
+                overlay_draw.rectangle([x0, y0, x1, y1], fill=(*VisualizationUtils.COLOR_MAP['discard'], 30))
                 image = Image.alpha_composite(image.convert('RGBA'), overlay).convert('RGB')
                 draw = ImageDraw.Draw(image)
                 
                 # 灰色边框
-                draw.rectangle([x0, y0, x1, y1], outline=VisualizationUtils.DISCARD_COLOR, width=1)
+                draw.rectangle([x0, y0, x1, y1], outline=VisualizationUtils.COLOR_MAP['discard'], width=1)
                 
                 # 类型标签
                 if draw_type_label:
                     label = f"D:{original_category}"
                     bbox_label = draw.textbbox((x0 + 2, y0 + 2), label, font=font)
-                    draw.rectangle(bbox_label, fill=VisualizationUtils.DISCARD_COLOR)
+                    draw.rectangle(bbox_label, fill=VisualizationUtils.COLOR_MAP['discard'])
                     draw.text((x0 + 2, y0 + 2), label, fill='white', font=font)
             
             # 根据输入类型决定命名
@@ -276,117 +292,87 @@ class VisualizationUtils:
         is_pdf: bool = True
     ) -> List[str]:
         """
-        保存 OCR 可视化图片
-        
+        保存 OCR 可视化图片(与 *_page_001.json 同源同构)。
+
+        数据源为 JSONFormatters._element_to_cell_bbox_format 转换后的扁平格式
+        (与 save_page_jsons 输出的 JSON 一致);
+        绘制样式与 debug/ocr_recognition 一致:亮蓝实线=有文字,虚线=仅框无字。
+
         命名规则:
         - PDF输入: 文件名_page_001_ocr.png
         - 图片输入(单页): 文件名_ocr.png
-        
-        Args:
-            results: 处理结果
-            output_dir: 输出目录
-            doc_name: 文档名称
-            is_pdf: 是否为 PDF 输入
-            
-        Returns:
-            保存的图片路径列表
         """
+        from ocr_utils.json_formatters import JSONFormatters
+        from ocr_utils.module_debug_viz import draw_ocr_spans_cv2
+
         ocr_paths = []
         total_pages = len(results.get('pages', []))
-        
+
         for page in results.get('pages', []):
             page_idx = page.get('page_idx', 0)
             processed_image = page.get('original_image')
             if processed_image is None:
                 processed_image = page.get('processed_image')
-            
+
             if processed_image is None:
                 logger.warning(f"Page {page_idx}: No image data found for OCR visualization")
                 continue
-            
-            if isinstance(processed_image, np.ndarray):
-                image = Image.fromarray(processed_image).convert('RGB')
-            elif isinstance(processed_image, Image.Image):
-                image = processed_image.convert('RGB')
-            else:
-                continue
-            
-            draw = ImageDraw.Draw(image)
-            font = VisualizationUtils._get_font(10)
-            
-            for element in page.get('elements', []):
-                content = element.get('content', {})
-                
-                # OCR 文本框
-                ocr_details = content.get('ocr_details', [])
-                for ocr_item in ocr_details:
-                    ocr_bbox = ocr_item.get('bbox', [])
-                    if ocr_bbox:
-                        VisualizationUtils._draw_polygon(
-                            draw, ocr_bbox, VisualizationUtils.OCR_BOX_COLOR, width=1
-                        )
-                
-                # 表格单元格
-                cells = content.get('cells', [])
-                for cell in cells:
-                    cell_bbox = cell.get('bbox', [])
-                    if cell_bbox and len(cell_bbox) >= 4:
-                        x0, y0, x1, y1 = map(int, cell_bbox[:4])
-                        draw.rectangle(
-                            [x0, y0, x1, y1], 
-                            outline=VisualizationUtils.CELL_BOX_COLOR, 
-                            width=2
-                        )
-                        
-                        cell_text = cell.get('text', '')[:10]
-                        if cell_text:
-                            draw.text(
-                                (x0 + 2, y0 + 2), 
-                                cell_text, 
-                                fill=VisualizationUtils.CELL_BOX_COLOR, 
-                                font=font
-                            )
-                
-                # OCR 框
-                ocr_boxes = content.get('ocr_boxes', [])
-                for ocr_box in ocr_boxes:
-                    bbox = ocr_box.get('bbox', [])
-                    if bbox:
-                        VisualizationUtils._draw_polygon(
-                            draw, bbox, VisualizationUtils.OCR_BOX_COLOR, width=1
-                        )
-            
-            # 绘制丢弃元素的 OCR 框
-            for element in page.get('discarded_blocks', []):
-                bbox = element.get('bbox', [0, 0, 0, 0])
-                content = element.get('content', {})
-                
-                if len(bbox) >= 4:
-                    x0, y0, x1, y1 = map(int, bbox[:4])
-                    draw.rectangle(
-                        [x0, y0, x1, y1], 
-                        outline=VisualizationUtils.DISCARD_COLOR, 
-                        width=1
-                    )
-                    
-                    ocr_details = content.get('ocr_details', [])
-                    for ocr_item in ocr_details:
-                        ocr_bbox = ocr_item.get('bbox', [])
-                        if ocr_bbox:
-                            VisualizationUtils._draw_polygon(
-                                draw, ocr_bbox, VisualizationUtils.DISCARD_COLOR, width=1
-                            )
-            
-            # 根据输入类型决定命名
+
+            page_rotation_angle = float(page.get('angle', 0))
+
+            flat_elements = []
+            for element in (page.get('elements') or []):
+                converted = JSONFormatters._element_to_cell_bbox_format(
+                    element, page_idx, page_rotation_angle
+                )
+                if converted:
+                    flat_elements.append(converted)
+            for element in (page.get('discarded_blocks') or []):
+                converted = JSONFormatters._element_to_cell_bbox_format(
+                    element, page_idx, page_rotation_angle
+                )
+                if converted:
+                    flat_elements.append(converted)
+
+            spans = []
+            for elem in flat_elements:
+                bbox = elem.get('bbox', [])
+                if not bbox or len(bbox) < 4:
+                    continue
+                elem_type = elem.get('type', '')
+                if 'table_cells' in elem:
+                    for cell in elem['table_cells']:
+                        cell_bbox = cell.get('bbox', [])
+                        if cell_bbox and len(cell_bbox) >= 4:
+                            spans.append({
+                                'bbox': cell_bbox[:4],
+                                'text': cell.get('text', '').strip(),
+                            })
+                elif elem.get('text') is not None:
+                    spans.append({
+                        'bbox': bbox[:4],
+                        'text': str(elem.get('text', '')).strip(),
+                        'category': 'seal' if elem_type == 'seal' else None,
+                    })
+                else:
+                    spans.append({
+                        'bbox': bbox[:4],
+                        'text': '',
+                    })
+
+            vis_bgr = draw_ocr_spans_cv2(processed_image, spans)
+            vis_rgb = cv2.cvtColor(vis_bgr, cv2.COLOR_BGR2RGB)
+            image = Image.fromarray(vis_rgb)
+
             if is_pdf or total_pages > 1:
                 ocr_path = output_dir / f"{doc_name}_page_{page_idx + 1:03d}_ocr.png"
             else:
                 ocr_path = output_dir / f"{doc_name}_ocr.png"
-            
+
             image.save(ocr_path)
             ocr_paths.append(str(ocr_path))
             logger.info(f"🖼️ OCR image saved: {ocr_path}")
-        
+
         return ocr_paths
     
     @staticmethod