|
@@ -0,0 +1,554 @@
|
|
|
|
|
+#!/usr/bin/env python3
|
|
|
|
|
+"""
|
|
|
|
|
+单元格裁剪图预处理参数扫描:去水印 / threshold / contrast / upscale / det 阈值 / OCR 模式。
|
|
|
|
|
+
|
|
|
|
|
+默认从 **原图**(`*_raw.png`)出发,与 pipeline 二次 OCR 一致,避免对已预处理 debug 图二次去水印。
|
|
|
|
|
+
|
|
|
|
|
+用法:
|
|
|
|
|
+ python cell_sweep.py cell219_empty_empty_raw.png -o ./out -t "ATM存折取款"
|
|
|
|
|
+ python cell_sweep.py /path/to/tablecell_ocr/ -o ./out
|
|
|
|
|
+ python cell_sweep.py cell.png --quick --no-save-images
|
|
|
|
|
+ OCR_DET_MODEL_PATH=... OCR_REC_MODEL_PATH=... python cell_sweep.py cell.png
|
|
|
|
|
+"""
|
|
|
|
|
+from __future__ import annotations
|
|
|
|
|
+
|
|
|
|
|
+import argparse
|
|
|
|
|
+import json
|
|
|
|
|
+import os
|
|
|
|
|
+import sys
|
|
|
|
|
+from itertools import product
|
|
|
|
|
+from pathlib import Path
|
|
|
|
|
+from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
|
|
|
|
|
+
|
|
|
|
|
+import cv2
|
|
|
|
|
+import numpy as np
|
|
|
|
|
+
|
|
|
|
|
+_repo_root = Path(__file__).resolve().parents[2]
|
|
|
|
|
+if str(_repo_root) not in sys.path:
|
|
|
|
|
+ sys.path.insert(0, str(_repo_root))
|
|
|
|
|
+
|
|
|
|
|
+from ocr_utils.watermark import WatermarkProcessor, merge_watermark_config
|
|
|
|
|
+from ocr_utils.watermark.contrast import apply_contrast_enhancement_config
|
|
|
|
|
+
|
|
|
|
|
+_IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp"}
|
|
|
|
|
+_DEFAULT_MODEL_DIR = Path(
|
|
|
|
|
+ "/Users/zhch158/models/modelscope_cache/models/OpenDataLab/"
|
|
|
|
|
+ "PDF-Extract-Kit-1___0/models/OCR/paddleocr_torch"
|
|
|
|
|
+)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _parse_csv_ints(s: str) -> List[Optional[int]]:
|
|
|
|
|
+ out: List[Optional[int]] = []
|
|
|
|
|
+ for part in s.split(","):
|
|
|
|
|
+ part = part.strip()
|
|
|
|
|
+ if not part or part.lower() in ("none", "d", "default"):
|
|
|
|
|
+ out.append(None)
|
|
|
|
|
+ else:
|
|
|
|
|
+ out.append(int(part))
|
|
|
|
|
+ return out
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _parse_csv_floats(s: str) -> List[float]:
|
|
|
|
|
+ return [float(x.strip()) for x in s.split(",") if x.strip()]
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _parse_csv_bools(s: str) -> List[bool]:
|
|
|
|
|
+ out: List[bool] = []
|
|
|
|
|
+ for part in s.split(","):
|
|
|
|
|
+ p = part.strip().lower()
|
|
|
|
|
+ if p in ("1", "true", "yes", "on"):
|
|
|
|
|
+ out.append(True)
|
|
|
|
|
+ elif p in ("0", "false", "no", "off"):
|
|
|
|
|
+ out.append(False)
|
|
|
|
|
+ else:
|
|
|
|
|
+ raise ValueError(f"无效的 bool 值: {part!r}")
|
|
|
|
|
+ return out
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _default_model_dir() -> Path:
|
|
|
|
|
+ det = os.environ.get("OCR_DET_MODEL_PATH")
|
|
|
|
|
+ if det:
|
|
|
|
|
+ return Path(det).parent
|
|
|
|
|
+ return _DEFAULT_MODEL_DIR
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _upscale(img: np.ndarray, min_side: int) -> np.ndarray:
|
|
|
|
|
+ h, w = img.shape[:2]
|
|
|
|
|
+ if h >= min_side and w >= min_side:
|
|
|
|
|
+ return img
|
|
|
|
|
+ s = max(min_side / max(h, 1), min_side / max(w, 1), 1.0)
|
|
|
|
|
+ return cv2.resize(img, None, fx=s, fy=s, interpolation=cv2.INTER_CUBIC)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _preprocess(
|
|
|
|
|
+ raw: np.ndarray,
|
|
|
|
|
+ *,
|
|
|
|
|
+ method: str,
|
|
|
|
|
+ thresh: Optional[int],
|
|
|
|
|
+ contrast: bool,
|
|
|
|
|
+ upscale: int,
|
|
|
|
|
+ text_black_target: int,
|
|
|
|
|
+) -> np.ndarray:
|
|
|
|
|
+ user: Dict[str, Any] = {"enabled": True, "method": method}
|
|
|
|
|
+ if method == "threshold" and thresh is not None:
|
|
|
|
|
+ user["threshold"] = thresh
|
|
|
|
|
+ cfg = merge_watermark_config("cell", user)
|
|
|
|
|
+ img, _ = WatermarkProcessor(cfg, scope="cell").process(raw, force=True)
|
|
|
|
|
+ if contrast:
|
|
|
|
|
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
|
|
|
|
+ ce = dict(cfg.get("contrast_enhancement") or {})
|
|
|
|
|
+ ce["enabled"] = True
|
|
|
|
|
+ ce["text_black_target"] = text_black_target
|
|
|
|
|
+ gray = apply_contrast_enhancement_config(gray, ce)
|
|
|
|
|
+ img = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
|
|
|
|
|
+ return _upscale(img, upscale)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _parse_rec_pair(rec_part: Any) -> Tuple[str, float]:
|
|
|
|
|
+ """从 OCR 返回的 (text, score) 或嵌套结构中解析识别结果。"""
|
|
|
|
|
+ if rec_part is None:
|
|
|
|
|
+ return "", 0.0
|
|
|
|
|
+ if isinstance(rec_part, (list, tuple)) and len(rec_part) >= 2:
|
|
|
|
|
+ if isinstance(rec_part[0], (list, tuple, dict)):
|
|
|
|
|
+ return "", 0.0
|
|
|
|
|
+ txt = str(rec_part[0] or "").strip()
|
|
|
|
|
+ try:
|
|
|
|
|
+ sc = float(rec_part[1] or 0.0)
|
|
|
|
|
+ except (TypeError, ValueError):
|
|
|
|
|
+ sc = 0.0
|
|
|
|
|
+ return txt, sc if txt else 0.0
|
|
|
|
|
+ if isinstance(rec_part, (list, tuple)) and len(rec_part) == 1:
|
|
|
|
|
+ txt = str(rec_part[0] or "").strip()
|
|
|
|
|
+ return txt, 0.0
|
|
|
|
|
+ return "", 0.0
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _aggregate_rec_score(boxes: List[Dict[str, Any]]) -> float:
|
|
|
|
|
+ """按字符数加权平均识别分(与 pipeline aggregate_line_ocr 一致)。"""
|
|
|
|
|
+ total_len = sum(len(b.get("text") or "") for b in boxes)
|
|
|
|
|
+ if total_len <= 0:
|
|
|
|
|
+ return 0.0
|
|
|
|
|
+ weighted = sum(
|
|
|
|
|
+ len(b.get("text") or "") * float(b.get("score") or 0.0) for b in boxes
|
|
|
|
|
+ )
|
|
|
|
|
+ return weighted / total_len
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _ocr(engine: Any, img: np.ndarray, *, det: bool, rec: bool) -> Dict[str, Any]:
|
|
|
|
|
+ empty: Dict[str, Any] = {
|
|
|
|
|
+ "text": "",
|
|
|
|
|
+ "score": 0.0,
|
|
|
|
|
+ "boxes": [],
|
|
|
|
|
+ "det": det,
|
|
|
|
|
+ "rec": rec,
|
|
|
|
|
+ "n_boxes": 0,
|
|
|
|
|
+ }
|
|
|
|
|
+ try:
|
|
|
|
|
+ res = engine.ocr(img, det=det, rec=rec)
|
|
|
|
|
+ items = res[0] if res and res[0] is not None else []
|
|
|
|
|
+ boxes_out: List[Dict[str, Any]] = []
|
|
|
|
|
+
|
|
|
|
|
+ if det:
|
|
|
|
|
+ for item in items:
|
|
|
|
|
+ if not item or len(item) < 2:
|
|
|
|
|
+ continue
|
|
|
|
|
+ text, score = _parse_rec_pair(item[1])
|
|
|
|
|
+ bbox = item[0]
|
|
|
|
|
+ if hasattr(bbox, "tolist"):
|
|
|
|
|
+ bbox = bbox.tolist()
|
|
|
|
|
+ entry: Dict[str, Any] = {
|
|
|
|
|
+ "text": text,
|
|
|
|
|
+ "score": round(score, 6),
|
|
|
|
|
+ }
|
|
|
|
|
+ if bbox is not None:
|
|
|
|
|
+ entry["det_bbox"] = bbox
|
|
|
|
|
+ boxes_out.append(entry)
|
|
|
|
|
+ else:
|
|
|
|
|
+ for item in items:
|
|
|
|
|
+ text, score = _parse_rec_pair(item)
|
|
|
|
|
+ if not text and isinstance(item, (list, tuple)) and len(item) >= 1:
|
|
|
|
|
+ text, score = _parse_rec_pair(item[0])
|
|
|
|
|
+ boxes_out.append({"text": text, "score": round(score, 6)})
|
|
|
|
|
+
|
|
|
|
|
+ text = "".join(b["text"] for b in boxes_out if b.get("text")).strip()
|
|
|
|
|
+ agg_score = _aggregate_rec_score(boxes_out)
|
|
|
|
|
+ return {
|
|
|
|
|
+ "text": text,
|
|
|
|
|
+ "score": round(agg_score, 6),
|
|
|
|
|
+ "boxes": boxes_out,
|
|
|
|
|
+ "det": det,
|
|
|
|
|
+ "rec": rec,
|
|
|
|
|
+ "n_boxes": len(boxes_out),
|
|
|
|
|
+ }
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ out = dict(empty)
|
|
|
|
|
+ out["error"] = str(e)
|
|
|
|
|
+ return out
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _make_engine(det_thresh: float, model_dir: Path) -> Any:
|
|
|
|
|
+ from ocr_tools.pytorch_models.pytorch_paddle import PytorchPaddleOCR
|
|
|
|
|
+
|
|
|
|
|
+ det_path = os.environ.get("OCR_DET_MODEL_PATH") or str(
|
|
|
|
|
+ model_dir / "ch_PP-OCRv5_det_infer.pth"
|
|
|
|
|
+ )
|
|
|
|
|
+ rec_path = os.environ.get("OCR_REC_MODEL_PATH") or str(
|
|
|
|
|
+ model_dir / "ch_PP-OCRv4_rec_server_doc_infer.pth"
|
|
|
|
|
+ )
|
|
|
|
|
+ return PytorchPaddleOCR(
|
|
|
|
|
+ lang="ch",
|
|
|
|
|
+ det_model_path=det_path,
|
|
|
|
|
+ rec_model_path=rec_path,
|
|
|
|
|
+ det_db_box_thresh=det_thresh,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def resolve_input_image(path: Path, *, prefer_raw: bool) -> Path:
|
|
|
|
|
+ """优先使用与 pipeline debug 配套的 *_raw.png。"""
|
|
|
|
|
+ if not prefer_raw or path.stem.endswith("_raw"):
|
|
|
|
|
+ return path
|
|
|
|
|
+ raw_path = path.parent / f"{path.stem}_raw{path.suffix}"
|
|
|
|
|
+ if raw_path.is_file():
|
|
|
|
|
+ print(f" 使用原图: {raw_path.name}(跳过 {path.name})")
|
|
|
|
|
+ return raw_path
|
|
|
|
|
+ return path
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def collect_inputs(path: Path, *, prefer_raw: bool) -> List[Path]:
|
|
|
|
|
+ if path.is_file():
|
|
|
|
|
+ if path.suffix.lower() not in _IMAGE_SUFFIXES:
|
|
|
|
|
+ raise ValueError(f"不支持的图像格式: {path}")
|
|
|
|
|
+ return [resolve_input_image(path, prefer_raw=prefer_raw)]
|
|
|
|
|
+
|
|
|
|
|
+ if not path.is_dir():
|
|
|
|
|
+ raise FileNotFoundError(path)
|
|
|
|
|
+
|
|
|
|
|
+ all_images = sorted(
|
|
|
|
|
+ p
|
|
|
|
|
+ for p in path.iterdir()
|
|
|
|
|
+ if p.is_file() and p.suffix.lower() in _IMAGE_SUFFIXES
|
|
|
|
|
+ )
|
|
|
|
|
+ if not all_images:
|
|
|
|
|
+ raise FileNotFoundError(f"目录内无图像: {path}")
|
|
|
|
|
+
|
|
|
|
|
+ if prefer_raw:
|
|
|
|
|
+ raws = [p for p in all_images if p.stem.endswith("_raw")]
|
|
|
|
|
+ if raws:
|
|
|
|
|
+ return raws
|
|
|
|
|
+
|
|
|
|
|
+ chosen: List[Path] = []
|
|
|
|
|
+ for p in all_images:
|
|
|
|
|
+ if p.stem.endswith("_raw"):
|
|
|
|
|
+ continue
|
|
|
|
|
+ raw_sibling = p.parent / f"{p.stem}_raw{p.suffix}"
|
|
|
|
|
+ if prefer_raw and raw_sibling.is_file():
|
|
|
|
|
+ continue
|
|
|
|
|
+ chosen.append(p)
|
|
|
|
|
+ return chosen or all_images
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _match_hit(text: str, target: Optional[str]) -> Optional[str]:
|
|
|
|
|
+ if not text:
|
|
|
|
|
+ return None
|
|
|
|
|
+ if not target:
|
|
|
|
|
+ return "nonempty"
|
|
|
|
|
+ if target in text:
|
|
|
|
|
+ return "full"
|
|
|
|
|
+ if len(target) >= 6 and target.isdigit() and len(text) >= 6 and text.isdigit():
|
|
|
|
|
+ return "partial"
|
|
|
|
|
+ return None
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def run_sweep(
|
|
|
|
|
+ input_path: Path,
|
|
|
|
|
+ out_dir: Path,
|
|
|
|
|
+ *,
|
|
|
|
|
+ prefer_raw: bool,
|
|
|
|
|
+ target: Optional[str],
|
|
|
|
|
+ model_dir: Path,
|
|
|
|
|
+ methods: Sequence[str],
|
|
|
|
|
+ thresholds: Sequence[Optional[int]],
|
|
|
|
|
+ contrasts: Sequence[bool],
|
|
|
|
|
+ upscales: Sequence[int],
|
|
|
|
|
+ det_threshs: Sequence[float],
|
|
|
|
|
+ text_black_target: int,
|
|
|
|
|
+ save_images: bool,
|
|
|
|
|
+ run_baseline: bool,
|
|
|
|
|
+ baseline_upscale: int,
|
|
|
|
|
+) -> Dict[str, Any]:
|
|
|
|
|
+ resolved = resolve_input_image(input_path, prefer_raw=prefer_raw)
|
|
|
|
|
+ raw = cv2.imread(str(resolved))
|
|
|
|
|
+ if raw is None:
|
|
|
|
|
+ raise RuntimeError(f"无法读取图像: {resolved}")
|
|
|
|
|
+
|
|
|
|
|
+ stem = resolved.stem.removesuffix("_raw") if resolved.stem.endswith("_raw") else resolved.stem
|
|
|
|
|
+ cell_out = out_dir / stem
|
|
|
|
|
+ cell_out.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
+
|
|
|
|
|
+ ocr_modes: List[Tuple[str, bool, bool]] = [
|
|
|
|
|
+ ("det_rec", True, True),
|
|
|
|
|
+ ("whole_rec", False, True),
|
|
|
|
|
+ ]
|
|
|
|
|
+
|
|
|
|
|
+ results: List[Dict[str, Any]] = []
|
|
|
|
|
+ hits: List[Dict[str, Any]] = []
|
|
|
|
|
+ engines: Dict[float, Any] = {}
|
|
|
|
|
+ total = 0
|
|
|
|
|
+
|
|
|
|
|
+ for method, thresh, contrast, upscale, det_th in product(
|
|
|
|
|
+ methods, thresholds, contrasts, upscales, det_threshs
|
|
|
|
|
+ ):
|
|
|
|
|
+ if method != "threshold" and thresh is not None:
|
|
|
|
|
+ continue
|
|
|
|
|
+ if det_th not in engines:
|
|
|
|
|
+ print(f" [{stem}] 加载 OCR det_db_box_thresh={det_th} ...")
|
|
|
|
|
+ engines[det_th] = _make_engine(det_th, model_dir)
|
|
|
|
|
+
|
|
|
|
|
+ img = _preprocess(
|
|
|
|
|
+ raw,
|
|
|
|
|
+ method=method,
|
|
|
|
|
+ thresh=thresh,
|
|
|
|
|
+ contrast=contrast,
|
|
|
|
|
+ upscale=upscale,
|
|
|
|
|
+ text_black_target=text_black_target,
|
|
|
|
|
+ )
|
|
|
|
|
+ tag = f"{method}_t{thresh or 'd'}_c{int(contrast)}_u{upscale}_det{det_th}"
|
|
|
|
|
+ if save_images:
|
|
|
|
|
+ cv2.imwrite(str(cell_out / f"{tag}.png"), img)
|
|
|
|
|
+
|
|
|
|
|
+ for mode_name, det, rec in ocr_modes:
|
|
|
|
|
+ total += 1
|
|
|
|
|
+ ocr = _ocr(engines[det_th], img, det=det, rec=rec)
|
|
|
|
|
+ row: Dict[str, Any] = {
|
|
|
|
|
+ "tag": tag,
|
|
|
|
|
+ "method": method,
|
|
|
|
|
+ "threshold": thresh,
|
|
|
|
|
+ "contrast": contrast,
|
|
|
|
|
+ "upscale": upscale,
|
|
|
|
|
+ "det_db_box_thresh": det_th,
|
|
|
|
|
+ "ocr_mode": mode_name,
|
|
|
|
|
+ **ocr,
|
|
|
|
|
+ }
|
|
|
|
|
+ results.append(row)
|
|
|
|
|
+ m = _match_hit(row.get("text", ""), target)
|
|
|
|
|
+ if m:
|
|
|
|
|
+ row["match"] = m
|
|
|
|
|
+ hits.append(row)
|
|
|
|
|
+ print(
|
|
|
|
|
+ f" HIT [{m}] {mode_name} {tag} "
|
|
|
|
|
+ f"score={row.get('score')} -> {row.get('text')!r}"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ if run_baseline:
|
|
|
|
|
+ for det_th in det_threshs:
|
|
|
|
|
+ if det_th not in engines:
|
|
|
|
|
+ engines[det_th] = _make_engine(det_th, model_dir)
|
|
|
|
|
+ base_img = _upscale(raw, baseline_upscale)
|
|
|
|
|
+ if save_images:
|
|
|
|
|
+ cv2.imwrite(str(cell_out / f"baseline_upscale{baseline_upscale}.png"), base_img)
|
|
|
|
|
+ for mode_name, det, rec in ocr_modes:
|
|
|
|
|
+ ocr = _ocr(engines[det_th], base_img, det=det, rec=rec)
|
|
|
|
|
+ row = {
|
|
|
|
|
+ "tag": f"baseline_upscale{baseline_upscale}",
|
|
|
|
|
+ "det_db_box_thresh": det_th,
|
|
|
|
|
+ "ocr_mode": mode_name,
|
|
|
|
|
+ **ocr,
|
|
|
|
|
+ }
|
|
|
|
|
+ results.append(row)
|
|
|
|
|
+ m = _match_hit(row.get("text", ""), target)
|
|
|
|
|
+ if m:
|
|
|
|
|
+ row["match"] = m
|
|
|
|
|
+ hits.append(row)
|
|
|
|
|
+
|
|
|
|
|
+ report = {
|
|
|
|
|
+ "input": str(resolved),
|
|
|
|
|
+ "input_requested": str(input_path),
|
|
|
|
|
+ "output_dir": str(cell_out),
|
|
|
|
|
+ "target": target,
|
|
|
|
|
+ "total_trials": total,
|
|
|
|
|
+ "hits": hits,
|
|
|
|
|
+ "all_results": results,
|
|
|
|
|
+ }
|
|
|
|
|
+ report_path = cell_out / "sweep_report.json"
|
|
|
|
|
+ report_path.write_text(
|
|
|
|
|
+ json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8"
|
|
|
|
|
+ )
|
|
|
|
|
+ return report
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _build_arg_parser() -> argparse.ArgumentParser:
|
|
|
|
|
+ p = argparse.ArgumentParser(
|
|
|
|
|
+ description="单元格图预处理 + OCR 参数网格扫描(对齐 pipeline 格级二次 OCR)",
|
|
|
|
|
+ )
|
|
|
|
|
+ p.add_argument(
|
|
|
|
|
+ "input",
|
|
|
|
|
+ type=Path,
|
|
|
|
|
+ help="单元格裁剪图路径,或 tablecell_ocr 目录(批量扫描)",
|
|
|
|
|
+ )
|
|
|
|
|
+ p.add_argument(
|
|
|
|
|
+ "-o",
|
|
|
|
|
+ "--output",
|
|
|
|
|
+ type=Path,
|
|
|
|
|
+ default=None,
|
|
|
|
|
+ help="输出目录,默认 <input_dir|input_parent>/sweep_out/<stem>",
|
|
|
|
|
+ )
|
|
|
|
|
+ p.add_argument(
|
|
|
|
|
+ "-t",
|
|
|
|
|
+ "--target",
|
|
|
|
|
+ default=None,
|
|
|
|
|
+ help="期望 OCR 文本;用于标记 HIT(子串匹配)。省略则任意非空为 HIT",
|
|
|
|
|
+ )
|
|
|
|
|
+ p.add_argument(
|
|
|
|
|
+ "--model-dir",
|
|
|
|
|
+ type=Path,
|
|
|
|
|
+ default=None,
|
|
|
|
|
+ help="PaddleOCR torch 模型目录(含 det/rec .pth),也可用 OCR_*_MODEL_PATH",
|
|
|
|
|
+ )
|
|
|
|
|
+ p.add_argument(
|
|
|
|
|
+ "--no-prefer-raw",
|
|
|
|
|
+ action="store_true",
|
|
|
|
|
+ help="不自动选用同名的 *_raw.png",
|
|
|
|
|
+ )
|
|
|
|
|
+ p.add_argument(
|
|
|
|
|
+ "--quick",
|
|
|
|
|
+ action="store_true",
|
|
|
|
|
+ help="缩小网格(threshold 170,175 × upscale 128,192 × det 0.3,0.5)",
|
|
|
|
|
+ )
|
|
|
|
|
+ p.add_argument(
|
|
|
|
|
+ "--methods",
|
|
|
|
|
+ default="threshold,masked_adaptive",
|
|
|
|
|
+ help="去水印方式,逗号分隔",
|
|
|
|
|
+ )
|
|
|
|
|
+ p.add_argument(
|
|
|
|
|
+ "--thresholds",
|
|
|
|
|
+ default="155,165,170,175,180,none",
|
|
|
|
|
+ help="threshold 法的阈值;none=预设默认",
|
|
|
|
|
+ )
|
|
|
|
|
+ p.add_argument(
|
|
|
|
|
+ "--contrasts",
|
|
|
|
|
+ default="false,true",
|
|
|
|
|
+ help="是否 contrast,逗号分隔 false,true",
|
|
|
|
|
+ )
|
|
|
|
|
+ p.add_argument(
|
|
|
|
|
+ "--upscales",
|
|
|
|
|
+ default="64,96,128,192",
|
|
|
|
|
+ help="最短边放大目标,逗号分隔整数",
|
|
|
|
|
+ )
|
|
|
|
|
+ p.add_argument(
|
|
|
|
|
+ "--det-threshs",
|
|
|
|
|
+ default="0.2,0.3,0.4,0.5",
|
|
|
|
|
+ help="det_db_box_thresh,逗号分隔",
|
|
|
|
|
+ )
|
|
|
|
|
+ p.add_argument(
|
|
|
|
|
+ "--text-black-target",
|
|
|
|
|
+ type=int,
|
|
|
|
|
+ default=88,
|
|
|
|
|
+ help="contrast text_restore 目标黑度",
|
|
|
|
|
+ )
|
|
|
|
|
+ p.add_argument(
|
|
|
|
|
+ "--no-save-images",
|
|
|
|
|
+ action="store_true",
|
|
|
|
|
+ help="不写出中间预处理 png(仅报告)",
|
|
|
|
|
+ )
|
|
|
|
|
+ p.add_argument(
|
|
|
|
|
+ "--no-baseline",
|
|
|
|
|
+ action="store_true",
|
|
|
|
|
+ help="跳过「仅放大、不去水印」对照组",
|
|
|
|
|
+ )
|
|
|
|
|
+ p.add_argument(
|
|
|
|
|
+ "--baseline-upscale",
|
|
|
|
|
+ type=int,
|
|
|
|
|
+ default=128,
|
|
|
|
|
+ help="baseline 对照组的最短边放大",
|
|
|
|
|
+ )
|
|
|
|
|
+ return p
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def main(argv: Optional[Sequence[str]] = None) -> None:
|
|
|
|
|
+ args = _build_arg_parser().parse_args(argv)
|
|
|
|
|
+ inputs = collect_inputs(args.input, prefer_raw=not args.no_prefer_raw)
|
|
|
|
|
+ if not inputs:
|
|
|
|
|
+ raise SystemExit("未找到可扫描的图像")
|
|
|
|
|
+
|
|
|
|
|
+ if args.output is not None:
|
|
|
|
|
+ out_root = args.output
|
|
|
|
|
+ elif args.input.is_file():
|
|
|
|
|
+ out_root = args.input.parent / "sweep_out"
|
|
|
|
|
+ else:
|
|
|
|
|
+ out_root = args.input / "sweep_out"
|
|
|
|
|
+ out_root.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
+
|
|
|
|
|
+ model_dir = args.model_dir or _default_model_dir()
|
|
|
|
|
+ methods = [m.strip() for m in args.methods.split(",") if m.strip()]
|
|
|
|
|
+
|
|
|
|
|
+ if args.quick:
|
|
|
|
|
+ thresholds = [170, 175]
|
|
|
|
|
+ upscales = [128, 192]
|
|
|
|
|
+ det_threshs = [0.3, 0.5]
|
|
|
|
|
+ contrasts = [False, True]
|
|
|
|
|
+ else:
|
|
|
|
|
+ thresholds = _parse_csv_ints(args.thresholds)
|
|
|
|
|
+ upscales = [int(x) for x in args.upscales.split(",") if x.strip()]
|
|
|
|
|
+ det_threshs = _parse_csv_floats(args.det_threshs)
|
|
|
|
|
+ contrasts = _parse_csv_bools(args.contrasts)
|
|
|
|
|
+
|
|
|
|
|
+ print(f"扫描 {len(inputs)} 张图 -> {out_root}")
|
|
|
|
|
+ print(f" methods={methods} thresholds={thresholds} upscales={upscales}")
|
|
|
|
|
+ if args.target:
|
|
|
|
|
+ print(f" target={args.target!r}")
|
|
|
|
|
+
|
|
|
|
|
+ summary: List[Dict[str, Any]] = []
|
|
|
|
|
+ for img_path in inputs:
|
|
|
|
|
+ print(f"\n=== {img_path.name} ===")
|
|
|
|
|
+ report = run_sweep(
|
|
|
|
|
+ img_path,
|
|
|
|
|
+ out_root,
|
|
|
|
|
+ prefer_raw=not args.no_prefer_raw,
|
|
|
|
|
+ target=args.target,
|
|
|
|
|
+ model_dir=model_dir,
|
|
|
|
|
+ methods=methods,
|
|
|
|
|
+ thresholds=thresholds,
|
|
|
|
|
+ contrasts=contrasts,
|
|
|
|
|
+ upscales=upscales,
|
|
|
|
|
+ det_threshs=det_threshs,
|
|
|
|
|
+ text_black_target=args.text_black_target,
|
|
|
|
|
+ save_images=not args.no_save_images,
|
|
|
|
|
+ run_baseline=not args.no_baseline,
|
|
|
|
|
+ baseline_upscale=args.baseline_upscale,
|
|
|
|
|
+ )
|
|
|
|
|
+ summary.append(
|
|
|
|
|
+ {
|
|
|
|
|
+ "input": report["input"],
|
|
|
|
|
+ "hits": len(report["hits"]),
|
|
|
|
|
+ "report": str(Path(report["output_dir"]) / "sweep_report.json"),
|
|
|
|
|
+ }
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ index_path = out_root / "sweep_index.json"
|
|
|
|
|
+ index_path.write_text(
|
|
|
|
|
+ json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8"
|
|
|
|
|
+ )
|
|
|
|
|
+ print(f"\n全部完成,索引: {index_path}")
|
|
|
|
|
+ for s in summary:
|
|
|
|
|
+ print(f" {s['input']}: {s['hits']} hits -> {s['report']}")
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+if __name__ == "__main__":
|
|
|
|
|
+ if len(sys.argv) == 1:
|
|
|
|
|
+ print("ℹ️ 未提供命令行参数,使用默认配置运行...")
|
|
|
|
|
+ default_config = {
|
|
|
|
|
+ "input": "/Users/zhch158/workspace/data/流水分析/彭_广东兴宁农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/彭_广东兴宁农村商业银行_page_002_0/cell219_empty_empty_raw.png",
|
|
|
|
|
+ "output": "./output/彭_广东兴宁农村商业银行/cell219_sweep",
|
|
|
|
|
+ "target": "ATM存折取款",
|
|
|
|
|
+ }
|
|
|
|
|
+ sys.argv = [sys.argv[0], default_config["input"]]
|
|
|
|
|
+ for key, value in default_config.items():
|
|
|
|
|
+ if key == "input":
|
|
|
|
|
+ continue
|
|
|
|
|
+ flag = f"--{key.replace('_', '-')}"
|
|
|
|
|
+ if isinstance(value, bool) and value:
|
|
|
|
|
+ sys.argv.append(flag)
|
|
|
|
|
+ elif not isinstance(value, bool):
|
|
|
|
|
+ sys.argv.extend([flag, str(value)])
|
|
|
|
|
+
|
|
|
|
|
+ sys.exit(main())
|