| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334 |
- #!/usr/bin/env python3
- """
- 单元格裁剪图预处理参数扫描:去水印 / contrast(clahe/gamma/linear/text_restore)/ upscale / det 阈值 / OCR 模式。
- 支持 contrast 在放大前/后执行两种顺序对比。
- 默认从 **原图**(`*_raw.png`)出发,与 pipeline 二次 OCR 一致,避免对已预处理 debug 图二次去水印。
- 用法:
- python cell_sweep.py cell219_empty_empty_raw.png -o ./out -t "ATM存折取款"
- python cell_sweep.py /path/to/tablecell_ocr/ -o ./out
- python cell_sweep.py cell.png --quick --no-save-images
- python cell_sweep.py cell.png --contrast-orders before_upscale,after_upscale
- OCR_DET_MODEL_PATH=... OCR_REC_MODEL_PATH=... python cell_sweep.py cell.png
- # 统计出的最优参数 tag: threshold_t150_cl_1.0_8_ob_u128_det0.5
- # 对目录下所有 *_raw.png 验证适配性
- python cell_sweep.py /path/to/tablecell_ocr/ -o ./out --best-only
- # 自定义最优参数
- python cell_sweep.py /path/to/tablecell_ocr/ -o ./out --best-only \
- --best-config threshold_t150_cl_1.0_8_ob_u128_det0.5
- # 指定目标文字,自动统计 HIT 命中率
- python cell_sweep.py /path/to/tablecell_ocr/ -o ./out --best-only -t "交易类型"
- # 多 case 批量扫描 + 汇总 Pass1/Pass2 建议
- python cell_sweep.py --cases sweep_cases.json
- python cell_sweep.py --aggregate-only --cases sweep_cases.json
- """
- from __future__ import annotations
- import argparse
- import json
- import os
- import sys
- from itertools import product
- from pathlib import Path
- from typing import Any, Dict, List, Optional, Sequence, Tuple
- import cv2
- import numpy as np
- _repo_root = Path(__file__).resolve().parents[3]
- if str(_repo_root) not in sys.path:
- sys.path.insert(0, str(_repo_root))
- from ocr_utils.watermark import WatermarkProcessor, merge_watermark_config
- from ocr_utils.watermark.contrast import enhance_document_contrast
- _IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp"}
- _DEFAULT_MODEL_DIR = Path(
- "/Users/zhch158/models/modelscope_cache/models/OpenDataLab/"
- "PDF-Extract-Kit-1___0/models/OCR/paddleocr_torch"
- )
- def _parse_csv_ints(s: str) -> List[Optional[int]]:
- out: List[Optional[int]] = []
- for part in s.split(","):
- part = part.strip()
- if not part or part.lower() in ("none", "d", "default"):
- out.append(None)
- else:
- out.append(int(part))
- return out
- def _parse_csv_floats(s: str) -> List[float]:
- return [float(x.strip()) for x in s.split(",") if x.strip()]
- def _parse_csv_bools(s: str) -> List[bool]:
- out: List[bool] = []
- for part in s.split(","):
- p = part.strip().lower()
- if p in ("1", "true", "yes", "on"):
- out.append(True)
- elif p in ("0", "false", "no", "off"):
- out.append(False)
- else:
- raise ValueError(f"无效的 bool 值: {part!r}")
- return out
- def _default_model_dir() -> Path:
- det = os.environ.get("OCR_DET_MODEL_PATH")
- if det:
- return Path(det).parent
- return _DEFAULT_MODEL_DIR
- def _upscale(img: np.ndarray, min_side: int) -> np.ndarray:
- h, w = img.shape[:2]
- if h >= min_side and w >= min_side:
- return img
- s = max(min_side / max(h, 1), min_side / max(w, 1), 1.0)
- return cv2.resize(img, None, fx=s, fy=s, interpolation=cv2.INTER_CUBIC)
- # ── 对比度增强方法(clahe / gamma / linear / text_restore / none)──
- def _apply_contrast(
- gray: np.ndarray,
- *,
- method: str,
- clip_limit: float = 1.0,
- tile_grid_size: int = 8,
- gamma: float = 0.85,
- black_percentile: float = 2.0,
- white_percentile: float = 98.0,
- text_black_target: int = 85,
- background_threshold: int = 248,
- ) -> np.ndarray:
- """对灰度图应用对比度增强;method="none" 时原样返回。"""
- if method == "none":
- return gray
- if method == "text_restore":
- return enhance_document_contrast(
- gray, method="text_restore",
- text_black_target=text_black_target,
- background_threshold=background_threshold,
- )
- if method == "clahe":
- return enhance_document_contrast(
- gray, method="clahe",
- clip_limit=clip_limit, tile_grid_size=tile_grid_size,
- )
- if method == "gamma":
- return enhance_document_contrast(gray, method="gamma", gamma=gamma)
- if method == "linear":
- return enhance_document_contrast(
- gray, method="linear",
- black_percentile=black_percentile,
- white_percentile=white_percentile,
- )
- return gray
- def _contrast_tag(cfg: Dict[str, Any]) -> str:
- """生成 contrast 配置的短标签。"""
- m = cfg.get("method", "none")
- if m == "none":
- return "c0"
- if m == "text_restore":
- return f"tr_{cfg.get('text_black_target', 85)}"
- if m == "clahe":
- return f"cl_{cfg.get('clip_limit', 1.0)}_{cfg.get('tile_grid_size', 8)}"
- if m == "gamma":
- return f"gm_{cfg.get('gamma', 0.85)}"
- if m == "linear":
- return f"ln_{cfg.get('black_percentile', 2.0)}_{cfg.get('white_percentile', 98.0)}"
- return m
- def _build_contrast_grid(quick: bool = False) -> List[Dict[str, Any]]:
- """构建 contrast 参数网格(对齐 contrast_sweep.py 的设计)。
- 返回列表,每个元素是一个 Dict,至少包含 "method" 字段。
- """
- grid: List[Dict[str, Any]] = [{"method": "none"}] # 对照组:不增强
- # text_restore
- if quick:
- tbt = [60, 85]
- bts = [240, 248]
- else:
- tbt = [60, 85, 100, 120]
- bts = [240, 248, 252]
- for target, bg_th in product(tbt, bts):
- grid.append({"method": "text_restore", "text_black_target": target, "background_threshold": bg_th})
- # clahe
- if quick:
- cl = [1.0, 2.0]
- ts = [4, 8]
- else:
- cl = [0.5, 1.0, 2.0, 3.0, 5.0]
- ts = [4, 8]
- for clip, tile in product(cl, ts):
- grid.append({"method": "clahe", "clip_limit": clip, "tile_grid_size": tile})
- # # gamma
- # if quick:
- # gvs = [0.5, 0.85]
- # else:
- # gvs = [0.4, 0.55, 0.7, 0.85]
- # for g in gvs:
- # grid.append({"method": "gamma", "gamma": g})
- # # linear
- # if quick:
- # bps = [2.0, 5.0]
- # wps = [95.0, 98.0]
- # else:
- # bps = [2.0, 5.0, 8.0]
- # wps = [95.0, 98.0]
- # for bp, wp in product(bps, wps):
- # grid.append({"method": "linear", "black_percentile": bp, "white_percentile": wp})
- return grid
- def _preprocess(
- raw: np.ndarray,
- *,
- method: str,
- thresh: Optional[int],
- contrast_cfg: Dict[str, Any],
- upscale: int,
- contrast_order: str = "before_upscale",
- ) -> np.ndarray:
- """预处理管线:去水印 → [contrast] → 放大(或去水印 → 放大 → contrast)。
- method="none" 时跳过去水印,直接从原图开始处理。
- """
- if method == "none":
- img = raw.copy() # 不处理水印,直接使用原图
- else:
- user: Dict[str, Any] = {"enabled": True, "method": method}
- if method == "threshold" and thresh is not None:
- user["threshold"] = thresh
- cfg = merge_watermark_config("cell", user)
- img, _ = WatermarkProcessor(cfg, scope="cell").process(raw, force=True)
- contrast_method = contrast_cfg.get("method", "none")
- if contrast_method != "none" and contrast_order == "before_upscale":
- gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
- gray = _apply_contrast(gray, **contrast_cfg)
- img = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
- img = _upscale(img, upscale)
- if contrast_method != "none" and contrast_order == "after_upscale":
- gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
- gray = _apply_contrast(gray, **contrast_cfg)
- img = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
- return img
- def _parse_rec_pair(rec_part: Any) -> Tuple[str, float]:
- """从 OCR 返回的 (text, score) 或嵌套结构中解析识别结果。"""
- if rec_part is None:
- return "", 0.0
- if isinstance(rec_part, (list, tuple)) and len(rec_part) >= 2:
- if isinstance(rec_part[0], (list, tuple, dict)):
- return "", 0.0
- txt = str(rec_part[0] or "").strip()
- try:
- sc = float(rec_part[1] or 0.0)
- except (TypeError, ValueError):
- sc = 0.0
- return txt, sc if txt else 0.0
- if isinstance(rec_part, (list, tuple)) and len(rec_part) == 1:
- txt = str(rec_part[0] or "").strip()
- return txt, 0.0
- return "", 0.0
- def _aggregate_rec_score(boxes: List[Dict[str, Any]]) -> float:
- """按字符数加权平均识别分(与 pipeline aggregate_line_ocr 一致)。"""
- total_len = sum(len(b.get("text") or "") for b in boxes)
- if total_len <= 0:
- return 0.0
- weighted = sum(
- len(b.get("text") or "") * float(b.get("score") or 0.0) for b in boxes
- )
- return weighted / total_len
- def _ocr(engine: Any, img: np.ndarray, *, det: bool, rec: bool) -> Dict[str, Any]:
- empty: Dict[str, Any] = {
- "text": "",
- "score": 0.0,
- "boxes": [],
- "det": det,
- "rec": rec,
- "n_boxes": 0,
- }
- try:
- res = engine.ocr(img, det=det, rec=rec)
- items = res[0] if res and res[0] is not None else []
- boxes_out: List[Dict[str, Any]] = []
- if det:
- for item in items:
- if not item or len(item) < 2:
- continue
- text, score = _parse_rec_pair(item[1])
- bbox = item[0]
- if hasattr(bbox, "tolist"):
- bbox = bbox.tolist()
- entry: Dict[str, Any] = {
- "text": text,
- "score": round(score, 6),
- }
- if bbox is not None:
- entry["det_bbox"] = bbox
- boxes_out.append(entry)
- else:
- for item in items:
- text, score = _parse_rec_pair(item)
- if not text and isinstance(item, (list, tuple)) and len(item) >= 1:
- text, score = _parse_rec_pair(item[0])
- boxes_out.append({"text": text, "score": round(score, 6)})
- text = "".join(b["text"] for b in boxes_out if b.get("text")).strip()
- agg_score = _aggregate_rec_score(boxes_out)
- return {
- "text": text,
- "score": round(agg_score, 6),
- "boxes": boxes_out,
- "det": det,
- "rec": rec,
- "n_boxes": len(boxes_out),
- }
- except Exception as e:
- out = dict(empty)
- out["error"] = str(e)
- return out
- def _make_engine(det_thresh: float, model_dir: Path) -> Any:
- from ocr_tools.pytorch_models.pytorch_paddle import PytorchPaddleOCR
- det_path = os.environ.get("OCR_DET_MODEL_PATH") or str(
- model_dir / "ch_PP-OCRv5_det_infer.pth"
- )
- rec_path = os.environ.get("OCR_REC_MODEL_PATH") or str(
- model_dir / "ch_PP-OCRv4_rec_server_doc_infer.pth"
- )
- return PytorchPaddleOCR(
- lang="ch",
- det_model_path=det_path,
- rec_model_path=rec_path,
- det_db_box_thresh=det_thresh,
- )
- def resolve_input_image(path: Path, *, prefer_raw: bool) -> Path:
- """优先使用与 pipeline debug 配套的 *_raw.png。"""
- if not prefer_raw or path.stem.endswith("_raw"):
- return path
- raw_path = path.parent / f"{path.stem}_raw{path.suffix}"
- if raw_path.is_file():
- print(f" 使用原图: {raw_path.name}(跳过 {path.name})")
- return raw_path
- return path
- def collect_inputs(path: Path, *, prefer_raw: bool) -> List[Path]:
- if path.is_file():
- if path.suffix.lower() not in _IMAGE_SUFFIXES:
- raise ValueError(f"不支持的图像格式: {path}")
- return [resolve_input_image(path, prefer_raw=prefer_raw)]
- if not path.is_dir():
- raise FileNotFoundError(path)
- all_images = sorted(
- p
- for p in path.iterdir()
- if p.is_file() and p.suffix.lower() in _IMAGE_SUFFIXES
- )
- if not all_images:
- raise FileNotFoundError(f"目录内无图像: {path}")
- if prefer_raw:
- raws = [p for p in all_images if p.stem.endswith("_raw")]
- if raws:
- return raws
- chosen: List[Path] = []
- for p in all_images:
- if p.stem.endswith("_raw"):
- continue
- raw_sibling = p.parent / f"{p.stem}_raw{p.suffix}"
- if prefer_raw and raw_sibling.is_file():
- continue
- chosen.append(p)
- return chosen or all_images
- def _match_hit(text: str, target: Optional[str]) -> Optional[str]:
- if not text:
- return None
- if not target:
- return "nonempty"
- if target in text:
- return "full"
- if len(target) >= 6 and target.isdigit() and len(text) >= 6 and text.isdigit():
- return "partial"
- return None
- def _collect_qualified_hits(
- results: List[Dict[str, Any]],
- target: Optional[str],
- *,
- min_score: float = 0.9,
- ocr_mode: str = "det_rec",
- ) -> List[Dict[str, Any]]:
- """命中 target 且 score > min_score 的组合(按 tag 去重,保留最高分)。"""
- if not target:
- return []
- by_tag: Dict[str, Dict[str, Any]] = {}
- for r in results:
- if r.get("ocr_mode") != ocr_mode:
- continue
- if not _match_hit(r.get("text", "") or "", target):
- continue
- score = float(r.get("score") or 0)
- if score <= min_score:
- continue
- tag = str(r.get("tag") or "")
- prev = by_tag.get(tag)
- if prev is None or score > float(prev.get("score") or 0):
- by_tag[tag] = r
- return sorted(by_tag.values(), key=lambda x: -(float(x.get("score") or 0)))
- def run_sweep(
- input_path: Path,
- out_dir: Path,
- *,
- prefer_raw: bool,
- target: Optional[str],
- model_dir: Path,
- methods: Sequence[str],
- thresholds: Sequence[Optional[int]],
- contrast_grid: List[Dict[str, Any]],
- contrast_orders: Sequence[str],
- upscales: Sequence[int],
- det_threshs: Sequence[float],
- save_images: bool,
- run_baseline: bool,
- baseline_upscale: int,
- min_hit_score: float = 0.9,
- ) -> Dict[str, Any]:
- resolved = resolve_input_image(input_path, prefer_raw=prefer_raw)
- raw = cv2.imread(str(resolved))
- if raw is None:
- raise RuntimeError(f"无法读取图像: {resolved}")
- stem = resolved.stem.removesuffix("_raw") if resolved.stem.endswith("_raw") else resolved.stem
- cell_out = out_dir / stem
- cell_out.mkdir(parents=True, exist_ok=True)
- ocr_modes: List[Tuple[str, bool, bool]] = [
- ("det_rec", True, True),
- ("whole_rec", False, True),
- ]
- results: List[Dict[str, Any]] = []
- hits: List[Dict[str, Any]] = []
- engines: Dict[float, Any] = {}
- total = 0
- for method, thresh, contrast_cfg, c_order, upscale, det_th in product(
- methods, thresholds, contrast_grid, contrast_orders, upscales, det_threshs
- ):
- # 过滤无效组合:非 threshold 方法不需要阈值
- if method not in ("threshold",):
- if thresh is not None:
- continue
- if det_th not in engines:
- print(f" [{stem}] 加载 OCR det_db_box_thresh={det_th} ...")
- engines[det_th] = _make_engine(det_th, model_dir)
- img = _preprocess(
- raw,
- method=method,
- thresh=thresh,
- contrast_cfg=contrast_cfg,
- upscale=upscale,
- contrast_order=c_order,
- )
- c_tag = _contrast_tag(contrast_cfg)
- o_tag = "b" if c_order == "before_upscale" else "a"
- tag = f"{method}_t{thresh or 'd'}_{c_tag}_o{o_tag}_u{upscale}_det{det_th}"
- if save_images:
- cv2.imwrite(str(cell_out / f"{tag}.png"), img)
- for mode_name, det, rec in ocr_modes:
- total += 1
- ocr = _ocr(engines[det_th], img, det=det, rec=rec)
- row: Dict[str, Any] = {
- "tag": tag,
- "method": method,
- "threshold": thresh,
- "contrast_method": contrast_cfg.get("method", "none"),
- "contrast_order": c_order,
- "contrast_cfg": contrast_cfg,
- "upscale": upscale,
- "det_db_box_thresh": det_th,
- "ocr_mode": mode_name,
- **ocr,
- }
- results.append(row)
- m = _match_hit(row.get("text", ""), target)
- if m:
- row["match"] = m
- hits.append(row)
- print(
- f" HIT [{m}] {mode_name} {tag} "
- f"score={row.get('score')} -> {row.get('text')!r}"
- )
- if run_baseline:
- for det_th in det_threshs:
- if det_th not in engines:
- engines[det_th] = _make_engine(det_th, model_dir)
- base_img = _upscale(raw, baseline_upscale)
- if save_images:
- cv2.imwrite(str(cell_out / f"baseline_upscale{baseline_upscale}.png"), base_img)
- for mode_name, det, rec in ocr_modes:
- ocr = _ocr(engines[det_th], base_img, det=det, rec=rec)
- row = {
- "tag": f"baseline_upscale{baseline_upscale}",
- "det_db_box_thresh": det_th,
- "ocr_mode": mode_name,
- **ocr,
- }
- results.append(row)
- m = _match_hit(row.get("text", ""), target)
- if m:
- row["match"] = m
- hits.append(row)
- qualified_hits = _collect_qualified_hits(
- results, target, min_score=min_hit_score, ocr_mode="det_rec"
- )
- report = {
- "input": str(resolved),
- "input_requested": str(input_path),
- "output_dir": str(cell_out),
- "target": target,
- "min_hit_score": min_hit_score,
- "total_trials": total,
- "hits": hits,
- "hits_target_score_above": [
- {
- "tag": r.get("tag"),
- "score": r.get("score"),
- "text": r.get("text"),
- "match": r.get("match"),
- "method": r.get("method"),
- "threshold": r.get("threshold"),
- "contrast_method": r.get("contrast_method"),
- "contrast_order": r.get("contrast_order"),
- "contrast_cfg": r.get("contrast_cfg"),
- "upscale": r.get("upscale"),
- "det_db_box_thresh": r.get("det_db_box_thresh"),
- "ocr_mode": r.get("ocr_mode"),
- }
- for r in qualified_hits
- ],
- "all_results": results,
- }
- report_path = cell_out / "sweep_report.json"
- report_path.write_text(
- json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8"
- )
- _print_conclusions(stem, results, target, min_hit_score=min_hit_score)
- return report
- def _sorted_summary_labels(values: set) -> List[str]:
- """将可能含 None 的集合转为可排序的展示标签(baseline 行字段常为 None)。"""
- labels: List[str] = []
- if None in values:
- labels.append("baseline")
- labels.extend(sorted(str(v) for v in values if v is not None))
- return labels
- def _format_param_row(r: Dict[str, Any]) -> str:
- """单行参数组合描述。"""
- cfg = r.get("contrast_cfg") or {}
- cm = r.get("contrast_method", "none")
- extra = ""
- if cm == "clahe":
- extra = f" cl={cfg.get('clip_limit')} tile={cfg.get('tile_grid_size')}"
- elif cm == "text_restore":
- extra = f" t={cfg.get('text_black_target')} bg={cfg.get('background_threshold')}"
- elif cm == "gamma":
- extra = f" g={cfg.get('gamma')}"
- elif cm == "linear":
- extra = f" b={cfg.get('black_percentile')} w={cfg.get('white_percentile')}"
- order = "放大前" if r.get("contrast_order") == "before_upscale" else "放大后"
- return (
- f"wm={r.get('method')} thresh={r.get('threshold', 'd')} contrast={cm}{extra} "
- f"order={order} upscale={r.get('upscale')} det={r.get('det_db_box_thresh')}"
- )
- def _print_conclusions(
- stem: str,
- results: List[Dict[str, Any]],
- target: Optional[str],
- *,
- min_hit_score: float = 0.9,
- ) -> None:
- """先列出命中 target 且 score>阈值的全部参数组合,再输出结论。"""
- if not results:
- return
- print(f"\n{'='*70}")
- print(f" 实验结论: {stem}")
- if target:
- print(f" 目标文字: {target!r} 阈值: score > {min_hit_score} (det_rec)")
- print(f"{'='*70}")
- dr_results = [r for r in results if r.get("ocr_mode") == "det_rec" and r.get("text")]
- if not dr_results:
- dr_results = [r for r in results if r.get("text")]
- if not dr_results:
- print(" (无有效 OCR 结果)")
- return
- scored = sorted(dr_results, key=lambda r: -(float(r.get("score") or 0)))
- qualified = _collect_qualified_hits(
- results, target, min_score=min_hit_score, ocr_mode="det_rec"
- )
- # ── 1. 命中 target 且 score > 阈值 的全部参数组合 ──
- print(f"\n 【命中列表】共 {len(qualified)} 组 (target 匹配 + score > {min_hit_score}):")
- if not target:
- print(" (未指定 -t/--target,跳过命中列表)")
- elif not qualified:
- print(" (无满足条件的组合)")
- # 仍展示最接近的 HIT
- near = [
- r for r in scored
- if _match_hit(r.get("text", "") or "", target)
- ]
- if near:
- print(f" 提示: 有 {len(near)} 组命中 target 但 score <= {min_hit_score},最高分:")
- for r in near[:5]:
- print(
- f" score={float(r.get('score', 0)):.4f} text={r.get('text', '')!r}"
- )
- print(f" tag={r.get('tag', '')}")
- else:
- for i, r in enumerate(qualified, 1):
- print(
- f" {i}. score={float(r.get('score', 0)):.4f} "
- f"match={r.get('match')} text={r.get('text', '')!r}"
- )
- print(f" tag={r.get('tag', '')}")
- print(f" {_format_param_row(r)}")
- # ── 2. 结论 ──
- print("\n 【结论】")
- if qualified:
- best = qualified[0]
- print(
- f" 推荐参数组合: {best.get('tag')} "
- f"(score={float(best.get('score', 0)):.4f}, text={best.get('text', '')!r})"
- )
- print(f" {_format_param_row(best)}")
- # 在合格集合内做简要对比
- cm_best: Dict[str, Dict[str, Any]] = {}
- for r in qualified:
- cm = r.get("contrast_method") or "baseline"
- if cm not in cm_best or float(r.get("score") or 0) > float(
- cm_best[cm].get("score") or 0
- ):
- cm_best[cm] = r
- print(" 合格组合内各 contrast 最优:")
- for cm in sorted(cm_best.keys(), key=str):
- r = cm_best[cm]
- print(
- f" [{cm}] score={float(r.get('score', 0)):.4f} tag={r.get('tag', '')}"
- )
- wm_set = {r.get("method") for r in qualified}
- order_set = {r.get("contrast_order") for r in qualified}
- upscale_vals = {r.get("upscale") for r in qualified}
- upscale_set = _sorted_summary_labels(
- {str(v) if v is not None else "baseline" for v in upscale_vals}
- )
- print(
- f" 合格组合涉及: 去水印={_sorted_summary_labels(wm_set)} "
- f"放大顺序={_sorted_summary_labels(order_set)} upscale={upscale_set}"
- )
- print(f" 共 {len(qualified)} 组参数可用于生产配置参考")
- else:
- top = scored[0] if scored else None
- if top and target and _match_hit(top.get("text", "") or "", target):
- print(
- f" 无 score>{min_hit_score} 的命中组合;最高分命中: "
- f"{top.get('tag')} score={float(top.get('score', 0)):.4f}"
- )
- elif top:
- print(
- f" 无命中 target 的组合;全局最高 det_rec: "
- f"{top.get('tag')} score={float(top.get('score', 0)):.4f} text={top.get('text', '')!r}"
- )
- else:
- print(" 无可用结论")
- print(f"{'='*70}\n")
- def _parse_best_config(tag: str) -> Dict[str, Any]:
- """解析最优参数 tag,如 threshold_t150_cl_1.0_8_ob_u128_det0.5。
- tag 格式: {method}_t{thresh}_{c_tag}_o{b|a}_u{upscale}_det{det_th}
- """
- import re
- cfg: Dict[str, Any] = {}
- tag = tag.strip()
- # 解析 method: threshold | masked_adaptive | none
- m = re.match(r"(threshold|masked_adaptive|none)_t(\w+?)_(.+?)_o([ba])_u(\d+)_det([\d.]+)$", tag)
- if not m:
- raise ValueError(f"无法解析 best-config tag: {tag!r}")
- method, thresh_str, c_part, order_char, upscale, det_th = m.groups()
- cfg["method"] = method
- cfg["threshold"] = int(thresh_str) if thresh_str.isdigit() else None
- cfg["contrast_order"] = "before_upscale" if order_char == "b" else "after_upscale"
- cfg["upscale"] = int(upscale)
- cfg["det_db_box_thresh"] = float(det_th)
- # 解析 contrast 部分: cl_1.0_8 | tr_85 | gm_0.85 | ln_2.0_98.0 | c0
- if c_part == "c0":
- cfg["contrast_cfg"] = {"method": "none"}
- elif c_part.startswith("cl_"):
- parts = c_part.split("_")
- cfg["contrast_cfg"] = {"method": "clahe", "clip_limit": float(parts[1]), "tile_grid_size": int(parts[2])}
- elif c_part.startswith("tr_"):
- parts = c_part.split("_")
- cfg["contrast_cfg"] = {"method": "text_restore", "text_black_target": int(parts[1])}
- elif c_part.startswith("gm_"):
- parts = c_part.split("_")
- cfg["contrast_cfg"] = {"method": "gamma", "gamma": float(parts[1])}
- elif c_part.startswith("ln_"):
- parts = c_part.split("_")
- cfg["contrast_cfg"] = {"method": "linear", "black_percentile": float(parts[1]), "white_percentile": float(parts[2])}
- else:
- raise ValueError(f"无法解析 contrast tag: {c_part!r} (in {tag})")
- return cfg
- def _tag_to_cell_preprocess_yaml(tag: str) -> Dict[str, Any]:
- """将 sweep tag 转为 second_pass_ocr.cell_preprocess 片段(Pass1 + Pass2)。"""
- cfg = _parse_best_config(tag)
- cm = cfg["contrast_cfg"].get("method", "none")
- cpp: Dict[str, Any] = {
- "watermark": {
- "enabled": cfg["method"] != "none",
- "method": cfg["method"],
- },
- "upscale_min_side": cfg["upscale"],
- }
- if cfg["method"] == "threshold" and cfg.get("threshold") is not None:
- cpp["watermark"]["threshold"] = cfg["threshold"]
- if cm != "none":
- contrast = {"enabled": True, "method": cm, **{k: v for k, v in cfg["contrast_cfg"].items() if k != "method"}}
- cpp["contrast"] = contrast
- else:
- cpp["contrast"] = {"enabled": False}
- pass2_contrast: Dict[str, Any] = dict(cpp.get("contrast") or {"enabled": False})
- if pass2_contrast.get("enabled") and pass2_contrast.get("method") == "clahe":
- pass2_contrast = dict(pass2_contrast)
- pass2_contrast["tile_grid_size"] = int(4)
- cpp["enhance_retry"] = {
- "enabled": True,
- "upscale_min_side": cfg["upscale"],
- "contrast": pass2_contrast,
- }
- return cpp
- def _qualified_from_report(report: Dict[str, Any], min_score: float) -> List[Dict[str, Any]]:
- qh = report.get("hits_target_score_above")
- if qh is not None:
- return list(qh)
- return _collect_qualified_hits(
- report.get("all_results") or [],
- report.get("target"),
- min_score=min_score,
- ocr_mode="det_rec",
- )
- def _majority_key(items: Sequence[Any]) -> Any:
- from collections import Counter
- if not items:
- return None
- return Counter(items).most_common(1)[0][0]
- def aggregate_sweep_reports(
- report_paths: Sequence[Path],
- *,
- min_hit_score: float = 0.9,
- ) -> Dict[str, Any]:
- """跨多个 sweep_report.json 汇总,输出 Pass1/Pass2 配置建议。"""
- per_case: List[Dict[str, Any]] = []
- all_qualified_tags: List[set] = []
- for rp in report_paths:
- if not rp.is_file():
- per_case.append({"report": str(rp), "error": "missing"})
- continue
- report = json.loads(rp.read_text(encoding="utf-8"))
- qualified = _qualified_from_report(report, min_hit_score)
- tags = [str(r.get("tag") or "") for r in qualified if r.get("tag")]
- all_qualified_tags.append(set(tags))
- top = qualified[0] if qualified else None
- per_case.append(
- {
- "report": str(rp),
- "input": report.get("input"),
- "target": report.get("target"),
- "qualified_count": len(qualified),
- "top_tag": top.get("tag") if top else None,
- "top_score": top.get("score") if top else None,
- "top_text": top.get("text") if top else None,
- }
- )
- ok_cases = [c for c in per_case if c.get("top_tag")]
- intersection: set = set.intersection(*all_qualified_tags) if all_qualified_tags else set()
- intersection_sorted = sorted(
- intersection,
- key=lambda t: -max(
- float(r.get("score") or 0)
- for rp in report_paths
- if rp.is_file()
- for r in _qualified_from_report(json.loads(rp.read_text(encoding="utf-8")), min_hit_score)
- if r.get("tag") == t
- ),
- )
- pick_tag: Optional[str] = None
- pick_reason = ""
- majority_fields: Dict[str, Any] = {}
- if intersection_sorted:
- pick_tag = intersection_sorted[0]
- pick_reason = f"全部 {len(report_paths)} 个 case 的合格集合交集,取最高分 tag"
- elif ok_cases:
- parsed: List[Dict[str, Any]] = []
- for c in ok_cases:
- tag = c["top_tag"]
- if tag and str(tag).startswith("baseline"):
- continue
- try:
- parsed.append(_parse_best_config(str(tag)))
- except ValueError:
- continue
- if parsed:
- pick_reason = "无全局 tag 交集,按各 case 榜首做分字段多数票(跳过 baseline)"
- method = _majority_key([p["method"] for p in parsed])
- thresh = _majority_key([p.get("threshold") for p in parsed])
- upscale = _majority_key([p["upscale"] for p in parsed])
- cm = _majority_key([p["contrast_cfg"].get("method") for p in parsed])
- c_cfgs = [p["contrast_cfg"] for p in parsed if p["contrast_cfg"].get("method") == cm]
- contrast_cfg: Dict[str, Any] = {"method": cm or "none"}
- if cm == "clahe" and c_cfgs:
- contrast_cfg["clip_limit"] = _majority_key([c.get("clip_limit") for c in c_cfgs])
- contrast_cfg["tile_grid_size"] = _majority_key(
- [c.get("tile_grid_size") for c in c_cfgs]
- )
- synthetic = {
- "method": method,
- "threshold": thresh,
- "contrast_cfg": contrast_cfg,
- "contrast_order": _majority_key([p["contrast_order"] for p in parsed]),
- "upscale": upscale,
- "det_db_box_thresh": _majority_key([p["det_db_box_thresh"] for p in parsed]),
- }
- c_tag = _contrast_tag(contrast_cfg)
- o_tag = "b" if synthetic["contrast_order"] == "before_upscale" else "a"
- t_s = str(synthetic["threshold"] or "d")
- pick_tag = (
- f"{method}_t{t_s}_{c_tag}_o{o_tag}_u{synthetic['upscale']}"
- f"_det{synthetic['det_db_box_thresh']}"
- )
- majority_fields = synthetic
- else:
- pick_tag = _majority_key([c["top_tag"] for c in ok_cases])
- pick_reason = "无全局交集,按各 case 榜首 tag 多数票(含 baseline)"
- pass1_cpp = _tag_to_cell_preprocess_yaml(pick_tag) if pick_tag else None
- def _is_pass2_tile_tag(t: str) -> bool:
- import re
- return bool(re.search(r"cl_[\d.]+_4_", t))
- pass2_candidates = [t for t in intersection_sorted if _is_pass2_tile_tag(t)]
- pass2_tag = pass2_candidates[0] if pass2_candidates else pick_tag
- pass2_cpp = _tag_to_cell_preprocess_yaml(pass2_tag) if pass2_tag else None
- return {
- "min_hit_score": min_hit_score,
- "per_case": per_case,
- "intersection_tags": intersection_sorted[:20],
- "intersection_count": len(intersection),
- "recommended_tag": pick_tag,
- "pick_reason": pick_reason,
- "majority_fields": majority_fields,
- "pass1_cell_preprocess": pass1_cpp,
- "pass2_enhance_retry": (pass2_cpp or {}).get("enhance_retry") if pass2_cpp else None,
- "pass2_tag": pass2_tag,
- }
- _CASE_META_KEYS = frozenset({"name", "note", "description"})
- def _config_to_argv(cfg: Dict[str, Any]) -> List[str]:
- argv = [str(cfg["input"])]
- for key, value in cfg.items():
- if key == "input" or key in _CASE_META_KEYS:
- continue
- flag = f"--{key.replace('_', '-')}"
- if isinstance(value, bool) and value:
- argv.append(flag)
- elif not isinstance(value, bool):
- argv.extend([flag, str(value)])
- return argv
- def _load_cases_json(path: Path) -> List[Dict[str, Any]]:
- data = json.loads(path.read_text(encoding="utf-8"))
- shared = data.get("shared") or {}
- cases = data.get("cases") or []
- out: List[Dict[str, Any]] = []
- for c in cases:
- merged = {**shared, **c}
- out.append(merged)
- return out
- def _find_sweep_report(case_output: Path, input_path: Path) -> Optional[Path]:
- stem = Path(str(input_path)).stem.removesuffix("_raw")
- direct = case_output / stem / "sweep_report.json"
- if direct.is_file():
- return direct
- hits = sorted(case_output.rglob("sweep_report.json"))
- return hits[0] if hits else None
- def run_cases_batch(cases: Sequence[Dict[str, Any]]) -> List[Dict[str, Any]]:
- """顺序执行多组 case(每组独立 output)。"""
- summaries: List[Dict[str, Any]] = []
- for i, case in enumerate(cases, 1):
- name = case.get("name") or case.get("input")
- print(f"\n{'#'*70}\n 批量 case {i}/{len(cases)}: {name}\n{'#'*70}")
- argv = _config_to_argv(case)
- main(argv)
- out = Path(case["output"])
- rp = _find_sweep_report(out, Path(case["input"]))
- summaries.append({"name": name, "output": str(out), "report": str(rp) if rp else None})
- return summaries
- def run_best_config(
- input_path: Path,
- out_dir: Path,
- *,
- prefer_raw: bool,
- best_cfg: Dict[str, Any],
- model_dir: Path,
- save_images: bool,
- ) -> Dict[str, Any]:
- """对单图用指定最优参数跑一次 OCR。"""
- resolved = resolve_input_image(input_path, prefer_raw=prefer_raw)
- raw = cv2.imread(str(resolved))
- if raw is None:
- raise RuntimeError(f"无法读取图像: {resolved}")
- stem = resolved.stem.removesuffix("_raw") if resolved.stem.endswith("_raw") else resolved.stem
- cell_out = out_dir / stem
- cell_out.mkdir(parents=True, exist_ok=True)
- engine = _make_engine(best_cfg["det_db_box_thresh"], model_dir)
- img = _preprocess(
- raw,
- method=best_cfg["method"],
- thresh=best_cfg.get("threshold"),
- contrast_cfg=best_cfg["contrast_cfg"],
- upscale=best_cfg["upscale"],
- contrast_order=best_cfg["contrast_order"],
- )
- tag = best_cfg.get("_tag", "best")
- if save_images:
- cv2.imwrite(str(cell_out / f"{tag}.png"), img)
- ocr = _ocr(engine, img, det=True, rec=True)
- row: Dict[str, Any] = {
- "tag": tag,
- "method": best_cfg["method"],
- "threshold": best_cfg.get("threshold"),
- "contrast_method": best_cfg["contrast_cfg"].get("method", "none"),
- "contrast_order": best_cfg["contrast_order"],
- "contrast_cfg": best_cfg["contrast_cfg"],
- "upscale": best_cfg["upscale"],
- "det_db_box_thresh": best_cfg["det_db_box_thresh"],
- "ocr_mode": "det_rec",
- **ocr,
- }
- report = {
- "input": str(resolved),
- "input_requested": str(input_path),
- "output_dir": str(cell_out),
- "result": row,
- }
- report_path = cell_out / "best_result.json"
- report_path.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
- return report
- def _build_arg_parser() -> argparse.ArgumentParser:
- p = argparse.ArgumentParser(
- description="单元格图预处理 + OCR 参数网格扫描(对齐 pipeline 格级二次 OCR)",
- )
- p.add_argument(
- "input",
- type=Path,
- nargs="?",
- default=None,
- help="单元格裁剪图路径,或 tablecell_ocr 目录(批量扫描)",
- )
- p.add_argument(
- "-o",
- "--output",
- type=Path,
- default=None,
- help="输出目录,默认 <input_dir|input_parent>/sweep_out/<stem>",
- )
- p.add_argument(
- "-t",
- "--target",
- default=None,
- help="期望 OCR 文本;用于标记 HIT(子串匹配)。省略则任意非空为 HIT",
- )
- p.add_argument(
- "--min-hit-score",
- type=float,
- default=0.9,
- help="结论中「命中列表」的最低 score 阈值(默认 0.9,仅 det_rec)",
- )
- p.add_argument(
- "--model-dir",
- type=Path,
- default=None,
- help="PaddleOCR torch 模型目录(含 det/rec .pth),也可用 OCR_*_MODEL_PATH",
- )
- p.add_argument(
- "--no-prefer-raw",
- action="store_true",
- help="不自动选用同名的 *_raw.png",
- )
- p.add_argument(
- "--quick",
- action="store_true",
- help="缩小网格(threshold 155,165 × upscale 128,192 × det 0.5 × contrast 精简)",
- )
- p.add_argument(
- "--methods",
- default="threshold,masked_adaptive,none",
- help="去水印方式,逗号分隔;none=不去水印",
- )
- p.add_argument(
- "--thresholds",
- default="155,165,none",
- help="threshold 法的阈值;none=预设默认",
- )
- p.add_argument(
- "--contrast-orders",
- default="before_upscale,after_upscale",
- help="contrast 执行顺序: before_upscale(放大前), after_upscale(放大后), 逗号组合",
- )
- p.add_argument(
- "--upscales",
- default="128,192",
- help="最短边放大目标,逗号分隔整数",
- )
- p.add_argument(
- "--det-threshs",
- # default="0.2,0.3,0.4,0.5",
- default="0.5",
- help="det_db_box_thresh,逗号分隔",
- )
- p.add_argument(
- "--no-save-images",
- action="store_true",
- help="不写出中间预处理 png(仅报告)",
- )
- p.add_argument(
- "--no-baseline",
- action="store_true",
- help="跳过「仅放大、不去水印」对照组",
- )
- p.add_argument(
- "--baseline-upscale",
- type=int,
- default=192,
- help="baseline 对照组的最短边放大",
- )
- p.add_argument(
- "--best-only",
- action="store_true",
- help="不跑参数网格,对目录下所有图用 --best-config 指定参数跑一次,验证适配性",
- )
- p.add_argument(
- "--best-config",
- default="threshold_t150_cl_1.0_8_ob_u128_det0.5",
- help="最优参数 tag,如 threshold_t150_cl_1.0_8_ob_u128_det0.5",
- )
- p.add_argument(
- "--cases",
- type=Path,
- default=None,
- help="批量 case JSON(见 sweep_cases.json),顺序跑网格扫描",
- )
- p.add_argument(
- "--aggregate-only",
- action="store_true",
- help="不扫描,仅根据已有 sweep_report.json 汇总 Pass1/Pass2 建议",
- )
- p.add_argument(
- "--aggregate-out",
- type=Path,
- default=None,
- help="汇总输出路径(默认 <cases 父目录>/aggregate_recommendation.json)",
- )
- return p
- def main(argv: Optional[Sequence[str]] = None) -> None:
- args = _build_arg_parser().parse_args(argv)
- if args.aggregate_only:
- if not args.cases or not args.cases.is_file():
- raise SystemExit("--aggregate-only 需要 --cases <sweep_cases.json>")
- cases = _load_cases_json(args.cases)
- report_paths: List[Path] = []
- for case in cases:
- out = Path(case["output"])
- rp = _find_sweep_report(out, Path(case["input"]))
- if rp:
- report_paths.append(rp)
- else:
- print(f" 跳过(无报告): {case.get('name') or case['input']}")
- agg = aggregate_sweep_reports(
- report_paths, min_hit_score=args.min_hit_score
- )
- agg_out = args.aggregate_out or args.cases.parent / "aggregate_recommendation.json"
- agg_out.parent.mkdir(parents=True, exist_ok=True)
- agg_out.write_text(json.dumps(agg, ensure_ascii=False, indent=2), encoding="utf-8")
- print(f"\n汇总完成 -> {agg_out}")
- print(f" 推荐 tag: {agg.get('recommended_tag')}")
- print(f" 交集合格 tag 数: {agg.get('intersection_count')}")
- if agg.get("pass1_cell_preprocess"):
- print("\n Pass1 cell_preprocess 片段:")
- print(json.dumps(agg["pass1_cell_preprocess"], ensure_ascii=False, indent=2))
- if agg.get("pass2_enhance_retry"):
- print("\n Pass2 enhance_retry 片段:")
- print(json.dumps(agg["pass2_enhance_retry"], ensure_ascii=False, indent=2))
- return
- if args.cases and args.cases.is_file():
- if args.input is not None:
- print(" 提示: 已指定 --cases,忽略 positional input")
- cases = _load_cases_json(args.cases)
- run_cases_batch(cases)
- report_paths = []
- for case in cases:
- rp = _find_sweep_report(Path(case["output"]), Path(case["input"]))
- if rp:
- report_paths.append(rp)
- if report_paths:
- agg = aggregate_sweep_reports(
- report_paths, min_hit_score=args.min_hit_score
- )
- agg_out = args.aggregate_out or args.cases.parent / "aggregate_recommendation.json"
- agg_out.write_text(json.dumps(agg, ensure_ascii=False, indent=2), encoding="utf-8")
- print(f"\n批量汇总 -> {agg_out}")
- print(f" 推荐 tag: {agg.get('recommended_tag')} ({agg.get('pick_reason')})")
- return
- if args.input is None:
- raise SystemExit("需要 input 路径,或使用 --cases / --aggregate-only")
- inputs = collect_inputs(args.input, prefer_raw=not args.no_prefer_raw)
- if not inputs:
- raise SystemExit("未找到可扫描的图像")
- if args.output is not None:
- out_root = args.output
- elif args.input.is_file():
- out_root = args.input.parent / "sweep_out"
- else:
- out_root = args.input / "sweep_out"
- out_root.mkdir(parents=True, exist_ok=True)
- model_dir = args.model_dir or _default_model_dir()
- if args.best_only:
- # 验证适配性模式:对目录下所有图用最优参数跑一次
- best_cfg = _parse_best_config(args.best_config)
- best_cfg["_tag"] = args.best_config
- print(f"最佳参数验证模式: {args.best_config}")
- print(f" 解析: method={best_cfg['method']} contrast={best_cfg['contrast_cfg'].get('method')} "
- f"upscale={best_cfg['upscale']} order={best_cfg['contrast_order']}")
- print(f" 共 {len(inputs)} 张图")
- all_texts: List[Dict[str, Any]] = []
- hit_count = 0
- for img_path in inputs:
- report = run_best_config(
- img_path, out_root,
- prefer_raw=not args.no_prefer_raw,
- best_cfg=best_cfg,
- model_dir=model_dir,
- save_images=not args.no_save_images,
- )
- result = report["result"]
- text = result.get("text", "")
- score = result.get("score", 0)
- all_texts.append({
- "input": img_path.name,
- "text": text,
- "score": score,
- "report": str(Path(report["output_dir"]) / "best_result.json"),
- })
- m = _match_hit(text, args.target)
- hit_info = f" [HIT: {m}]" if m else ""
- print(f" {img_path.name}: score={score:.4f} text={text!r}{hit_info}")
- if m:
- hit_count += 1
- # 汇总
- summary_path = out_root / "best_summary.json"
- summary_data = {
- "best_config": args.best_config,
- "total": len(all_texts),
- "hits": hit_count,
- "target": args.target,
- "results": all_texts,
- }
- summary_path.write_text(json.dumps(summary_data, ensure_ascii=False, indent=2), encoding="utf-8")
- print(f"\n汇总: {hit_count}/{len(all_texts)} HIT -> {summary_path}")
- return
- # 正常参数网格扫描模式
- methods = [m.strip() for m in args.methods.split(",") if m.strip()]
- contrast_orders = [o.strip() for o in args.contrast_orders.split(",") if o.strip()]
- if args.quick:
- thresholds = [150, 155]
- upscales = [96, 128, 192]
- det_threshs = [0.5]
- else:
- thresholds = _parse_csv_ints(args.thresholds)
- upscales = [int(x) for x in args.upscales.split(",") if x.strip()]
- det_threshs = _parse_csv_floats(args.det_threshs)
- contrast_grid = _build_contrast_grid(quick=args.quick)
- print(f"扫描 {len(inputs)} 张图 -> {out_root}")
- print(f" methods={methods} thresholds={thresholds} upscales={upscales}")
- print(f" contrast_methods={len(contrast_grid)} orders={contrast_orders}")
- if args.target:
- print(f" target={args.target!r}")
- summary: List[Dict[str, Any]] = []
- for img_path in inputs:
- print(f"\n=== {img_path.name} ===")
- report = run_sweep(
- img_path,
- out_root,
- prefer_raw=not args.no_prefer_raw,
- target=args.target,
- model_dir=model_dir,
- methods=methods,
- thresholds=thresholds,
- contrast_grid=contrast_grid,
- contrast_orders=contrast_orders,
- upscales=upscales,
- det_threshs=det_threshs,
- save_images=not args.no_save_images,
- run_baseline=not args.no_baseline,
- baseline_upscale=args.baseline_upscale,
- min_hit_score=args.min_hit_score,
- )
- qh = report.get("hits_target_score_above") or []
- summary.append(
- {
- "input": report["input"],
- "hits": len(report["hits"]),
- "hits_target_score_above": len(qh),
- "top_qualified_tag": qh[0]["tag"] if qh else None,
- "top_qualified_score": qh[0]["score"] if qh else None,
- "report": str(Path(report["output_dir"]) / "sweep_report.json"),
- }
- )
- index_path = out_root / "sweep_index.json"
- index_path.write_text(
- json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8"
- )
- print(f"\n全部完成,索引: {index_path}")
- for s in summary:
- qn = s.get("hits_target_score_above", 0)
- top = s.get("top_qualified_tag")
- print(
- f" {s['input']}: hits={s['hits']} "
- f"qualified(score>{args.min_hit_score})={qn}"
- + (f" top={top}" if top else "")
- + f" -> {s['report']}"
- )
- if __name__ == "__main__":
- if len(sys.argv) == 1:
- cases_path = Path(__file__).resolve().parent / "sweep_cases.json"
- print(f"ℹ️ 未提供命令行参数,使用批量 cases: {cases_path.name}")
- sys.argv = [sys.argv[0], "--cases", str(cases_path)]
- sys.exit(main())
|