| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733 |
- #!/usr/bin/env python3
- """
- 方向1:对比度增强参数网格扫描。
- 不去水印,直接对原图做多种对比度增强,验证哪种参数组合能让水印
- 在视觉上"淡化"、正文保持清晰,从而使后续 OCR 不受水印干扰。
- 用法:
- cd ocr_platform/ocr_tools/watermark_lab
- # 单张图快速扫描
- python contrast_sweep.py ../gan_experiments_lab/test_images/input/彭_广东兴宁农村商业银行_page_002.png --quick
- # 全量扫描(更多参数组合 + 生成增强对比图)
- python contrast_sweep.py ../gan_experiments_lab/test_images/input/
- # 同时跑 OCR 整页对比(det+rec,每方法 Top-1 组合)
- python contrast_sweep.py input.png --ocr --model-dir /path/to/models
- # 每方法 Top-3 组合跑 OCR 对比
- python contrast_sweep.py input.png --ocr --ocr-top-n 3
- 输出:
- output/<stem>/
- ├── sweep_report.json # 参数扫描结果汇总(含 OCR 对比结果)
- ├── sweep_summary.csv # CSV 表格
- ├── quad_compare.png # 四宫格对比图
- ├── text_restore_t60_bg248.png # 各组合增强结果图
- ├── clahe_cl3.0_t8.png
- ├── gamma_g0.5.png
- └── ocr/ # OCR 对比结果(--ocr 时生成)
- ├── <stem>_original_ocr_spans.png # 原始图 OCR 可视化
- ├── <stem>_original_ocr_spans.json # 原始图 OCR JSON
- ├── <stem>_<tag>_ocr_spans.png # 各增强组合 OCR 可视化
- ├── <stem>_<tag>_ocr_spans.json # 各增强组合 OCR JSON
- └── ocr_comparison.json # OCR 差异汇总报告
- """
- from __future__ import annotations
- import argparse
- import json
- import sys
- import time
- from itertools import product
- from pathlib import Path
- from typing import Any, Dict, List, Optional, Sequence, Tuple
- import cv2
- import numpy as np
- _repo_root = Path(__file__).resolve().parents[3]
- if str(_repo_root) not in sys.path:
- sys.path.insert(0, str(_repo_root))
- from loguru import logger
- from ocr_utils.watermark.contrast import enhance_document_contrast
- _IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp"}
- # ── 参数网格 ────────────────────────────────────────────────────
- def _build_param_grid(quick: bool = False) -> List[Dict[str, Any]]:
- """构建参数网格。
- 四个维度:
- 1. method: text_restore | clahe | gamma | linear
- 2. text_restore 专属: text_black_target + background_threshold
- 3. clahe 专属: clip_limit + tile_grid_size
- 4. gamma 专属: gamma
- """
- grid: List[Dict[str, Any]] = []
- # ── text_restore ──
- if quick:
- tbt = [40, 60, 85]
- bts = [240, 248]
- else:
- tbt = [40, 60, 80, 100, 120]
- bts = [235, 240, 248, 252]
- for target, bg_th in product(tbt, bts):
- grid.append({
- "method": "text_restore",
- "text_black_target": target,
- "background_threshold": bg_th,
- })
- # ── clahe ──
- if quick:
- cl = [1.0, 3.0, 5.0]
- ts = [8, 16]
- else:
- cl = [0.5, 1.0, 2.0, 3.0, 5.0, 8.0]
- ts = [4, 8, 16, 32]
- for clip, tile in product(cl, ts):
- grid.append({
- "method": "clahe",
- "clip_limit": clip,
- "tile_grid_size": tile,
- })
- # ── gamma ──
- if quick:
- gvs = [0.4, 0.55, 0.7, 0.85]
- else:
- gvs = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
- for g in gvs:
- grid.append({"method": "gamma", "gamma": g})
- # ── linear ──
- if quick:
- bps = [2.0, 5.0]
- wps = [95.0, 98.0]
- else:
- bps = [1.0, 2.0, 5.0, 8.0]
- wps = [92.0, 95.0, 98.0]
- for bp, wp in product(bps, wps):
- grid.append({"method": "linear", "black_percentile": bp, "white_percentile": wp})
- return grid
- # ── 标签生成 ────────────────────────────────────────────────────
- def _tag_from_cfg(cfg: Dict[str, Any]) -> str:
- m = cfg["method"]
- if m == "text_restore":
- return f"{m}_t{cfg['text_black_target']}_bg{cfg['background_threshold']}"
- if m == "clahe":
- return f"{m}_cl{cfg['clip_limit']}_t{cfg['tile_grid_size']}"
- if m == "gamma":
- return f"{m}_g{cfg['gamma']}"
- if m == "linear":
- return f"{m}_b{cfg['black_percentile']}_w{cfg['white_percentile']}"
- return m
- # ── 工具函数 ────────────────────────────────────────────────────
- def _collect_images(path: Path) -> List[Path]:
- if path.is_file():
- if path.suffix.lower() not in _IMAGE_SUFFIXES:
- raise ValueError(f"不支持的图像格式: {path}")
- return [path]
- if not path.is_dir():
- raise FileNotFoundError(path)
- return sorted(
- p for p in path.iterdir() if p.is_file() and p.suffix.lower() in _IMAGE_SUFFIXES
- )
- def _compute_watermark_fade_score(
- original: np.ndarray, enhanced: np.ndarray, window: int = 31
- ) -> float:
- """
- 量化水印淡化程度。
- 原理:大核中值滤波估计背景,残差的方差越小 = 水印纹理越弱。
- """
- o_f = original.astype(np.float32)
- e_f = enhanced.astype(np.float32)
- k = max(3, window) | 1
- o_bg = cv2.medianBlur(o_f.astype(np.uint8), k).astype(np.float32)
- e_bg = cv2.medianBlur(e_f.astype(np.uint8), k).astype(np.float32)
- o_res = cv2.absdiff(o_f, o_bg)
- e_res = cv2.absdiff(e_f, e_bg)
- return float(1.0 - np.var(e_res) / max(np.var(o_res), 1.0))
- def _compute_text_sharpness_score(
- enhanced: np.ndarray, win: int = 3
- ) -> float:
- """局部标准差均值,越大 = 文字越清晰。"""
- e_f = enhanced.astype(np.float32)
- kernel = np.ones((win, win), np.float32) / (win * win)
- mean = cv2.filter2D(e_f, -1, kernel)
- sq_mean = cv2.filter2D(e_f * e_f, -1, kernel)
- var = np.maximum(sq_mean - mean * mean, 0)
- return float(np.sqrt(var).mean())
- # ── 对比图生成 ──────────────────────────────────────────────────
- def _make_quad_compare(
- original: np.ndarray,
- top_enhanced: List[Tuple[str, np.ndarray]],
- ) -> np.ndarray:
- """生成四宫格对比图:原图 | 最佳 text_restore | 最佳 clahe | 最佳 gamma。"""
- panels = [original]
- labels = ["Original"]
- for label, img in top_enhanced:
- panels.append(img)
- labels.append(label)
- # 全部转 BGR
- bgr_panels: List[np.ndarray] = []
- for p in panels:
- if p.ndim == 2:
- bgr_panels.append(cv2.cvtColor(p, cv2.COLOR_GRAY2BGR))
- else:
- bgr_panels.append(p)
- # 统一高度
- h = max(p.shape[0] for p in bgr_panels)
- w = max(p.shape[1] for p in bgr_panels)
- resized: List[np.ndarray] = []
- for p, label in zip(bgr_panels, labels):
- if p.shape[0] != h or p.shape[1] != w:
- p = cv2.resize(p, (w, h))
- bar = np.ones((40, w, 3), dtype=np.uint8) * 240
- cv2.putText(bar, label, (12, 28), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (0, 0, 0), 2)
- resized.append(np.vstack([p, bar]))
- return np.hstack(resized)
- # ── OCR(整页对比)───────────────────────────────────────────────
- def _poly_to_bbox(poly: List[List[float]]) -> List[int]:
- """四点 polygon 转轴对齐 bbox [x0,y0,x1,y1]."""
- xs = [p[0] for p in poly]
- ys = [p[1] for p in poly]
- return [int(min(xs)), int(min(ys)), int(max(xs)), int(max(ys))]
- def _ocr_full_page(engine: Any, img: np.ndarray) -> List[Dict[str, Any]]:
- """整页 OCR(det+rec),返回 spans 列表。
- 每个 span: {"poly": [[x,y],...], "bbox": [x0,y0,x1,y1], "text": "...", "confidence": 0.9}
- """
- res = engine.ocr(img, det=True, rec=True)
- items = res[0] if res and res[0] is not None else []
- spans: List[Dict[str, Any]] = []
- for item in items:
- if not item or len(item) < 2:
- continue
- box, rec_part = item[0], item[1]
- text = str(rec_part[0] if isinstance(rec_part, (list, tuple)) else rec_part)
- conf = float(rec_part[1]) if isinstance(rec_part, (list, tuple)) and len(rec_part) > 1 else 0.0
- poly = [[float(p[0]), float(p[1])] for p in box] if box else []
- spans.append({
- "poly": poly,
- "bbox": _poly_to_bbox(poly) if poly else [],
- "text": text.strip(),
- "confidence": round(conf, 4),
- })
- from ocr_tools.universal_doc_parser.core.layout_utils import SpanMatcher
- spans = SpanMatcher.remove_duplicate_spans(spans)
- return spans
- def _save_ocr_debug_for_sweep(
- image: np.ndarray,
- spans: List[Dict[str, Any]],
- out_dir: Path,
- tag: str,
- ) -> Dict[str, str]:
- """保存 OCR 可视化图和 JSON。复用 module_debug_viz.draw_ocr_spans_cv2。"""
- from ocr_utils.module_debug_viz import draw_ocr_spans_cv2
- ocr_dir = out_dir / "ocr"
- ocr_dir.mkdir(parents=True, exist_ok=True)
- # 可视化图
- vis = draw_ocr_spans_cv2(image, spans)
- img_path = ocr_dir / f"{tag}_ocr_spans.png"
- cv2.imwrite(str(img_path), vis)
- # JSON
- json_data = {
- "tag": tag,
- "count": len(spans),
- "spans": [
- {
- "bbox": s.get("bbox"),
- "poly": s.get("poly"),
- "text": s.get("text"),
- "confidence": s.get("confidence"),
- }
- for s in spans
- ],
- }
- json_path = ocr_dir / f"{tag}_ocr_spans.json"
- json_path.write_text(json.dumps(json_data, ensure_ascii=False, indent=2), encoding="utf-8")
- logger.info(f" OCR debug: {img_path}")
- return {"image": str(img_path), "json": str(json_path)}
- def _compare_ocr_results(
- orig_spans: List[Dict[str, Any]],
- enh_spans: List[Dict[str, Any]],
- iou_threshold: float = 0.5,
- ) -> Dict[str, Any]:
- """对比两组 OCR spans 的检测+识别差异。
- Returns:
- {
- "detection": { orig_count, enh_count, matched, new, missing },
- "recognition": { text_changed_count, char_diff_rate, details: [...] },
- "summary": "一句话摘要"
- }
- """
- def _bbox_iou(a: List[int], b: List[int]) -> float:
- if not a or not b:
- return 0.0
- xa = max(a[0], b[0])
- ya = max(a[1], b[1])
- xb = min(a[2], b[2])
- yb = min(a[3], b[3])
- inter = max(0, xb - xa) * max(0, yb - ya)
- area_a = max(0, a[2] - a[0]) * max(0, a[3] - a[1])
- area_b = max(0, b[2] - b[0]) * max(0, b[3] - b[1])
- union = area_a + area_b - inter
- return inter / union if union > 0 else 0.0
- # ── 检测对比 ──
- orig_boxes = [s.get("bbox", []) for s in orig_spans]
- enh_boxes = [s.get("bbox", []) for s in enh_spans]
- matched_orig_idxs: set = set()
- matched_enh_idxs: set = set()
- recognition_details: List[Dict[str, Any]] = []
- for i, ob in enumerate(orig_boxes):
- if not ob:
- continue
- best_j, best_iou = -1, 0.0
- for j, eb in enumerate(enh_boxes):
- if j in matched_enh_idxs or not eb:
- continue
- iou = _bbox_iou(ob, eb)
- if iou > best_iou:
- best_iou, best_j = iou, j
- if best_iou >= iou_threshold:
- matched_orig_idxs.add(i)
- matched_enh_idxs.add(best_j)
- orig_text = orig_spans[i].get("text", "")
- enh_text = enh_spans[best_j].get("text", "")
- orig_score = orig_spans[i].get("confidence", 0)
- enh_score = enh_spans[best_j].get("confidence", 0)
- rec_detail: Dict[str, Any] = {
- "orig_bbox": ob,
- "orig_text": orig_text,
- "orig_score": orig_score,
- "enh_text": enh_text,
- "enh_score": enh_score,
- "iou": round(best_iou, 4),
- }
- if orig_text != enh_text:
- rec_detail["text_changed"] = True
- else:
- rec_detail["text_changed"] = False
- recognition_details.append(rec_detail)
- new_boxes = len(enh_boxes) - len(matched_enh_idxs)
- missing_boxes = len(orig_boxes) - len(matched_orig_idxs)
- # 字符差异率
- orig_concat = "".join(s.get("text", "") for s in orig_spans)
- enh_concat = "".join(s.get("text", "") for s in enh_spans)
- total_chars = max(len(orig_concat), len(enh_concat), 1)
- char_diff = sum(1 for a, b in zip(orig_concat, enh_concat) if a != b) + abs(
- len(orig_concat) - len(enh_concat)
- )
- char_diff_rate = round(char_diff / total_chars, 4)
- detection = {
- "orig_count": len(orig_boxes),
- "enh_count": len(enh_boxes),
- "matched": len(matched_orig_idxs),
- "new": new_boxes,
- "missing": missing_boxes,
- }
- recognition = {
- "text_changed_count": len(recognition_details),
- "char_diff_rate": char_diff_rate,
- "details": recognition_details[:50], # 最多保存50条差异明细
- }
- summary = (
- f"检测: {detection['orig_count']}→{detection['enh_count']} (匹配{detection['matched']}, "
- f"新增{detection['new']}, 遗失{detection['missing']}); "
- f"识别: 文字变化{recognition['text_changed_count']}处, 字符差异率{char_diff_rate:.2%}"
- )
- return {"detection": detection, "recognition": recognition, "summary": summary}
- def _load_paddle_engine(model_dir: Path, det_thresh: float = 0.3):
- from ocr_tools.pytorch_models.pytorch_paddle import PytorchPaddleOCR
- det = model_dir / "ch_PP-OCRv5_det_infer.pth"
- rec = model_dir / "ch_PP-OCRv4_rec_server_doc_infer.pth"
- return PytorchPaddleOCR(
- lang="ch",
- det_model_path=str(det) if det.exists() else None,
- rec_model_path=str(rec) if rec.exists() else None,
- det_db_box_thresh=det_thresh,
- )
- # ── 扫描核心 ────────────────────────────────────────────────────
- def run_sweep(
- input_path: Path,
- out_dir: Path,
- *,
- quick: bool = False,
- save_images: bool = True,
- ocr_enabled: bool = False,
- model_dir: Optional[Path] = None,
- ocr_top_n: int = 1,
- ocr_all: bool = False,
- ) -> Dict[str, Any]:
- bgr = cv2.imread(str(input_path))
- if bgr is None:
- raise RuntimeError(f"无法读取图像: {input_path}")
- gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
- stem = input_path.stem
- img_out = out_dir / stem
- img_out.mkdir(parents=True, exist_ok=True)
- grid = _build_param_grid(quick=quick)
- logger.info(f" {stem}: {len(grid)} 组参数组合")
- engine = None
- baseline_spans: List[Dict[str, Any]] = []
- if ocr_enabled:
- try:
- md = model_dir or Path(
- "/Users/zhch158/models/modelscope_cache/models/OpenDataLab/"
- "PDF-Extract-Kit-1___0/models/OCR/paddleocr_torch"
- )
- engine = _load_paddle_engine(md)
- logger.info(" OCR 引擎已加载")
- # 基线 OCR(原始灰度图)
- baseline_spans = _ocr_full_page(engine, gray)
- logger.info(f" 基线 OCR: {len(baseline_spans)} 个文本块")
- _save_ocr_debug_for_sweep(bgr, baseline_spans, img_out, f"{stem}_original")
- except Exception as e:
- logger.warning(f" OCR 引擎加载失败: {e}")
- results: List[Dict[str, Any]] = []
- # 按 method 分组, 便于后面取各类别最优
- method_groups: Dict[str, List[Dict[str, Any]]] = {}
- for cfg in grid:
- tag = _tag_from_cfg(cfg)
- t0 = time.perf_counter()
- try:
- enhanced = enhance_document_contrast(gray, **cfg)
- except Exception as e:
- logger.warning(f" [{tag}] 增强失败: {e}")
- continue
- elapsed = time.perf_counter() - t0
- fade = _compute_watermark_fade_score(gray, enhanced)
- sharpness = _compute_text_sharpness_score(enhanced)
- # 综合分:fade(水印淡化) 和 sharpness(文字清晰度) 同等权重
- combined = round(fade * 0.5 + sharpness / max(sharpness, 0.01) * 0.5, 4)
- row: Dict[str, Any] = {
- "tag": tag,
- **cfg,
- "fade_score": round(fade, 6),
- "sharpness_score": round(sharpness, 4),
- "combined_score": round(combined, 4),
- "time_ms": round(elapsed * 1000, 1),
- }
- if save_images:
- out_path = img_out / f"{tag}.png"
- cv2.imwrite(str(out_path), enhanced)
- row["image_path"] = str(out_path)
- results.append(row)
- method = cfg["method"]
- method_groups.setdefault(method, []).append(row)
- # ── 排序 ──
- results.sort(key=lambda r: -r["combined_score"])
- for mname, entries in method_groups.items():
- entries.sort(key=lambda r: -r["combined_score"])
- # Top 各方法最优
- tops: List[Tuple[str, str, float]] = []
- for mname, entries in method_groups.items():
- if entries:
- top = entries[0]
- tops.append((mname, top["tag"], top["combined_score"]))
- logger.info(f" [{mname}] Top: {top['tag']} combined={top['combined_score']:.4f}")
- logger.info(f" 全局 Top1: {results[0]['tag']} combined={results[0]['combined_score']:.4f}")
- # ── 阶段二:OCR 对比(整页)─────────────────────────────────
- ocr_comparisons: List[Dict[str, Any]] = []
- if engine and baseline_spans:
- # 选择要跑 OCR 的组合列表
- if ocr_all:
- ocr_candidates = results
- else:
- ocr_candidates: List[Dict[str, Any]] = []
- for mname, entries in method_groups.items():
- for r in entries[:ocr_top_n]:
- if r not in ocr_candidates:
- ocr_candidates.append(r)
- logger.info(f" OCR 对比 {len(ocr_candidates)} 个组合(每方法 Top-{ocr_top_n})")
- for r in ocr_candidates:
- tag = r["tag"]
- enhanced_path = r.get("image_path")
- if enhanced_path:
- enhanced_bgr = cv2.imread(enhanced_path)
- if enhanced_bgr is None:
- logger.warning(f" [{tag}] 无法读取增强图")
- continue
- enhanced_gray = cv2.cvtColor(enhanced_bgr, cv2.COLOR_BGR2GRAY)
- else:
- # 从 raws 重新生成
- enhanced_gray = enhance_document_contrast(gray, **{
- k: v for k, v in r.items()
- if k not in ("tag", "image_path", "fade_score", "sharpness_score", "combined_score", "time_ms")
- })
- enhanced_bgr = cv2.cvtColor(enhanced_gray, cv2.COLOR_GRAY2BGR)
- try:
- enh_spans = _ocr_full_page(engine, enhanced_gray)
- except Exception as e:
- logger.warning(f" [{tag}] OCR 失败: {e}")
- continue
- _save_ocr_debug_for_sweep(enhanced_bgr, enh_spans, img_out, f"{stem}_{tag}")
- cmp = _compare_ocr_results(baseline_spans, enh_spans)
- cmp["tag"] = tag
- cmp["method"] = r["method"]
- r["ocr_comparison"] = cmp
- ocr_comparisons.append(cmp)
- for k, v in cmp["detection"].items():
- r[f"ocr_det_{k}"] = v
- for k, v in cmp["recognition"].items():
- if not isinstance(v, list):
- r[f"ocr_rec_{k}"] = v
- logger.info(f" [{tag}] {cmp['summary']}")
- # ── 四宫格对比图 ──
- if save_images and len(tops) >= 3:
- selected_labels = []
- selected_imgs = []
- seen_methods = set()
- for r in results:
- m = r["method"]
- if m in seen_methods:
- continue
- seen_methods.add(m)
- path = r.get("image_path")
- if path:
- img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
- if img is not None:
- selected_labels.append(r["tag"])
- selected_imgs.append(img)
- if len(selected_imgs) >= 3:
- break
- if selected_imgs:
- quad = _make_quad_compare(gray, list(zip(selected_labels, selected_imgs)))
- quad_path = img_out / "quad_compare.png"
- cv2.imwrite(str(quad_path), quad)
- logger.info(f" 四宫格对比图: {quad_path}")
- # ── 报告 ──
- report: Dict[str, Any] = {
- "input": str(input_path),
- "output_dir": str(img_out),
- "n_configs_tested": len(results),
- "top_overall": results[0] if results else None,
- "top_by_method": {
- m: e[0] for m, e in method_groups.items() if e
- },
- }
- if engine:
- baseline_text = "".join(s.get("text", "") for s in baseline_spans)
- report["baseline_ocr"] = {
- "span_count": len(baseline_spans),
- "full_text": baseline_text,
- }
- report["ocr_comparisons"] = {
- "n_compared": len(ocr_comparisons),
- "results": ocr_comparisons,
- }
- report_path = img_out / "contrast_report.json"
- report_path.write_text(
- json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8"
- )
- # 单独的 OCR 对比汇总报告(含完整检测+识别对比数据)
- if engine:
- ocr_summary_path = img_out / "ocr" / "ocr_comparison.json"
- ocr_summary_path.parent.mkdir(parents=True, exist_ok=True)
- ocr_summary = {
- "input": str(input_path),
- "baseline_spans": len(baseline_spans),
- "compared": ocr_comparisons,
- }
- ocr_summary_path.write_text(
- json.dumps(ocr_summary, ensure_ascii=False, indent=2), encoding="utf-8"
- )
- # CSV
- if results:
- csv_keys = [k for k in results[0].keys() if not k.endswith("_path") and k != "ocr_comparison"]
- lines = [",".join(csv_keys)]
- for r in results:
- lines.append(",".join(str(r.get(k, "")) for k in csv_keys))
- (img_out / "contrast_summary.csv").write_text("\n".join(lines), encoding="utf-8")
- logger.info(f" 报告: {report_path}")
- return report
- # ── CLI ──────────────────────────────────────────────────────────
- def _build_arg_parser() -> argparse.ArgumentParser:
- p = argparse.ArgumentParser(
- description="对比度增强参数网格扫描(不去水印,直接增强前后对比)",
- )
- p.add_argument("input", type=Path, help="单张图片路径或图片目录")
- p.add_argument("-o", "--output", type=Path, default=None,
- help="输出根目录,默认 input 同级 contrast_out/<stem>")
- p.add_argument("--quick", action="store_true",
- help="缩小参数网格")
- p.add_argument("--no-save-images", action="store_true",
- help="不写出增强结果图")
- p.add_argument("--ocr", action="store_true",
- help="启用整页 OCR 对比(det+rec):基线 OCR + Top-N 增强图 OCR,输出 spans 可视化和 JSON")
- p.add_argument("--ocr-top-n", type=int, default=1,
- help="OCR 对比时每方法取 Top-N 组合(默认 1)")
- p.add_argument("--ocr-all", action="store_true",
- help="对所有参数组合跑 OCR 对比(覆盖 --ocr-top-n)")
- p.add_argument("--model-dir", type=Path, default=None,
- help="PaddleOCR 模型目录")
- return p
- def main(argv: Optional[Sequence[str]] = None) -> None:
- args = _build_arg_parser().parse_args(argv)
- images = _collect_images(args.input)
- if not images:
- raise SystemExit("未找到可扫描的图像")
- if args.output is not None:
- out_root = args.output
- elif args.input.is_file():
- out_root = args.input.parent / "contrast_out"
- else:
- out_root = args.input / "contrast_out"
- out_root.mkdir(parents=True, exist_ok=True)
- logger.info(f"扫描 {len(images)} 张图 -> {out_root}")
- logger.info(f" quick={args.quick} ocr={args.ocr} ocr_top_n={args.ocr_top_n} ocr_all={args.ocr_all}")
- summary: List[Dict[str, Any]] = []
- for img_path in images:
- logger.info(f"\n=== {img_path.name} ===")
- report = run_sweep(
- img_path,
- out_root,
- quick=args.quick,
- save_images=not args.no_save_images,
- ocr_enabled=args.ocr,
- model_dir=args.model_dir,
- ocr_top_n=args.ocr_top_n,
- ocr_all=args.ocr_all,
- )
- to = report.get("top_overall")
- summary.append({
- "input": report["input"],
- "n_tested": report["n_configs_tested"],
- "top_tag": to["tag"] if to else None,
- "top_combined": to["combined_score"] if to else None,
- "report": str(Path(report["output_dir"]) / "contrast_report.json"),
- })
- index_path = out_root / "contrast_index.json"
- index_path.write_text(json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8")
- logger.info(f"\n全部完成。索引: {index_path}")
- for s in summary:
- logger.info(f" {Path(s['input']).name}: Top={s['top_tag']} combined={s['top_combined']}")
- if __name__ == "__main__":
- # python contrast_sweep.py ../gan_experiments_lab/test_images/input/彭_广东兴宁农村商业银行_page_002.png --ocr --ocr-top-n 1
- if len(sys.argv) == 1:
- print("ℹ️ 未提供命令行参数,使用默认配置运行...")
- default_config = {
- "input": "../gan_experiments_lab/test_images/input/彭_广东兴宁农村商业银行_page_002.png",
- "output": "./output/彭_广东兴宁农村商业银行_page_002/contrast_sweep",
- "ocr": True,
- "ocr_top_n": 3,
- "quick": True,
- }
- sys.argv = [sys.argv[0], default_config["input"]]
- for key, value in default_config.items():
- if key == "input":
- continue
- flag = f"--{key.replace('_', '-')}"
- if isinstance(value, bool) and value:
- sys.argv.append(flag)
- elif not isinstance(value, bool):
- sys.argv.extend([flag, str(value)])
- sys.exit(main())
|