Prechádzať zdrojové kódy

feat(新增模块级Debug可视化): 添加ocr_utils/module_debug_viz.py模块,提供布局和OCR调试图的绘制与保存功能,支持JSON输出,增强调试过程的可视化和审计能力。

zhch158_admin 5 dní pred
rodič
commit
9624e032a1
1 zmenil súbory, kde vykonal 213 pridanie a 0 odobranie
  1. 213 0
      ocr_utils/module_debug_viz.py

+ 213 - 0
ocr_utils/module_debug_viz.py

@@ -0,0 +1,213 @@
+"""
+模块级 Debug 可视化(Layout / OCR)
+
+用于 debug_comparison/ 下基于 inference_image 的调试图;
+用户审计图由 VisualizationUtils + original_image 负责,不在此模块。
+"""
+from __future__ import annotations
+
+import json
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Union
+
+import cv2
+import numpy as np
+from loguru import logger
+from PIL import Image
+
+LAYOUT_CATEGORY_COLORS_BGR = {
+    'table_body': (0, 0, 255),
+    'table_caption': (0, 0, 200),
+    'table_footnote': (0, 0, 150),
+    'text': (255, 0, 0),
+    'title': (0, 255, 255),
+    'header': (255, 0, 255),
+    'footer': (0, 165, 255),
+    'image_body': (0, 255, 0),
+    'image_caption': (0, 200, 0),
+    'image_footnote': (0, 150, 0),
+    'abandon': (128, 128, 128),
+}
+
+OCR_BOX_COLOR_BGR = (0, 255, 255)
+
+
+def _to_bgr(image: Union[np.ndarray, Image.Image]) -> np.ndarray:
+    if isinstance(image, Image.Image):
+        arr = np.array(image)
+    else:
+        arr = image.copy()
+    if arr.ndim == 2:
+        return cv2.cvtColor(arr, cv2.COLOR_GRAY2BGR)
+    if arr.shape[2] == 3:
+        return cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
+    return arr
+
+
+def draw_layout_boxes_cv2(
+    image: Union[np.ndarray, Image.Image],
+    layout_results: List[Dict[str, Any]],
+) -> np.ndarray:
+    """在 BGR 图像上绘制 layout 检测框,返回新图像。"""
+    vis = _to_bgr(image)
+    for result in layout_results:
+        bbox = result.get('bbox', [])
+        if not bbox or len(bbox) < 4:
+            continue
+        category = result.get('category', 'unknown')
+        color = LAYOUT_CATEGORY_COLORS_BGR.get(category, (128, 128, 128))
+        x1, y1, x2, y2 = int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
+        cv2.rectangle(vis, (x1, y1), (x2, y2), color, 2)
+        label = category
+        confidence = result.get('confidence', result.get('score', 0))
+        if confidence:
+            label += f":{float(confidence):.2f}"
+        font = cv2.FONT_HERSHEY_SIMPLEX
+        font_scale = 0.4
+        text_thickness = 1
+        (text_width, text_height), baseline = cv2.getTextSize(
+            label, font, font_scale, text_thickness
+        )
+        text_y = max(y1 - baseline - 1, text_height + baseline)
+        cv2.rectangle(
+            vis,
+            (x1, text_y - text_height - baseline - 2),
+            (x1 + text_width, text_y),
+            color,
+            -1,
+        )
+        cv2.putText(
+            vis, label, (x1, text_y - baseline - 1),
+            font, font_scale, (255, 255, 255), text_thickness,
+        )
+    return vis
+
+
+def draw_ocr_spans_cv2(
+    image: Union[np.ndarray, Image.Image],
+    spans: List[Dict[str, Any]],
+    *,
+    max_label_chars: int = 12,
+) -> np.ndarray:
+    """在 BGR 图像上绘制 OCR span(poly 或 bbox)。"""
+    vis = _to_bgr(image)
+    for span in spans:
+        poly = span.get('poly')
+        bbox = span.get('bbox', [])
+        pts = None
+        if poly and len(poly) >= 4:
+            pts = np.array(poly, dtype=np.int32).reshape(-1, 2)
+        elif bbox and len(bbox) >= 4:
+            x0, y0, x1, y1 = map(int, bbox[:4])
+            pts = np.array(
+                [[x0, y0], [x1, y0], [x1, y1], [x0, y1]], dtype=np.int32
+            )
+        if pts is not None:
+            cv2.polylines(vis, [pts], True, OCR_BOX_COLOR_BGR, 1)
+        text = str(span.get('text', ''))[:max_label_chars]
+        if text and pts is not None:
+            x, y = int(pts[0][0]), int(pts[0][1])
+            cv2.putText(
+                vis, text, (x, max(y - 2, 10)),
+                cv2.FONT_HERSHEY_SIMPLEX, 0.35, OCR_BOX_COLOR_BGR, 1,
+            )
+    return vis
+
+
+def save_layout_debug(
+    image: Union[np.ndarray, Image.Image],
+    layout_results: List[Dict[str, Any]],
+    output_dir: Union[str, Path],
+    page_name: str,
+    *,
+    suffix: str = 'raw',
+    subdir: str = 'layout_detection',
+    image_format: str = 'jpg',
+    save_json: bool = True,
+) -> Optional[Dict[str, str]]:
+    """保存 layout 模块 debug 图与 JSON。"""
+    if not layout_results or not output_dir:
+        return None
+    try:
+        fmt = (image_format or 'jpg').lstrip('.')
+        debug_dir = Path(output_dir) / 'debug_comparison' / subdir
+        debug_dir.mkdir(parents=True, exist_ok=True)
+        vis = draw_layout_boxes_cv2(image, layout_results)
+        img_path = debug_dir / f'{page_name}_layout_{suffix}.{fmt}'
+        cv2.imwrite(str(img_path), vis)
+        paths: Dict[str, str] = {'image': str(img_path)}
+        logger.info(f"Saved layout detection image ({suffix}): {img_path}")
+        if save_json:
+            json_data = {
+                'page_name': page_name,
+                'suffix': suffix,
+                'count': len(layout_results),
+                'results': [
+                    {
+                        'category': r.get('category'),
+                        'bbox': r.get('bbox'),
+                        'confidence': r.get('confidence', r.get('score', 0.0)),
+                    }
+                    for r in layout_results
+                ],
+            }
+            json_path = debug_dir / f'{page_name}_layout_{suffix}.json'
+            json_path.write_text(
+                json.dumps(json_data, ensure_ascii=False, indent=2),
+                encoding='utf-8',
+            )
+            paths['json'] = str(json_path)
+            logger.info(f"Saved layout detection JSON ({suffix}): {json_path}")
+        return paths
+    except Exception as e:
+        logger.warning(f"Failed to save layout debug ({suffix}): {e}")
+        return None
+
+
+def save_ocr_debug(
+    image: Union[np.ndarray, Image.Image],
+    spans: List[Dict[str, Any]],
+    output_dir: Union[str, Path],
+    page_name: str,
+    *,
+    subdir: str = 'ocr_recognition',
+    image_format: str = 'png',
+    save_json: bool = True,
+) -> Optional[Dict[str, str]]:
+    """保存 OCR 模块 debug 图与 JSON。"""
+    if not output_dir:
+        return None
+    try:
+        fmt = (image_format or 'png').lstrip('.')
+        debug_dir = Path(output_dir) / 'debug_comparison' / subdir
+        debug_dir.mkdir(parents=True, exist_ok=True)
+        vis = draw_ocr_spans_cv2(image, spans or [])
+        img_path = debug_dir / f'{page_name}_ocr_spans.{fmt}'
+        cv2.imwrite(str(img_path), vis)
+        paths: Dict[str, str] = {'image': str(img_path)}
+        logger.info(f"Saved OCR debug image: {img_path}")
+        if save_json:
+            json_data = {
+                'page_name': page_name,
+                'count': len(spans or []),
+                'spans': [
+                    {
+                        'bbox': s.get('bbox'),
+                        'poly': s.get('poly'),
+                        'text': s.get('text'),
+                        'confidence': s.get('confidence'),
+                    }
+                    for s in (spans or [])
+                ],
+            }
+            json_path = debug_dir / f'{page_name}_ocr_spans.json'
+            json_path.write_text(
+                json.dumps(json_data, ensure_ascii=False, indent=2),
+                encoding='utf-8',
+            )
+            paths['json'] = str(json_path)
+            logger.info(f"Saved OCR debug JSON: {json_path}")
+        return paths
+    except Exception as e:
+        logger.warning(f"Failed to save OCR debug: {e}")
+        return None