| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554 |
- #!/usr/bin/env python3
- """
- 单元格裁剪图预处理参数扫描:去水印 / threshold / contrast / upscale / det 阈值 / OCR 模式。
- 默认从 **原图**(`*_raw.png`)出发,与 pipeline 二次 OCR 一致,避免对已预处理 debug 图二次去水印。
- 用法:
- python cell_sweep.py cell219_empty_empty_raw.png -o ./out -t "ATM存折取款"
- python cell_sweep.py /path/to/tablecell_ocr/ -o ./out
- python cell_sweep.py cell.png --quick --no-save-images
- OCR_DET_MODEL_PATH=... OCR_REC_MODEL_PATH=... python cell_sweep.py cell.png
- """
- from __future__ import annotations
- import argparse
- import json
- import os
- import sys
- from itertools import product
- from pathlib import Path
- from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
- import cv2
- import numpy as np
- _repo_root = Path(__file__).resolve().parents[2]
- if str(_repo_root) not in sys.path:
- sys.path.insert(0, str(_repo_root))
- from ocr_utils.watermark import WatermarkProcessor, merge_watermark_config
- from ocr_utils.watermark.contrast import apply_contrast_enhancement_config
- _IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp"}
- _DEFAULT_MODEL_DIR = Path(
- "/Users/zhch158/models/modelscope_cache/models/OpenDataLab/"
- "PDF-Extract-Kit-1___0/models/OCR/paddleocr_torch"
- )
- def _parse_csv_ints(s: str) -> List[Optional[int]]:
- out: List[Optional[int]] = []
- for part in s.split(","):
- part = part.strip()
- if not part or part.lower() in ("none", "d", "default"):
- out.append(None)
- else:
- out.append(int(part))
- return out
- def _parse_csv_floats(s: str) -> List[float]:
- return [float(x.strip()) for x in s.split(",") if x.strip()]
- def _parse_csv_bools(s: str) -> List[bool]:
- out: List[bool] = []
- for part in s.split(","):
- p = part.strip().lower()
- if p in ("1", "true", "yes", "on"):
- out.append(True)
- elif p in ("0", "false", "no", "off"):
- out.append(False)
- else:
- raise ValueError(f"无效的 bool 值: {part!r}")
- return out
- def _default_model_dir() -> Path:
- det = os.environ.get("OCR_DET_MODEL_PATH")
- if det:
- return Path(det).parent
- return _DEFAULT_MODEL_DIR
- def _upscale(img: np.ndarray, min_side: int) -> np.ndarray:
- h, w = img.shape[:2]
- if h >= min_side and w >= min_side:
- return img
- s = max(min_side / max(h, 1), min_side / max(w, 1), 1.0)
- return cv2.resize(img, None, fx=s, fy=s, interpolation=cv2.INTER_CUBIC)
- def _preprocess(
- raw: np.ndarray,
- *,
- method: str,
- thresh: Optional[int],
- contrast: bool,
- upscale: int,
- text_black_target: int,
- ) -> np.ndarray:
- user: Dict[str, Any] = {"enabled": True, "method": method}
- if method == "threshold" and thresh is not None:
- user["threshold"] = thresh
- cfg = merge_watermark_config("cell", user)
- img, _ = WatermarkProcessor(cfg, scope="cell").process(raw, force=True)
- if contrast:
- gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
- ce = dict(cfg.get("contrast_enhancement") or {})
- ce["enabled"] = True
- ce["text_black_target"] = text_black_target
- gray = apply_contrast_enhancement_config(gray, ce)
- img = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
- return _upscale(img, upscale)
- def _parse_rec_pair(rec_part: Any) -> Tuple[str, float]:
- """从 OCR 返回的 (text, score) 或嵌套结构中解析识别结果。"""
- if rec_part is None:
- return "", 0.0
- if isinstance(rec_part, (list, tuple)) and len(rec_part) >= 2:
- if isinstance(rec_part[0], (list, tuple, dict)):
- return "", 0.0
- txt = str(rec_part[0] or "").strip()
- try:
- sc = float(rec_part[1] or 0.0)
- except (TypeError, ValueError):
- sc = 0.0
- return txt, sc if txt else 0.0
- if isinstance(rec_part, (list, tuple)) and len(rec_part) == 1:
- txt = str(rec_part[0] or "").strip()
- return txt, 0.0
- return "", 0.0
- def _aggregate_rec_score(boxes: List[Dict[str, Any]]) -> float:
- """按字符数加权平均识别分(与 pipeline aggregate_line_ocr 一致)。"""
- total_len = sum(len(b.get("text") or "") for b in boxes)
- if total_len <= 0:
- return 0.0
- weighted = sum(
- len(b.get("text") or "") * float(b.get("score") or 0.0) for b in boxes
- )
- return weighted / total_len
- def _ocr(engine: Any, img: np.ndarray, *, det: bool, rec: bool) -> Dict[str, Any]:
- empty: Dict[str, Any] = {
- "text": "",
- "score": 0.0,
- "boxes": [],
- "det": det,
- "rec": rec,
- "n_boxes": 0,
- }
- try:
- res = engine.ocr(img, det=det, rec=rec)
- items = res[0] if res and res[0] is not None else []
- boxes_out: List[Dict[str, Any]] = []
- if det:
- for item in items:
- if not item or len(item) < 2:
- continue
- text, score = _parse_rec_pair(item[1])
- bbox = item[0]
- if hasattr(bbox, "tolist"):
- bbox = bbox.tolist()
- entry: Dict[str, Any] = {
- "text": text,
- "score": round(score, 6),
- }
- if bbox is not None:
- entry["det_bbox"] = bbox
- boxes_out.append(entry)
- else:
- for item in items:
- text, score = _parse_rec_pair(item)
- if not text and isinstance(item, (list, tuple)) and len(item) >= 1:
- text, score = _parse_rec_pair(item[0])
- boxes_out.append({"text": text, "score": round(score, 6)})
- text = "".join(b["text"] for b in boxes_out if b.get("text")).strip()
- agg_score = _aggregate_rec_score(boxes_out)
- return {
- "text": text,
- "score": round(agg_score, 6),
- "boxes": boxes_out,
- "det": det,
- "rec": rec,
- "n_boxes": len(boxes_out),
- }
- except Exception as e:
- out = dict(empty)
- out["error"] = str(e)
- return out
- def _make_engine(det_thresh: float, model_dir: Path) -> Any:
- from ocr_tools.pytorch_models.pytorch_paddle import PytorchPaddleOCR
- det_path = os.environ.get("OCR_DET_MODEL_PATH") or str(
- model_dir / "ch_PP-OCRv5_det_infer.pth"
- )
- rec_path = os.environ.get("OCR_REC_MODEL_PATH") or str(
- model_dir / "ch_PP-OCRv4_rec_server_doc_infer.pth"
- )
- return PytorchPaddleOCR(
- lang="ch",
- det_model_path=det_path,
- rec_model_path=rec_path,
- det_db_box_thresh=det_thresh,
- )
- def resolve_input_image(path: Path, *, prefer_raw: bool) -> Path:
- """优先使用与 pipeline debug 配套的 *_raw.png。"""
- if not prefer_raw or path.stem.endswith("_raw"):
- return path
- raw_path = path.parent / f"{path.stem}_raw{path.suffix}"
- if raw_path.is_file():
- print(f" 使用原图: {raw_path.name}(跳过 {path.name})")
- return raw_path
- return path
- def collect_inputs(path: Path, *, prefer_raw: bool) -> List[Path]:
- if path.is_file():
- if path.suffix.lower() not in _IMAGE_SUFFIXES:
- raise ValueError(f"不支持的图像格式: {path}")
- return [resolve_input_image(path, prefer_raw=prefer_raw)]
- if not path.is_dir():
- raise FileNotFoundError(path)
- all_images = sorted(
- p
- for p in path.iterdir()
- if p.is_file() and p.suffix.lower() in _IMAGE_SUFFIXES
- )
- if not all_images:
- raise FileNotFoundError(f"目录内无图像: {path}")
- if prefer_raw:
- raws = [p for p in all_images if p.stem.endswith("_raw")]
- if raws:
- return raws
- chosen: List[Path] = []
- for p in all_images:
- if p.stem.endswith("_raw"):
- continue
- raw_sibling = p.parent / f"{p.stem}_raw{p.suffix}"
- if prefer_raw and raw_sibling.is_file():
- continue
- chosen.append(p)
- return chosen or all_images
- def _match_hit(text: str, target: Optional[str]) -> Optional[str]:
- if not text:
- return None
- if not target:
- return "nonempty"
- if target in text:
- return "full"
- if len(target) >= 6 and target.isdigit() and len(text) >= 6 and text.isdigit():
- return "partial"
- return None
- def run_sweep(
- input_path: Path,
- out_dir: Path,
- *,
- prefer_raw: bool,
- target: Optional[str],
- model_dir: Path,
- methods: Sequence[str],
- thresholds: Sequence[Optional[int]],
- contrasts: Sequence[bool],
- upscales: Sequence[int],
- det_threshs: Sequence[float],
- text_black_target: int,
- save_images: bool,
- run_baseline: bool,
- baseline_upscale: int,
- ) -> Dict[str, Any]:
- resolved = resolve_input_image(input_path, prefer_raw=prefer_raw)
- raw = cv2.imread(str(resolved))
- if raw is None:
- raise RuntimeError(f"无法读取图像: {resolved}")
- stem = resolved.stem.removesuffix("_raw") if resolved.stem.endswith("_raw") else resolved.stem
- cell_out = out_dir / stem
- cell_out.mkdir(parents=True, exist_ok=True)
- ocr_modes: List[Tuple[str, bool, bool]] = [
- ("det_rec", True, True),
- ("whole_rec", False, True),
- ]
- results: List[Dict[str, Any]] = []
- hits: List[Dict[str, Any]] = []
- engines: Dict[float, Any] = {}
- total = 0
- for method, thresh, contrast, upscale, det_th in product(
- methods, thresholds, contrasts, upscales, det_threshs
- ):
- if method != "threshold" and thresh is not None:
- continue
- if det_th not in engines:
- print(f" [{stem}] 加载 OCR det_db_box_thresh={det_th} ...")
- engines[det_th] = _make_engine(det_th, model_dir)
- img = _preprocess(
- raw,
- method=method,
- thresh=thresh,
- contrast=contrast,
- upscale=upscale,
- text_black_target=text_black_target,
- )
- tag = f"{method}_t{thresh or 'd'}_c{int(contrast)}_u{upscale}_det{det_th}"
- if save_images:
- cv2.imwrite(str(cell_out / f"{tag}.png"), img)
- for mode_name, det, rec in ocr_modes:
- total += 1
- ocr = _ocr(engines[det_th], img, det=det, rec=rec)
- row: Dict[str, Any] = {
- "tag": tag,
- "method": method,
- "threshold": thresh,
- "contrast": contrast,
- "upscale": upscale,
- "det_db_box_thresh": det_th,
- "ocr_mode": mode_name,
- **ocr,
- }
- results.append(row)
- m = _match_hit(row.get("text", ""), target)
- if m:
- row["match"] = m
- hits.append(row)
- print(
- f" HIT [{m}] {mode_name} {tag} "
- f"score={row.get('score')} -> {row.get('text')!r}"
- )
- if run_baseline:
- for det_th in det_threshs:
- if det_th not in engines:
- engines[det_th] = _make_engine(det_th, model_dir)
- base_img = _upscale(raw, baseline_upscale)
- if save_images:
- cv2.imwrite(str(cell_out / f"baseline_upscale{baseline_upscale}.png"), base_img)
- for mode_name, det, rec in ocr_modes:
- ocr = _ocr(engines[det_th], base_img, det=det, rec=rec)
- row = {
- "tag": f"baseline_upscale{baseline_upscale}",
- "det_db_box_thresh": det_th,
- "ocr_mode": mode_name,
- **ocr,
- }
- results.append(row)
- m = _match_hit(row.get("text", ""), target)
- if m:
- row["match"] = m
- hits.append(row)
- report = {
- "input": str(resolved),
- "input_requested": str(input_path),
- "output_dir": str(cell_out),
- "target": target,
- "total_trials": total,
- "hits": hits,
- "all_results": results,
- }
- report_path = cell_out / "sweep_report.json"
- report_path.write_text(
- json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8"
- )
- return report
- def _build_arg_parser() -> argparse.ArgumentParser:
- p = argparse.ArgumentParser(
- description="单元格图预处理 + OCR 参数网格扫描(对齐 pipeline 格级二次 OCR)",
- )
- p.add_argument(
- "input",
- type=Path,
- help="单元格裁剪图路径,或 tablecell_ocr 目录(批量扫描)",
- )
- p.add_argument(
- "-o",
- "--output",
- type=Path,
- default=None,
- help="输出目录,默认 <input_dir|input_parent>/sweep_out/<stem>",
- )
- p.add_argument(
- "-t",
- "--target",
- default=None,
- help="期望 OCR 文本;用于标记 HIT(子串匹配)。省略则任意非空为 HIT",
- )
- p.add_argument(
- "--model-dir",
- type=Path,
- default=None,
- help="PaddleOCR torch 模型目录(含 det/rec .pth),也可用 OCR_*_MODEL_PATH",
- )
- p.add_argument(
- "--no-prefer-raw",
- action="store_true",
- help="不自动选用同名的 *_raw.png",
- )
- p.add_argument(
- "--quick",
- action="store_true",
- help="缩小网格(threshold 170,175 × upscale 128,192 × det 0.3,0.5)",
- )
- p.add_argument(
- "--methods",
- default="threshold,masked_adaptive",
- help="去水印方式,逗号分隔",
- )
- p.add_argument(
- "--thresholds",
- default="155,165,170,175,180,none",
- help="threshold 法的阈值;none=预设默认",
- )
- p.add_argument(
- "--contrasts",
- default="false,true",
- help="是否 contrast,逗号分隔 false,true",
- )
- p.add_argument(
- "--upscales",
- default="64,96,128,192",
- help="最短边放大目标,逗号分隔整数",
- )
- p.add_argument(
- "--det-threshs",
- default="0.2,0.3,0.4,0.5",
- help="det_db_box_thresh,逗号分隔",
- )
- p.add_argument(
- "--text-black-target",
- type=int,
- default=88,
- help="contrast text_restore 目标黑度",
- )
- p.add_argument(
- "--no-save-images",
- action="store_true",
- help="不写出中间预处理 png(仅报告)",
- )
- p.add_argument(
- "--no-baseline",
- action="store_true",
- help="跳过「仅放大、不去水印」对照组",
- )
- p.add_argument(
- "--baseline-upscale",
- type=int,
- default=128,
- help="baseline 对照组的最短边放大",
- )
- return p
- def main(argv: Optional[Sequence[str]] = None) -> None:
- args = _build_arg_parser().parse_args(argv)
- inputs = collect_inputs(args.input, prefer_raw=not args.no_prefer_raw)
- if not inputs:
- raise SystemExit("未找到可扫描的图像")
- if args.output is not None:
- out_root = args.output
- elif args.input.is_file():
- out_root = args.input.parent / "sweep_out"
- else:
- out_root = args.input / "sweep_out"
- out_root.mkdir(parents=True, exist_ok=True)
- model_dir = args.model_dir or _default_model_dir()
- methods = [m.strip() for m in args.methods.split(",") if m.strip()]
- if args.quick:
- thresholds = [170, 175]
- upscales = [128, 192]
- det_threshs = [0.3, 0.5]
- contrasts = [False, True]
- else:
- thresholds = _parse_csv_ints(args.thresholds)
- upscales = [int(x) for x in args.upscales.split(",") if x.strip()]
- det_threshs = _parse_csv_floats(args.det_threshs)
- contrasts = _parse_csv_bools(args.contrasts)
- print(f"扫描 {len(inputs)} 张图 -> {out_root}")
- print(f" methods={methods} thresholds={thresholds} upscales={upscales}")
- if args.target:
- print(f" target={args.target!r}")
- summary: List[Dict[str, Any]] = []
- for img_path in inputs:
- print(f"\n=== {img_path.name} ===")
- report = run_sweep(
- img_path,
- out_root,
- prefer_raw=not args.no_prefer_raw,
- target=args.target,
- model_dir=model_dir,
- methods=methods,
- thresholds=thresholds,
- contrasts=contrasts,
- upscales=upscales,
- det_threshs=det_threshs,
- text_black_target=args.text_black_target,
- save_images=not args.no_save_images,
- run_baseline=not args.no_baseline,
- baseline_upscale=args.baseline_upscale,
- )
- summary.append(
- {
- "input": report["input"],
- "hits": len(report["hits"]),
- "report": str(Path(report["output_dir"]) / "sweep_report.json"),
- }
- )
- index_path = out_root / "sweep_index.json"
- index_path.write_text(
- json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8"
- )
- print(f"\n全部完成,索引: {index_path}")
- for s in summary:
- print(f" {s['input']}: {s['hits']} hits -> {s['report']}")
- if __name__ == "__main__":
- if len(sys.argv) == 1:
- print("ℹ️ 未提供命令行参数,使用默认配置运行...")
- default_config = {
- "input": "/Users/zhch158/workspace/data/流水分析/彭_广东兴宁农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/彭_广东兴宁农村商业银行_page_002_0/cell219_empty_empty_raw.png",
- "output": "./output/彭_广东兴宁农村商业银行/cell219_sweep",
- "target": "ATM存折取款",
- }
- sys.argv = [sys.argv[0], default_config["input"]]
- for key, value in default_config.items():
- if key == "input":
- continue
- flag = f"--{key.replace('_', '-')}"
- if isinstance(value, bool) and value:
- sys.argv.append(flag)
- elif not isinstance(value, bool):
- sys.argv.extend([flag, str(value)])
- sys.exit(main())
|