#!/usr/bin/env python3 """ 方向1:对比度增强参数网格扫描。 不去水印,直接对原图做多种对比度增强,验证哪种参数组合能让水印 在视觉上"淡化"、正文保持清晰,从而使后续 OCR 不受水印干扰。 用法: cd ocr_platform/ocr_tools/watermark_lab # 单张图快速扫描 python contrast_sweep.py ../gan_experiments_lab/test_images/input/彭_广东兴宁农村商业银行_page_002.png --quick # 全量扫描(更多参数组合 + 生成增强对比图) python contrast_sweep.py ../gan_experiments_lab/test_images/input/ # 同时跑 OCR 整页对比(det+rec,每方法 Top-1 组合) python contrast_sweep.py input.png --ocr --model-dir /path/to/models # 每方法 Top-3 组合跑 OCR 对比 python contrast_sweep.py input.png --ocr --ocr-top-n 3 输出: output// ├── sweep_report.json # 参数扫描结果汇总(含 OCR 对比结果) ├── sweep_summary.csv # CSV 表格 ├── quad_compare.png # 四宫格对比图 ├── text_restore_t60_bg248.png # 各组合增强结果图 ├── clahe_cl3.0_t8.png ├── gamma_g0.5.png └── ocr/ # OCR 对比结果(--ocr 时生成) ├── _original_ocr_spans.png # 原始图 OCR 可视化 ├── _original_ocr_spans.json # 原始图 OCR JSON ├── __ocr_spans.png # 各增强组合 OCR 可视化 ├── __ocr_spans.json # 各增强组合 OCR JSON └── ocr_comparison.json # OCR 差异汇总报告 """ from __future__ import annotations import argparse import json import sys import time from itertools import product from pathlib import Path from typing import Any, Dict, List, Optional, Sequence, Tuple import cv2 import numpy as np _repo_root = Path(__file__).resolve().parents[3] if str(_repo_root) not in sys.path: sys.path.insert(0, str(_repo_root)) from loguru import logger from ocr_utils.watermark.contrast import enhance_document_contrast _IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp"} # ── 参数网格 ──────────────────────────────────────────────────── def _build_param_grid(quick: bool = False) -> List[Dict[str, Any]]: """构建参数网格。 四个维度: 1. method: text_restore | clahe | gamma | linear 2. text_restore 专属: text_black_target + background_threshold 3. clahe 专属: clip_limit + tile_grid_size 4. gamma 专属: gamma """ grid: List[Dict[str, Any]] = [] # ── text_restore ── if quick: tbt = [40, 60, 85] bts = [240, 248] else: tbt = [40, 60, 80, 100, 120] bts = [235, 240, 248, 252] for target, bg_th in product(tbt, bts): grid.append({ "method": "text_restore", "text_black_target": target, "background_threshold": bg_th, }) # ── clahe ── if quick: cl = [1.0, 3.0, 5.0] ts = [8, 16] else: cl = [0.5, 1.0, 2.0, 3.0, 5.0, 8.0] ts = [4, 8, 16, 32] for clip, tile in product(cl, ts): grid.append({ "method": "clahe", "clip_limit": clip, "tile_grid_size": tile, }) # ── gamma ── if quick: gvs = [0.4, 0.55, 0.7, 0.85] else: gvs = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] for g in gvs: grid.append({"method": "gamma", "gamma": g}) # ── linear ── if quick: bps = [2.0, 5.0] wps = [95.0, 98.0] else: bps = [1.0, 2.0, 5.0, 8.0] wps = [92.0, 95.0, 98.0] for bp, wp in product(bps, wps): grid.append({"method": "linear", "black_percentile": bp, "white_percentile": wp}) return grid # ── 标签生成 ──────────────────────────────────────────────────── def _tag_from_cfg(cfg: Dict[str, Any]) -> str: m = cfg["method"] if m == "text_restore": return f"{m}_t{cfg['text_black_target']}_bg{cfg['background_threshold']}" if m == "clahe": return f"{m}_cl{cfg['clip_limit']}_t{cfg['tile_grid_size']}" if m == "gamma": return f"{m}_g{cfg['gamma']}" if m == "linear": return f"{m}_b{cfg['black_percentile']}_w{cfg['white_percentile']}" return m # ── 工具函数 ──────────────────────────────────────────────────── def _collect_images(path: Path) -> List[Path]: if path.is_file(): if path.suffix.lower() not in _IMAGE_SUFFIXES: raise ValueError(f"不支持的图像格式: {path}") return [path] if not path.is_dir(): raise FileNotFoundError(path) return sorted( p for p in path.iterdir() if p.is_file() and p.suffix.lower() in _IMAGE_SUFFIXES ) def _compute_watermark_fade_score( original: np.ndarray, enhanced: np.ndarray, window: int = 31 ) -> float: """ 量化水印淡化程度。 原理:大核中值滤波估计背景,残差的方差越小 = 水印纹理越弱。 """ o_f = original.astype(np.float32) e_f = enhanced.astype(np.float32) k = max(3, window) | 1 o_bg = cv2.medianBlur(o_f.astype(np.uint8), k).astype(np.float32) e_bg = cv2.medianBlur(e_f.astype(np.uint8), k).astype(np.float32) o_res = cv2.absdiff(o_f, o_bg) e_res = cv2.absdiff(e_f, e_bg) return float(1.0 - np.var(e_res) / max(np.var(o_res), 1.0)) def _compute_text_sharpness_score( enhanced: np.ndarray, win: int = 3 ) -> float: """局部标准差均值,越大 = 文字越清晰。""" e_f = enhanced.astype(np.float32) kernel = np.ones((win, win), np.float32) / (win * win) mean = cv2.filter2D(e_f, -1, kernel) sq_mean = cv2.filter2D(e_f * e_f, -1, kernel) var = np.maximum(sq_mean - mean * mean, 0) return float(np.sqrt(var).mean()) # ── 对比图生成 ────────────────────────────────────────────────── def _make_quad_compare( original: np.ndarray, top_enhanced: List[Tuple[str, np.ndarray]], ) -> np.ndarray: """生成四宫格对比图:原图 | 最佳 text_restore | 最佳 clahe | 最佳 gamma。""" panels = [original] labels = ["Original"] for label, img in top_enhanced: panels.append(img) labels.append(label) # 全部转 BGR bgr_panels: List[np.ndarray] = [] for p in panels: if p.ndim == 2: bgr_panels.append(cv2.cvtColor(p, cv2.COLOR_GRAY2BGR)) else: bgr_panels.append(p) # 统一高度 h = max(p.shape[0] for p in bgr_panels) w = max(p.shape[1] for p in bgr_panels) resized: List[np.ndarray] = [] for p, label in zip(bgr_panels, labels): if p.shape[0] != h or p.shape[1] != w: p = cv2.resize(p, (w, h)) bar = np.ones((40, w, 3), dtype=np.uint8) * 240 cv2.putText(bar, label, (12, 28), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 0, 0), 2) resized.append(np.vstack([p, bar])) return np.hstack(resized) # ── OCR(整页对比)─────────────────────────────────────────────── def _poly_to_bbox(poly: List[List[float]]) -> List[int]: """四点 polygon 转轴对齐 bbox [x0,y0,x1,y1].""" xs = [p[0] for p in poly] ys = [p[1] for p in poly] return [int(min(xs)), int(min(ys)), int(max(xs)), int(max(ys))] def _ocr_full_page(engine: Any, img: np.ndarray) -> List[Dict[str, Any]]: """整页 OCR(det+rec),返回 spans 列表。 每个 span: {"poly": [[x,y],...], "bbox": [x0,y0,x1,y1], "text": "...", "confidence": 0.9} """ res = engine.ocr(img, det=True, rec=True) items = res[0] if res and res[0] is not None else [] spans: List[Dict[str, Any]] = [] for item in items: if not item or len(item) < 2: continue box, rec_part = item[0], item[1] text = str(rec_part[0] if isinstance(rec_part, (list, tuple)) else rec_part) conf = float(rec_part[1]) if isinstance(rec_part, (list, tuple)) and len(rec_part) > 1 else 0.0 poly = [[float(p[0]), float(p[1])] for p in box] if box else [] spans.append({ "poly": poly, "bbox": _poly_to_bbox(poly) if poly else [], "text": text.strip(), "confidence": round(conf, 4), }) from ocr_tools.universal_doc_parser.core.layout_utils import SpanMatcher spans = SpanMatcher.remove_duplicate_spans(spans) return spans def _save_ocr_debug_for_sweep( image: np.ndarray, spans: List[Dict[str, Any]], out_dir: Path, tag: str, ) -> Dict[str, str]: """保存 OCR 可视化图和 JSON。复用 module_debug_viz.draw_ocr_spans_cv2。""" from ocr_utils.module_debug_viz import draw_ocr_spans_cv2 ocr_dir = out_dir / "ocr" ocr_dir.mkdir(parents=True, exist_ok=True) # 可视化图 vis = draw_ocr_spans_cv2(image, spans) img_path = ocr_dir / f"{tag}_ocr_spans.png" cv2.imwrite(str(img_path), vis) # JSON json_data = { "tag": tag, "count": len(spans), "spans": [ { "bbox": s.get("bbox"), "poly": s.get("poly"), "text": s.get("text"), "confidence": s.get("confidence"), } for s in spans ], } json_path = ocr_dir / f"{tag}_ocr_spans.json" json_path.write_text(json.dumps(json_data, ensure_ascii=False, indent=2), encoding="utf-8") logger.info(f" OCR debug: {img_path}") return {"image": str(img_path), "json": str(json_path)} def _compare_ocr_results( orig_spans: List[Dict[str, Any]], enh_spans: List[Dict[str, Any]], iou_threshold: float = 0.5, ) -> Dict[str, Any]: """对比两组 OCR spans 的检测+识别差异。 Returns: { "detection": { orig_count, enh_count, matched, new, missing }, "recognition": { text_changed_count, char_diff_rate, details: [...] }, "summary": "一句话摘要" } """ def _bbox_iou(a: List[int], b: List[int]) -> float: if not a or not b: return 0.0 xa = max(a[0], b[0]) ya = max(a[1], b[1]) xb = min(a[2], b[2]) yb = min(a[3], b[3]) inter = max(0, xb - xa) * max(0, yb - ya) area_a = max(0, a[2] - a[0]) * max(0, a[3] - a[1]) area_b = max(0, b[2] - b[0]) * max(0, b[3] - b[1]) union = area_a + area_b - inter return inter / union if union > 0 else 0.0 # ── 检测对比 ── orig_boxes = [s.get("bbox", []) for s in orig_spans] enh_boxes = [s.get("bbox", []) for s in enh_spans] matched_orig_idxs: set = set() matched_enh_idxs: set = set() recognition_details: List[Dict[str, Any]] = [] for i, ob in enumerate(orig_boxes): if not ob: continue best_j, best_iou = -1, 0.0 for j, eb in enumerate(enh_boxes): if j in matched_enh_idxs or not eb: continue iou = _bbox_iou(ob, eb) if iou > best_iou: best_iou, best_j = iou, j if best_iou >= iou_threshold: matched_orig_idxs.add(i) matched_enh_idxs.add(best_j) orig_text = orig_spans[i].get("text", "") enh_text = enh_spans[best_j].get("text", "") orig_score = orig_spans[i].get("confidence", 0) enh_score = enh_spans[best_j].get("confidence", 0) rec_detail: Dict[str, Any] = { "orig_bbox": ob, "orig_text": orig_text, "orig_score": orig_score, "enh_text": enh_text, "enh_score": enh_score, "iou": round(best_iou, 4), } if orig_text != enh_text: rec_detail["text_changed"] = True else: rec_detail["text_changed"] = False recognition_details.append(rec_detail) new_boxes = len(enh_boxes) - len(matched_enh_idxs) missing_boxes = len(orig_boxes) - len(matched_orig_idxs) # 字符差异率 orig_concat = "".join(s.get("text", "") for s in orig_spans) enh_concat = "".join(s.get("text", "") for s in enh_spans) total_chars = max(len(orig_concat), len(enh_concat), 1) char_diff = sum(1 for a, b in zip(orig_concat, enh_concat) if a != b) + abs( len(orig_concat) - len(enh_concat) ) char_diff_rate = round(char_diff / total_chars, 4) detection = { "orig_count": len(orig_boxes), "enh_count": len(enh_boxes), "matched": len(matched_orig_idxs), "new": new_boxes, "missing": missing_boxes, } recognition = { "text_changed_count": len(recognition_details), "char_diff_rate": char_diff_rate, "details": recognition_details[:50], # 最多保存50条差异明细 } summary = ( f"检测: {detection['orig_count']}→{detection['enh_count']} (匹配{detection['matched']}, " f"新增{detection['new']}, 遗失{detection['missing']}); " f"识别: 文字变化{recognition['text_changed_count']}处, 字符差异率{char_diff_rate:.2%}" ) return {"detection": detection, "recognition": recognition, "summary": summary} def _load_paddle_engine(model_dir: Path, det_thresh: float = 0.3): from ocr_tools.pytorch_models.pytorch_paddle import PytorchPaddleOCR det = model_dir / "ch_PP-OCRv5_det_infer.pth" rec = model_dir / "ch_PP-OCRv4_rec_server_doc_infer.pth" return PytorchPaddleOCR( lang="ch", det_model_path=str(det) if det.exists() else None, rec_model_path=str(rec) if rec.exists() else None, det_db_box_thresh=det_thresh, ) # ── 扫描核心 ──────────────────────────────────────────────────── def run_sweep( input_path: Path, out_dir: Path, *, quick: bool = False, save_images: bool = True, ocr_enabled: bool = False, model_dir: Optional[Path] = None, ocr_top_n: int = 1, ocr_all: bool = False, ) -> Dict[str, Any]: bgr = cv2.imread(str(input_path)) if bgr is None: raise RuntimeError(f"无法读取图像: {input_path}") gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY) stem = input_path.stem img_out = out_dir / stem img_out.mkdir(parents=True, exist_ok=True) grid = _build_param_grid(quick=quick) logger.info(f" {stem}: {len(grid)} 组参数组合") engine = None baseline_spans: List[Dict[str, Any]] = [] if ocr_enabled: try: md = model_dir or Path( "/Users/zhch158/models/modelscope_cache/models/OpenDataLab/" "PDF-Extract-Kit-1___0/models/OCR/paddleocr_torch" ) engine = _load_paddle_engine(md) logger.info(" OCR 引擎已加载") # 基线 OCR(原始灰度图) baseline_spans = _ocr_full_page(engine, gray) logger.info(f" 基线 OCR: {len(baseline_spans)} 个文本块") _save_ocr_debug_for_sweep(bgr, baseline_spans, img_out, f"{stem}_original") except Exception as e: logger.warning(f" OCR 引擎加载失败: {e}") results: List[Dict[str, Any]] = [] # 按 method 分组, 便于后面取各类别最优 method_groups: Dict[str, List[Dict[str, Any]]] = {} for cfg in grid: tag = _tag_from_cfg(cfg) t0 = time.perf_counter() try: enhanced = enhance_document_contrast(gray, **cfg) except Exception as e: logger.warning(f" [{tag}] 增强失败: {e}") continue elapsed = time.perf_counter() - t0 fade = _compute_watermark_fade_score(gray, enhanced) sharpness = _compute_text_sharpness_score(enhanced) # 综合分:fade(水印淡化) 和 sharpness(文字清晰度) 同等权重 combined = round(fade * 0.5 + sharpness / max(sharpness, 0.01) * 0.5, 4) row: Dict[str, Any] = { "tag": tag, **cfg, "fade_score": round(fade, 6), "sharpness_score": round(sharpness, 4), "combined_score": round(combined, 4), "time_ms": round(elapsed * 1000, 1), } if save_images: out_path = img_out / f"{tag}.png" cv2.imwrite(str(out_path), enhanced) row["image_path"] = str(out_path) results.append(row) method = cfg["method"] method_groups.setdefault(method, []).append(row) # ── 排序 ── results.sort(key=lambda r: -r["combined_score"]) for mname, entries in method_groups.items(): entries.sort(key=lambda r: -r["combined_score"]) # Top 各方法最优 tops: List[Tuple[str, str, float]] = [] for mname, entries in method_groups.items(): if entries: top = entries[0] tops.append((mname, top["tag"], top["combined_score"])) logger.info(f" [{mname}] Top: {top['tag']} combined={top['combined_score']:.4f}") logger.info(f" 全局 Top1: {results[0]['tag']} combined={results[0]['combined_score']:.4f}") # ── 阶段二:OCR 对比(整页)───────────────────────────────── ocr_comparisons: List[Dict[str, Any]] = [] if engine and baseline_spans: # 选择要跑 OCR 的组合列表 if ocr_all: ocr_candidates = results else: ocr_candidates: List[Dict[str, Any]] = [] for mname, entries in method_groups.items(): for r in entries[:ocr_top_n]: if r not in ocr_candidates: ocr_candidates.append(r) logger.info(f" OCR 对比 {len(ocr_candidates)} 个组合(每方法 Top-{ocr_top_n})") for r in ocr_candidates: tag = r["tag"] enhanced_path = r.get("image_path") if enhanced_path: enhanced_bgr = cv2.imread(enhanced_path) if enhanced_bgr is None: logger.warning(f" [{tag}] 无法读取增强图") continue enhanced_gray = cv2.cvtColor(enhanced_bgr, cv2.COLOR_BGR2GRAY) else: # 从 raws 重新生成 enhanced_gray = enhance_document_contrast(gray, **{ k: v for k, v in r.items() if k not in ("tag", "image_path", "fade_score", "sharpness_score", "combined_score", "time_ms") }) enhanced_bgr = cv2.cvtColor(enhanced_gray, cv2.COLOR_GRAY2BGR) try: enh_spans = _ocr_full_page(engine, enhanced_gray) except Exception as e: logger.warning(f" [{tag}] OCR 失败: {e}") continue _save_ocr_debug_for_sweep(enhanced_bgr, enh_spans, img_out, f"{stem}_{tag}") cmp = _compare_ocr_results(baseline_spans, enh_spans) cmp["tag"] = tag cmp["method"] = r["method"] r["ocr_comparison"] = cmp ocr_comparisons.append(cmp) for k, v in cmp["detection"].items(): r[f"ocr_det_{k}"] = v for k, v in cmp["recognition"].items(): if not isinstance(v, list): r[f"ocr_rec_{k}"] = v logger.info(f" [{tag}] {cmp['summary']}") # ── 四宫格对比图 ── if save_images and len(tops) >= 3: selected_labels = [] selected_imgs = [] seen_methods = set() for r in results: m = r["method"] if m in seen_methods: continue seen_methods.add(m) path = r.get("image_path") if path: img = cv2.imread(path, cv2.IMREAD_GRAYSCALE) if img is not None: selected_labels.append(r["tag"]) selected_imgs.append(img) if len(selected_imgs) >= 3: break if selected_imgs: quad = _make_quad_compare(gray, list(zip(selected_labels, selected_imgs))) quad_path = img_out / "quad_compare.png" cv2.imwrite(str(quad_path), quad) logger.info(f" 四宫格对比图: {quad_path}") # ── 报告 ── report: Dict[str, Any] = { "input": str(input_path), "output_dir": str(img_out), "n_configs_tested": len(results), "top_overall": results[0] if results else None, "top_by_method": { m: e[0] for m, e in method_groups.items() if e }, } if engine: baseline_text = "".join(s.get("text", "") for s in baseline_spans) report["baseline_ocr"] = { "span_count": len(baseline_spans), "full_text": baseline_text, } report["ocr_comparisons"] = { "n_compared": len(ocr_comparisons), "results": ocr_comparisons, } report_path = img_out / "contrast_report.json" report_path.write_text( json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8" ) # 单独的 OCR 对比汇总报告(含完整检测+识别对比数据) if engine: ocr_summary_path = img_out / "ocr" / "ocr_comparison.json" ocr_summary_path.parent.mkdir(parents=True, exist_ok=True) ocr_summary = { "input": str(input_path), "baseline_spans": len(baseline_spans), "compared": ocr_comparisons, } ocr_summary_path.write_text( json.dumps(ocr_summary, ensure_ascii=False, indent=2), encoding="utf-8" ) # CSV if results: csv_keys = [k for k in results[0].keys() if not k.endswith("_path") and k != "ocr_comparison"] lines = [",".join(csv_keys)] for r in results: lines.append(",".join(str(r.get(k, "")) for k in csv_keys)) (img_out / "contrast_summary.csv").write_text("\n".join(lines), encoding="utf-8") logger.info(f" 报告: {report_path}") return report # ── CLI ────────────────────────────────────────────────────────── def _build_arg_parser() -> argparse.ArgumentParser: p = argparse.ArgumentParser( description="对比度增强参数网格扫描(不去水印,直接增强前后对比)", ) p.add_argument("input", type=Path, help="单张图片路径或图片目录") p.add_argument("-o", "--output", type=Path, default=None, help="输出根目录,默认 input 同级 contrast_out/") p.add_argument("--quick", action="store_true", help="缩小参数网格") p.add_argument("--no-save-images", action="store_true", help="不写出增强结果图") p.add_argument("--ocr", action="store_true", help="启用整页 OCR 对比(det+rec):基线 OCR + Top-N 增强图 OCR,输出 spans 可视化和 JSON") p.add_argument("--ocr-top-n", type=int, default=1, help="OCR 对比时每方法取 Top-N 组合(默认 1)") p.add_argument("--ocr-all", action="store_true", help="对所有参数组合跑 OCR 对比(覆盖 --ocr-top-n)") p.add_argument("--model-dir", type=Path, default=None, help="PaddleOCR 模型目录") return p def main(argv: Optional[Sequence[str]] = None) -> None: args = _build_arg_parser().parse_args(argv) images = _collect_images(args.input) if not images: raise SystemExit("未找到可扫描的图像") if args.output is not None: out_root = args.output elif args.input.is_file(): out_root = args.input.parent / "contrast_out" else: out_root = args.input / "contrast_out" out_root.mkdir(parents=True, exist_ok=True) logger.info(f"扫描 {len(images)} 张图 -> {out_root}") logger.info(f" quick={args.quick} ocr={args.ocr} ocr_top_n={args.ocr_top_n} ocr_all={args.ocr_all}") summary: List[Dict[str, Any]] = [] for img_path in images: logger.info(f"\n=== {img_path.name} ===") report = run_sweep( img_path, out_root, quick=args.quick, save_images=not args.no_save_images, ocr_enabled=args.ocr, model_dir=args.model_dir, ocr_top_n=args.ocr_top_n, ocr_all=args.ocr_all, ) to = report.get("top_overall") summary.append({ "input": report["input"], "n_tested": report["n_configs_tested"], "top_tag": to["tag"] if to else None, "top_combined": to["combined_score"] if to else None, "report": str(Path(report["output_dir"]) / "contrast_report.json"), }) index_path = out_root / "contrast_index.json" index_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8") logger.info(f"\n全部完成。索引: {index_path}") for s in summary: logger.info(f" {Path(s['input']).name}: Top={s['top_tag']} combined={s['top_combined']}") if __name__ == "__main__": # python contrast_sweep.py ../gan_experiments_lab/test_images/input/彭_广东兴宁农村商业银行_page_002.png --ocr --ocr-top-n 1 if len(sys.argv) == 1: print("ℹ️ 未提供命令行参数,使用默认配置运行...") default_config = { "input": "../gan_experiments_lab/test_images/input/彭_广东兴宁农村商业银行_page_002.png", "output": "./output/彭_广东兴宁农村商业银行_page_002/contrast_sweep", "ocr": True, "ocr_top_n": 3, "quick": True, } sys.argv = [sys.argv[0], default_config["input"]] for key, value in default_config.items(): if key == "input": continue flag = f"--{key.replace('_', '-')}" if isinstance(value, bool) and value: sys.argv.append(flag) elif not isinstance(value, bool): sys.argv.extend([flag, str(value)]) sys.exit(main())