| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971 |
- #!/usr/bin/env python3
- """
- 单元格裁剪图预处理参数扫描:去水印 / contrast(clahe/gamma/linear/text_restore)/ upscale / det 阈值 / OCR 模式。
- 支持 contrast 在放大前/后执行两种顺序对比。
- 默认从 **原图**(`*_raw.png`)出发,与 pipeline 二次 OCR 一致,避免对已预处理 debug 图二次去水印。
- 用法:
- python cell_sweep.py cell219_empty_empty_raw.png -o ./out -t "ATM存折取款"
- python cell_sweep.py /path/to/tablecell_ocr/ -o ./out
- python cell_sweep.py cell.png --quick --no-save-images
- python cell_sweep.py cell.png --contrast-orders before_upscale,after_upscale
- OCR_DET_MODEL_PATH=... OCR_REC_MODEL_PATH=... python cell_sweep.py cell.png
- # 统计出的最优参数 tag: threshold_t150_cl_1.0_8_ob_u128_det0.5
- # 对目录下所有 *_raw.png 验证适配性
- python cell_sweep.py /path/to/tablecell_ocr/ -o ./out --best-only
- # 自定义最优参数
- python cell_sweep.py /path/to/tablecell_ocr/ -o ./out --best-only \
- --best-config threshold_t150_cl_1.0_8_ob_u128_det0.5
- # 指定目标文字,自动统计 HIT 命中率
- python cell_sweep.py /path/to/tablecell_ocr/ -o ./out --best-only -t "交易类型"
- """
- from __future__ import annotations
- import argparse
- import json
- import os
- import sys
- from itertools import product
- from pathlib import Path
- from typing import Any, Dict, List, Optional, Sequence, Tuple
- import cv2
- import numpy as np
- _repo_root = Path(__file__).resolve().parents[3]
- if str(_repo_root) not in sys.path:
- sys.path.insert(0, str(_repo_root))
- from ocr_utils.watermark import WatermarkProcessor, merge_watermark_config
- from ocr_utils.watermark.contrast import enhance_document_contrast
- _IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp"}
- _DEFAULT_MODEL_DIR = Path(
- "/Users/zhch158/models/modelscope_cache/models/OpenDataLab/"
- "PDF-Extract-Kit-1___0/models/OCR/paddleocr_torch"
- )
- def _parse_csv_ints(s: str) -> List[Optional[int]]:
- out: List[Optional[int]] = []
- for part in s.split(","):
- part = part.strip()
- if not part or part.lower() in ("none", "d", "default"):
- out.append(None)
- else:
- out.append(int(part))
- return out
- def _parse_csv_floats(s: str) -> List[float]:
- return [float(x.strip()) for x in s.split(",") if x.strip()]
- def _parse_csv_bools(s: str) -> List[bool]:
- out: List[bool] = []
- for part in s.split(","):
- p = part.strip().lower()
- if p in ("1", "true", "yes", "on"):
- out.append(True)
- elif p in ("0", "false", "no", "off"):
- out.append(False)
- else:
- raise ValueError(f"无效的 bool 值: {part!r}")
- return out
- def _default_model_dir() -> Path:
- det = os.environ.get("OCR_DET_MODEL_PATH")
- if det:
- return Path(det).parent
- return _DEFAULT_MODEL_DIR
- def _upscale(img: np.ndarray, min_side: int) -> np.ndarray:
- h, w = img.shape[:2]
- if h >= min_side and w >= min_side:
- return img
- s = max(min_side / max(h, 1), min_side / max(w, 1), 1.0)
- return cv2.resize(img, None, fx=s, fy=s, interpolation=cv2.INTER_CUBIC)
- # ── 对比度增强方法(clahe / gamma / linear / text_restore / none)──
- def _apply_contrast(
- gray: np.ndarray,
- *,
- method: str,
- clip_limit: float = 1.0,
- tile_grid_size: int = 8,
- gamma: float = 0.85,
- black_percentile: float = 2.0,
- white_percentile: float = 98.0,
- text_black_target: int = 85,
- background_threshold: int = 248,
- ) -> np.ndarray:
- """对灰度图应用对比度增强;method="none" 时原样返回。"""
- if method == "none":
- return gray
- if method == "text_restore":
- return enhance_document_contrast(
- gray, method="text_restore",
- text_black_target=text_black_target,
- background_threshold=background_threshold,
- )
- if method == "clahe":
- return enhance_document_contrast(
- gray, method="clahe",
- clip_limit=clip_limit, tile_grid_size=tile_grid_size,
- )
- if method == "gamma":
- return enhance_document_contrast(gray, method="gamma", gamma=gamma)
- if method == "linear":
- return enhance_document_contrast(
- gray, method="linear",
- black_percentile=black_percentile,
- white_percentile=white_percentile,
- )
- return gray
- def _contrast_tag(cfg: Dict[str, Any]) -> str:
- """生成 contrast 配置的短标签。"""
- m = cfg.get("method", "none")
- if m == "none":
- return "c0"
- if m == "text_restore":
- return f"tr_{cfg.get('text_black_target', 85)}"
- if m == "clahe":
- return f"cl_{cfg.get('clip_limit', 1.0)}_{cfg.get('tile_grid_size', 8)}"
- if m == "gamma":
- return f"gm_{cfg.get('gamma', 0.85)}"
- if m == "linear":
- return f"ln_{cfg.get('black_percentile', 2.0)}_{cfg.get('white_percentile', 98.0)}"
- return m
- def _build_contrast_grid(quick: bool = False) -> List[Dict[str, Any]]:
- """构建 contrast 参数网格(对齐 contrast_sweep.py 的设计)。
- 返回列表,每个元素是一个 Dict,至少包含 "method" 字段。
- """
- grid: List[Dict[str, Any]] = [{"method": "none"}] # 对照组:不增强
- # text_restore
- if quick:
- tbt = [60, 85]
- bts = [240, 248]
- else:
- tbt = [60, 85, 100, 120]
- bts = [240, 248, 252]
- for target, bg_th in product(tbt, bts):
- grid.append({"method": "text_restore", "text_black_target": target, "background_threshold": bg_th})
- # clahe
- if quick:
- cl = [1.0, 2.0]
- ts = [4, 8]
- else:
- cl = [0.5, 1.0, 2.0, 3.0, 5.0]
- ts = [4, 8]
- for clip, tile in product(cl, ts):
- grid.append({"method": "clahe", "clip_limit": clip, "tile_grid_size": tile})
- # # gamma
- # if quick:
- # gvs = [0.5, 0.85]
- # else:
- # gvs = [0.4, 0.55, 0.7, 0.85]
- # for g in gvs:
- # grid.append({"method": "gamma", "gamma": g})
- # # linear
- # if quick:
- # bps = [2.0, 5.0]
- # wps = [95.0, 98.0]
- # else:
- # bps = [2.0, 5.0, 8.0]
- # wps = [95.0, 98.0]
- # for bp, wp in product(bps, wps):
- # grid.append({"method": "linear", "black_percentile": bp, "white_percentile": wp})
- return grid
- def _preprocess(
- raw: np.ndarray,
- *,
- method: str,
- thresh: Optional[int],
- contrast_cfg: Dict[str, Any],
- upscale: int,
- contrast_order: str = "before_upscale",
- ) -> np.ndarray:
- """预处理管线:去水印 → [contrast] → 放大(或去水印 → 放大 → contrast)。
- method="none" 时跳过去水印,直接从原图开始处理。
- """
- if method == "none":
- img = raw.copy() # 不处理水印,直接使用原图
- else:
- user: Dict[str, Any] = {"enabled": True, "method": method}
- if method == "threshold" and thresh is not None:
- user["threshold"] = thresh
- cfg = merge_watermark_config("cell", user)
- img, _ = WatermarkProcessor(cfg, scope="cell").process(raw, force=True)
- contrast_method = contrast_cfg.get("method", "none")
- if contrast_method != "none" and contrast_order == "before_upscale":
- gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
- gray = _apply_contrast(gray, **contrast_cfg)
- img = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
- img = _upscale(img, upscale)
- if contrast_method != "none" and contrast_order == "after_upscale":
- gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
- gray = _apply_contrast(gray, **contrast_cfg)
- img = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
- return img
- def _parse_rec_pair(rec_part: Any) -> Tuple[str, float]:
- """从 OCR 返回的 (text, score) 或嵌套结构中解析识别结果。"""
- if rec_part is None:
- return "", 0.0
- if isinstance(rec_part, (list, tuple)) and len(rec_part) >= 2:
- if isinstance(rec_part[0], (list, tuple, dict)):
- return "", 0.0
- txt = str(rec_part[0] or "").strip()
- try:
- sc = float(rec_part[1] or 0.0)
- except (TypeError, ValueError):
- sc = 0.0
- return txt, sc if txt else 0.0
- if isinstance(rec_part, (list, tuple)) and len(rec_part) == 1:
- txt = str(rec_part[0] or "").strip()
- return txt, 0.0
- return "", 0.0
- def _aggregate_rec_score(boxes: List[Dict[str, Any]]) -> float:
- """按字符数加权平均识别分(与 pipeline aggregate_line_ocr 一致)。"""
- total_len = sum(len(b.get("text") or "") for b in boxes)
- if total_len <= 0:
- return 0.0
- weighted = sum(
- len(b.get("text") or "") * float(b.get("score") or 0.0) for b in boxes
- )
- return weighted / total_len
- def _ocr(engine: Any, img: np.ndarray, *, det: bool, rec: bool) -> Dict[str, Any]:
- empty: Dict[str, Any] = {
- "text": "",
- "score": 0.0,
- "boxes": [],
- "det": det,
- "rec": rec,
- "n_boxes": 0,
- }
- try:
- res = engine.ocr(img, det=det, rec=rec)
- items = res[0] if res and res[0] is not None else []
- boxes_out: List[Dict[str, Any]] = []
- if det:
- for item in items:
- if not item or len(item) < 2:
- continue
- text, score = _parse_rec_pair(item[1])
- bbox = item[0]
- if hasattr(bbox, "tolist"):
- bbox = bbox.tolist()
- entry: Dict[str, Any] = {
- "text": text,
- "score": round(score, 6),
- }
- if bbox is not None:
- entry["det_bbox"] = bbox
- boxes_out.append(entry)
- else:
- for item in items:
- text, score = _parse_rec_pair(item)
- if not text and isinstance(item, (list, tuple)) and len(item) >= 1:
- text, score = _parse_rec_pair(item[0])
- boxes_out.append({"text": text, "score": round(score, 6)})
- text = "".join(b["text"] for b in boxes_out if b.get("text")).strip()
- agg_score = _aggregate_rec_score(boxes_out)
- return {
- "text": text,
- "score": round(agg_score, 6),
- "boxes": boxes_out,
- "det": det,
- "rec": rec,
- "n_boxes": len(boxes_out),
- }
- except Exception as e:
- out = dict(empty)
- out["error"] = str(e)
- return out
- def _make_engine(det_thresh: float, model_dir: Path) -> Any:
- from ocr_tools.pytorch_models.pytorch_paddle import PytorchPaddleOCR
- det_path = os.environ.get("OCR_DET_MODEL_PATH") or str(
- model_dir / "ch_PP-OCRv5_det_infer.pth"
- )
- rec_path = os.environ.get("OCR_REC_MODEL_PATH") or str(
- model_dir / "ch_PP-OCRv4_rec_server_doc_infer.pth"
- )
- return PytorchPaddleOCR(
- lang="ch",
- det_model_path=det_path,
- rec_model_path=rec_path,
- det_db_box_thresh=det_thresh,
- )
- def resolve_input_image(path: Path, *, prefer_raw: bool) -> Path:
- """优先使用与 pipeline debug 配套的 *_raw.png。"""
- if not prefer_raw or path.stem.endswith("_raw"):
- return path
- raw_path = path.parent / f"{path.stem}_raw{path.suffix}"
- if raw_path.is_file():
- print(f" 使用原图: {raw_path.name}(跳过 {path.name})")
- return raw_path
- return path
- def collect_inputs(path: Path, *, prefer_raw: bool) -> List[Path]:
- if path.is_file():
- if path.suffix.lower() not in _IMAGE_SUFFIXES:
- raise ValueError(f"不支持的图像格式: {path}")
- return [resolve_input_image(path, prefer_raw=prefer_raw)]
- if not path.is_dir():
- raise FileNotFoundError(path)
- all_images = sorted(
- p
- for p in path.iterdir()
- if p.is_file() and p.suffix.lower() in _IMAGE_SUFFIXES
- )
- if not all_images:
- raise FileNotFoundError(f"目录内无图像: {path}")
- if prefer_raw:
- raws = [p for p in all_images if p.stem.endswith("_raw")]
- if raws:
- return raws
- chosen: List[Path] = []
- for p in all_images:
- if p.stem.endswith("_raw"):
- continue
- raw_sibling = p.parent / f"{p.stem}_raw{p.suffix}"
- if prefer_raw and raw_sibling.is_file():
- continue
- chosen.append(p)
- return chosen or all_images
- def _match_hit(text: str, target: Optional[str]) -> Optional[str]:
- if not text:
- return None
- if not target:
- return "nonempty"
- if target in text:
- return "full"
- if len(target) >= 6 and target.isdigit() and len(text) >= 6 and text.isdigit():
- return "partial"
- return None
- def run_sweep(
- input_path: Path,
- out_dir: Path,
- *,
- prefer_raw: bool,
- target: Optional[str],
- model_dir: Path,
- methods: Sequence[str],
- thresholds: Sequence[Optional[int]],
- contrast_grid: List[Dict[str, Any]],
- contrast_orders: Sequence[str],
- upscales: Sequence[int],
- det_threshs: Sequence[float],
- save_images: bool,
- run_baseline: bool,
- baseline_upscale: int,
- ) -> Dict[str, Any]:
- resolved = resolve_input_image(input_path, prefer_raw=prefer_raw)
- raw = cv2.imread(str(resolved))
- if raw is None:
- raise RuntimeError(f"无法读取图像: {resolved}")
- stem = resolved.stem.removesuffix("_raw") if resolved.stem.endswith("_raw") else resolved.stem
- cell_out = out_dir / stem
- cell_out.mkdir(parents=True, exist_ok=True)
- ocr_modes: List[Tuple[str, bool, bool]] = [
- ("det_rec", True, True),
- ("whole_rec", False, True),
- ]
- results: List[Dict[str, Any]] = []
- hits: List[Dict[str, Any]] = []
- engines: Dict[float, Any] = {}
- total = 0
- for method, thresh, contrast_cfg, c_order, upscale, det_th in product(
- methods, thresholds, contrast_grid, contrast_orders, upscales, det_threshs
- ):
- # 过滤无效组合:非 threshold 方法不需要阈值
- if method not in ("threshold",):
- if thresh is not None:
- continue
- if det_th not in engines:
- print(f" [{stem}] 加载 OCR det_db_box_thresh={det_th} ...")
- engines[det_th] = _make_engine(det_th, model_dir)
- img = _preprocess(
- raw,
- method=method,
- thresh=thresh,
- contrast_cfg=contrast_cfg,
- upscale=upscale,
- contrast_order=c_order,
- )
- c_tag = _contrast_tag(contrast_cfg)
- o_tag = "b" if c_order == "before_upscale" else "a"
- tag = f"{method}_t{thresh or 'd'}_{c_tag}_o{o_tag}_u{upscale}_det{det_th}"
- if save_images:
- cv2.imwrite(str(cell_out / f"{tag}.png"), img)
- for mode_name, det, rec in ocr_modes:
- total += 1
- ocr = _ocr(engines[det_th], img, det=det, rec=rec)
- row: Dict[str, Any] = {
- "tag": tag,
- "method": method,
- "threshold": thresh,
- "contrast_method": contrast_cfg.get("method", "none"),
- "contrast_order": c_order,
- "contrast_cfg": contrast_cfg,
- "upscale": upscale,
- "det_db_box_thresh": det_th,
- "ocr_mode": mode_name,
- **ocr,
- }
- results.append(row)
- m = _match_hit(row.get("text", ""), target)
- if m:
- row["match"] = m
- hits.append(row)
- print(
- f" HIT [{m}] {mode_name} {tag} "
- f"score={row.get('score')} -> {row.get('text')!r}"
- )
- if run_baseline:
- for det_th in det_threshs:
- if det_th not in engines:
- engines[det_th] = _make_engine(det_th, model_dir)
- base_img = _upscale(raw, baseline_upscale)
- if save_images:
- cv2.imwrite(str(cell_out / f"baseline_upscale{baseline_upscale}.png"), base_img)
- for mode_name, det, rec in ocr_modes:
- ocr = _ocr(engines[det_th], base_img, det=det, rec=rec)
- row = {
- "tag": f"baseline_upscale{baseline_upscale}",
- "det_db_box_thresh": det_th,
- "ocr_mode": mode_name,
- **ocr,
- }
- results.append(row)
- m = _match_hit(row.get("text", ""), target)
- if m:
- row["match"] = m
- hits.append(row)
- report = {
- "input": str(resolved),
- "input_requested": str(input_path),
- "output_dir": str(cell_out),
- "target": target,
- "total_trials": total,
- "hits": hits,
- "all_results": results,
- }
- report_path = cell_out / "sweep_report.json"
- report_path.write_text(
- json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8"
- )
- # ── 结论报告:按 OCR score 排序,分组对比 ──
- _print_conclusions(stem, results, target)
- return report
- def _print_conclusions(
- stem: str,
- results: List[Dict[str, Any]],
- target: Optional[str],
- ) -> None:
- """打印实验结论:按 OCR score 排序,分组展示最优组合。"""
- if not results:
- return
- print(f"\n{'='*70}")
- print(f" 实验结论: {stem}")
- if target:
- print(f" 目标文字: {target}")
- print(f"{'='*70}")
- # 取 det_rec 模式的结果(优先用检测+识别完整结果)
- dr_results = [r for r in results if r.get("ocr_mode") == "det_rec" and r.get("text")]
- if not dr_results:
- dr_results = [r for r in results if r.get("text")]
- if not dr_results:
- print(" (无有效 OCR 结果)")
- return
- # ── 1. 全局 Top-5 ──
- scored = sorted(dr_results, key=lambda r: -(r.get("score") or 0))
- print("\n 全局 OCR 得分 Top-5:")
- for i, r in enumerate(scored[:5], 1):
- print(f" {i}. score={r.get('score', 0):.4f} text={r.get('text', '')!r}")
- print(f" tag={r.get('tag', '')}")
- # ── 2. 按 contrast 方法分组最佳 ──
- print("\n 按 contrast 方法分组最优(score 最高):")
- groups: Dict[str, List[Dict[str, Any]]] = {}
- for r in scored:
- cm = r.get("contrast_method", "?")
- groups.setdefault(cm, []).append(r)
- for cm in sorted(groups.keys()):
- best = groups[cm][0]
- wm = best.get("method", "?")
- print(f" [{cm}] 最佳: score={best.get('score', 0):.4f} "
- f"wm={wm} upscale={best.get('upscale')} "
- f"text={best.get('text', '')!r}")
- # ── 3. 有 watermark 处理 vs 无 watermark 处理对比 ──
- print("\n 去水印开关对比(同 contrast 方法,最高 score):")
- wm_groups: Dict[str, Dict[str, Any]] = {}
- for r in scored:
- cm = r.get("contrast_method", "?")
- wm = r.get("method", "?") if r.get("method") != "none" else "无去水印"
- key = f"{cm}|{wm}"
- cur_score = r.get("score") or 0
- prev_score = (wm_groups.get(key) or {}).get("score") or 0
- if key not in wm_groups or cur_score > prev_score:
- wm_groups[key] = r
- for cm in sorted(set(r.get("contrast_method", "?") for r in scored)):
- wm_rows = [r for k, r in wm_groups.items() if k.startswith(cm + "|")]
- if wm_rows:
- best_row = max(wm_rows, key=lambda r: r.get("score") or 0)
- wm_label = "无去水印" if best_row.get("method") == "none" else best_row.get("method", "?")
- print(f" [{cm}] 最优: wm={wm_label} score={best_row.get('score', 0):.4f} "
- f"text={best_row.get('text', '')!r}")
- # ── 4. 放大顺序对比 ──
- print("\n 放大前/后对比(同方法,最高 score):")
- order_data: Dict[str, Dict[str, Any]] = {}
- for r in scored:
- cm = r.get("contrast_method", "?")
- co = r.get("contrast_order", "?")
- key = f"{cm}|{co}"
- cur_score = r.get("score") or 0
- prev_score = (order_data.get(key) or {}).get("score") or 0
- if key not in order_data or cur_score > prev_score:
- order_data[key] = r
- for cm in sorted(set(r.get("contrast_method", "?") for r in scored)):
- b_score = (order_data.get(f"{cm}|before_upscale") or {}).get("score") or 0
- a_score = (order_data.get(f"{cm}|after_upscale") or {}).get("score") or 0
- better = "放大前" if b_score > a_score else ("放大后" if a_score > b_score else "持平")
- if b_score or a_score:
- print(f" [{cm}] 放大前={b_score:.4f} 放大后={a_score:.4f} 更优: {better}")
- # ── 5. HIT 命中率统计 ──
- if target:
- hit_count = sum(1 for r in results if r.get("match"))
- hit_by_cm: Dict[str, int] = {}
- for r in results:
- if r.get("match"):
- cm = r.get("contrast_method", "?")
- hit_by_cm[cm] = hit_by_cm.get(cm, 0) + 1
- print(f"\n HIT 命中率 (target={target}): {hit_count}/{len(results)}")
- for cm in sorted(hit_by_cm.keys()):
- print(f" [{cm}] HIT={hit_by_cm[cm]}")
- print(f"{'='*70}\n")
- def _parse_best_config(tag: str) -> Dict[str, Any]:
- """解析最优参数 tag,如 threshold_t150_cl_1.0_8_ob_u128_det0.5。
- tag 格式: {method}_t{thresh}_{c_tag}_o{b|a}_u{upscale}_det{det_th}
- """
- import re
- cfg: Dict[str, Any] = {}
- tag = tag.strip()
- # 解析 method: threshold | masked_adaptive | none
- m = re.match(r"(threshold|masked_adaptive|none)_t(\w+?)_(.+?)_o([ba])_u(\d+)_det([\d.]+)$", tag)
- if not m:
- raise ValueError(f"无法解析 best-config tag: {tag!r}")
- method, thresh_str, c_part, order_char, upscale, det_th = m.groups()
- cfg["method"] = method
- cfg["threshold"] = int(thresh_str) if thresh_str.isdigit() else None
- cfg["contrast_order"] = "before_upscale" if order_char == "b" else "after_upscale"
- cfg["upscale"] = int(upscale)
- cfg["det_db_box_thresh"] = float(det_th)
- # 解析 contrast 部分: cl_1.0_8 | tr_85 | gm_0.85 | ln_2.0_98.0 | c0
- if c_part == "c0":
- cfg["contrast_cfg"] = {"method": "none"}
- elif c_part.startswith("cl_"):
- parts = c_part.split("_")
- cfg["contrast_cfg"] = {"method": "clahe", "clip_limit": float(parts[1]), "tile_grid_size": int(parts[2])}
- elif c_part.startswith("tr_"):
- parts = c_part.split("_")
- cfg["contrast_cfg"] = {"method": "text_restore", "text_black_target": int(parts[1])}
- elif c_part.startswith("gm_"):
- parts = c_part.split("_")
- cfg["contrast_cfg"] = {"method": "gamma", "gamma": float(parts[1])}
- elif c_part.startswith("ln_"):
- parts = c_part.split("_")
- cfg["contrast_cfg"] = {"method": "linear", "black_percentile": float(parts[1]), "white_percentile": float(parts[2])}
- else:
- raise ValueError(f"无法解析 contrast tag: {c_part!r} (in {tag})")
- return cfg
- def run_best_config(
- input_path: Path,
- out_dir: Path,
- *,
- prefer_raw: bool,
- best_cfg: Dict[str, Any],
- model_dir: Path,
- save_images: bool,
- ) -> Dict[str, Any]:
- """对单图用指定最优参数跑一次 OCR。"""
- resolved = resolve_input_image(input_path, prefer_raw=prefer_raw)
- raw = cv2.imread(str(resolved))
- if raw is None:
- raise RuntimeError(f"无法读取图像: {resolved}")
- stem = resolved.stem.removesuffix("_raw") if resolved.stem.endswith("_raw") else resolved.stem
- cell_out = out_dir / stem
- cell_out.mkdir(parents=True, exist_ok=True)
- engine = _make_engine(best_cfg["det_db_box_thresh"], model_dir)
- img = _preprocess(
- raw,
- method=best_cfg["method"],
- thresh=best_cfg.get("threshold"),
- contrast_cfg=best_cfg["contrast_cfg"],
- upscale=best_cfg["upscale"],
- contrast_order=best_cfg["contrast_order"],
- )
- tag = best_cfg.get("_tag", "best")
- if save_images:
- cv2.imwrite(str(cell_out / f"{tag}.png"), img)
- ocr = _ocr(engine, img, det=True, rec=True)
- row: Dict[str, Any] = {
- "tag": tag,
- "method": best_cfg["method"],
- "threshold": best_cfg.get("threshold"),
- "contrast_method": best_cfg["contrast_cfg"].get("method", "none"),
- "contrast_order": best_cfg["contrast_order"],
- "contrast_cfg": best_cfg["contrast_cfg"],
- "upscale": best_cfg["upscale"],
- "det_db_box_thresh": best_cfg["det_db_box_thresh"],
- "ocr_mode": "det_rec",
- **ocr,
- }
- report = {
- "input": str(resolved),
- "input_requested": str(input_path),
- "output_dir": str(cell_out),
- "result": row,
- }
- report_path = cell_out / "best_result.json"
- report_path.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
- return report
- def _build_arg_parser() -> argparse.ArgumentParser:
- p = argparse.ArgumentParser(
- description="单元格图预处理 + OCR 参数网格扫描(对齐 pipeline 格级二次 OCR)",
- )
- p.add_argument(
- "input",
- type=Path,
- help="单元格裁剪图路径,或 tablecell_ocr 目录(批量扫描)",
- )
- p.add_argument(
- "-o",
- "--output",
- type=Path,
- default=None,
- help="输出目录,默认 <input_dir|input_parent>/sweep_out/<stem>",
- )
- p.add_argument(
- "-t",
- "--target",
- default=None,
- help="期望 OCR 文本;用于标记 HIT(子串匹配)。省略则任意非空为 HIT",
- )
- p.add_argument(
- "--model-dir",
- type=Path,
- default=None,
- help="PaddleOCR torch 模型目录(含 det/rec .pth),也可用 OCR_*_MODEL_PATH",
- )
- p.add_argument(
- "--no-prefer-raw",
- action="store_true",
- help="不自动选用同名的 *_raw.png",
- )
- p.add_argument(
- "--quick",
- action="store_true",
- help="缩小网格(threshold 155,165 × upscale 128,192 × det 0.5 × contrast 精简)",
- )
- p.add_argument(
- "--methods",
- default="threshold,masked_adaptive,none",
- help="去水印方式,逗号分隔;none=不去水印",
- )
- p.add_argument(
- "--thresholds",
- default="155,165,none",
- help="threshold 法的阈值;none=预设默认",
- )
- p.add_argument(
- "--contrast-orders",
- default="before_upscale,after_upscale",
- help="contrast 执行顺序: before_upscale(放大前), after_upscale(放大后), 逗号组合",
- )
- p.add_argument(
- "--upscales",
- default="128,192",
- help="最短边放大目标,逗号分隔整数",
- )
- p.add_argument(
- "--det-threshs",
- # default="0.2,0.3,0.4,0.5",
- default="0.5",
- help="det_db_box_thresh,逗号分隔",
- )
- p.add_argument(
- "--no-save-images",
- action="store_true",
- help="不写出中间预处理 png(仅报告)",
- )
- p.add_argument(
- "--no-baseline",
- action="store_true",
- help="跳过「仅放大、不去水印」对照组",
- )
- p.add_argument(
- "--baseline-upscale",
- type=int,
- default=192,
- help="baseline 对照组的最短边放大",
- )
- p.add_argument(
- "--best-only",
- action="store_true",
- help="不跑参数网格,对目录下所有图用 --best-config 指定参数跑一次,验证适配性",
- )
- p.add_argument(
- "--best-config",
- default="threshold_t150_cl_1.0_8_ob_u128_det0.5",
- help="最优参数 tag,如 threshold_t150_cl_1.0_8_ob_u128_det0.5",
- )
- return p
- def main(argv: Optional[Sequence[str]] = None) -> None:
- args = _build_arg_parser().parse_args(argv)
- inputs = collect_inputs(args.input, prefer_raw=not args.no_prefer_raw)
- if not inputs:
- raise SystemExit("未找到可扫描的图像")
- if args.output is not None:
- out_root = args.output
- elif args.input.is_file():
- out_root = args.input.parent / "sweep_out"
- else:
- out_root = args.input / "sweep_out"
- out_root.mkdir(parents=True, exist_ok=True)
- model_dir = args.model_dir or _default_model_dir()
- if args.best_only:
- # 验证适配性模式:对目录下所有图用最优参数跑一次
- best_cfg = _parse_best_config(args.best_config)
- best_cfg["_tag"] = args.best_config
- print(f"最佳参数验证模式: {args.best_config}")
- print(f" 解析: method={best_cfg['method']} contrast={best_cfg['contrast_cfg'].get('method')} "
- f"upscale={best_cfg['upscale']} order={best_cfg['contrast_order']}")
- print(f" 共 {len(inputs)} 张图")
- all_texts: List[Dict[str, Any]] = []
- hit_count = 0
- for img_path in inputs:
- report = run_best_config(
- img_path, out_root,
- prefer_raw=not args.no_prefer_raw,
- best_cfg=best_cfg,
- model_dir=model_dir,
- save_images=not args.no_save_images,
- )
- result = report["result"]
- text = result.get("text", "")
- score = result.get("score", 0)
- all_texts.append({
- "input": img_path.name,
- "text": text,
- "score": score,
- "report": str(Path(report["output_dir"]) / "best_result.json"),
- })
- m = _match_hit(text, args.target)
- hit_info = f" [HIT: {m}]" if m else ""
- print(f" {img_path.name}: score={score:.4f} text={text!r}{hit_info}")
- if m:
- hit_count += 1
- # 汇总
- summary_path = out_root / "best_summary.json"
- summary_data = {
- "best_config": args.best_config,
- "total": len(all_texts),
- "hits": hit_count,
- "target": args.target,
- "results": all_texts,
- }
- summary_path.write_text(json.dumps(summary_data, ensure_ascii=False, indent=2), encoding="utf-8")
- print(f"\n汇总: {hit_count}/{len(all_texts)} HIT -> {summary_path}")
- return
- # 正常参数网格扫描模式
- methods = [m.strip() for m in args.methods.split(",") if m.strip()]
- contrast_orders = [o.strip() for o in args.contrast_orders.split(",") if o.strip()]
- if args.quick:
- thresholds = [150, 155]
- upscales = [128, 192]
- det_threshs = [0.5]
- else:
- thresholds = _parse_csv_ints(args.thresholds)
- upscales = [int(x) for x in args.upscales.split(",") if x.strip()]
- det_threshs = _parse_csv_floats(args.det_threshs)
- contrast_grid = _build_contrast_grid(quick=args.quick)
- print(f"扫描 {len(inputs)} 张图 -> {out_root}")
- print(f" methods={methods} thresholds={thresholds} upscales={upscales}")
- print(f" contrast_methods={len(contrast_grid)} orders={contrast_orders}")
- if args.target:
- print(f" target={args.target!r}")
- summary: List[Dict[str, Any]] = []
- for img_path in inputs:
- print(f"\n=== {img_path.name} ===")
- report = run_sweep(
- img_path,
- out_root,
- prefer_raw=not args.no_prefer_raw,
- target=args.target,
- model_dir=model_dir,
- methods=methods,
- thresholds=thresholds,
- contrast_grid=contrast_grid,
- contrast_orders=contrast_orders,
- upscales=upscales,
- det_threshs=det_threshs,
- save_images=not args.no_save_images,
- run_baseline=not args.no_baseline,
- baseline_upscale=args.baseline_upscale,
- )
- summary.append(
- {
- "input": report["input"],
- "hits": len(report["hits"]),
- "report": str(Path(report["output_dir"]) / "sweep_report.json"),
- }
- )
- index_path = out_root / "sweep_index.json"
- index_path.write_text(
- json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8"
- )
- print(f"\n全部完成,索引: {index_path}")
- for s in summary:
- print(f" {s['input']}: {s['hits']} hits -> {s['report']}")
- if __name__ == "__main__":
- if len(sys.argv) == 1:
- print("ℹ️ 未提供命令行参数,使用默认配置运行...")
- default_config = {
- # "input": "/Users/zhch158/workspace/data/流水分析/彭_广东兴宁农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/彭_广东兴宁农村商业银行_page_002_0/cell219_empty_empty_raw.png",
- # "output": "./output/彭_广东兴宁农村商业银行/cell219_sweep",
- # "target": "ATM存折取款",
- # "input": "/Users/zhch158/workspace/data/流水分析/彭_广东兴宁农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/彭_广东兴宁农村商业银行_page_002_0/cell007_whole_longer_易型交类_raw.png",
- # "output": "./output/彭_广东兴宁农村商业银行/cell007_sweep",
- # "target": "交易类型",
- # "quick": True,
- # "input": "/Users/zhch158/workspace/data/流水分析/钟_广东陆丰农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/钟_广东陆丰农村商业银行_page_001_0/cell217_empty_empty_raw.png",
- # "output": "./output/钟_广东陆丰农村商业银行/cell217_sweep",
- # "target": "专项资金",
- # "quick": True,
- # "input": "/Users/zhch158/workspace/data/流水分析/彭_广东兴宁农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/彭_广东兴宁农村商业银行_page_002_0",
- # "output": "./output/彭_广东兴宁农村商业银行",
- # "best-config": "threshold_t150_cl_1.0_8_ob_u128_det0.5",
- # "best-only": True,
- "input": "/Users/zhch158/workspace/data/流水分析/钟_广东陆丰农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/钟_广东陆丰农村商业银行_page_001_0",
- "output": "./output/钟_广东陆丰农村商业银行",
- # "best-config": "threshold_t150_cl_1.0_8_ob_u128_det0.5",
- "best-config": "threshold_t150_cl_1.0_4_ob_u128_det0.5",
- "best-only": True,
- }
- sys.argv = [sys.argv[0], default_config["input"]]
- for key, value in default_config.items():
- if key == "input":
- continue
- flag = f"--{key.replace('_', '-')}"
- if isinstance(value, bool) and value:
- sys.argv.append(flag)
- elif not isinstance(value, bool):
- sys.argv.extend([flag, str(value)])
- sys.exit(main())
|