cell_sweep.py 47 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334
  1. #!/usr/bin/env python3
  2. """
  3. 单元格裁剪图预处理参数扫描:去水印 / contrast(clahe/gamma/linear/text_restore)/ upscale / det 阈值 / OCR 模式。
  4. 支持 contrast 在放大前/后执行两种顺序对比。
  5. 默认从 **原图**(`*_raw.png`)出发,与 pipeline 二次 OCR 一致,避免对已预处理 debug 图二次去水印。
  6. 用法:
  7. python cell_sweep.py cell219_empty_empty_raw.png -o ./out -t "ATM存折取款"
  8. python cell_sweep.py /path/to/tablecell_ocr/ -o ./out
  9. python cell_sweep.py cell.png --quick --no-save-images
  10. python cell_sweep.py cell.png --contrast-orders before_upscale,after_upscale
  11. OCR_DET_MODEL_PATH=... OCR_REC_MODEL_PATH=... python cell_sweep.py cell.png
  12. # 统计出的最优参数 tag: threshold_t150_cl_1.0_8_ob_u128_det0.5
  13. # 对目录下所有 *_raw.png 验证适配性
  14. python cell_sweep.py /path/to/tablecell_ocr/ -o ./out --best-only
  15. # 自定义最优参数
  16. python cell_sweep.py /path/to/tablecell_ocr/ -o ./out --best-only \
  17. --best-config threshold_t150_cl_1.0_8_ob_u128_det0.5
  18. # 指定目标文字,自动统计 HIT 命中率
  19. python cell_sweep.py /path/to/tablecell_ocr/ -o ./out --best-only -t "交易类型"
  20. # 多 case 批量扫描 + 汇总 Pass1/Pass2 建议
  21. python cell_sweep.py --cases sweep_cases.json
  22. python cell_sweep.py --aggregate-only --cases sweep_cases.json
  23. """
  24. from __future__ import annotations
  25. import argparse
  26. import json
  27. import os
  28. import sys
  29. from itertools import product
  30. from pathlib import Path
  31. from typing import Any, Dict, List, Optional, Sequence, Tuple
  32. import cv2
  33. import numpy as np
  34. _repo_root = Path(__file__).resolve().parents[3]
  35. if str(_repo_root) not in sys.path:
  36. sys.path.insert(0, str(_repo_root))
  37. from ocr_utils.watermark import WatermarkProcessor, merge_watermark_config
  38. from ocr_utils.watermark.contrast import enhance_document_contrast
  39. _IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp"}
  40. _DEFAULT_MODEL_DIR = Path(
  41. "/Users/zhch158/models/modelscope_cache/models/OpenDataLab/"
  42. "PDF-Extract-Kit-1___0/models/OCR/paddleocr_torch"
  43. )
  44. def _parse_csv_ints(s: str) -> List[Optional[int]]:
  45. out: List[Optional[int]] = []
  46. for part in s.split(","):
  47. part = part.strip()
  48. if not part or part.lower() in ("none", "d", "default"):
  49. out.append(None)
  50. else:
  51. out.append(int(part))
  52. return out
  53. def _parse_csv_floats(s: str) -> List[float]:
  54. return [float(x.strip()) for x in s.split(",") if x.strip()]
  55. def _parse_csv_bools(s: str) -> List[bool]:
  56. out: List[bool] = []
  57. for part in s.split(","):
  58. p = part.strip().lower()
  59. if p in ("1", "true", "yes", "on"):
  60. out.append(True)
  61. elif p in ("0", "false", "no", "off"):
  62. out.append(False)
  63. else:
  64. raise ValueError(f"无效的 bool 值: {part!r}")
  65. return out
  66. def _default_model_dir() -> Path:
  67. det = os.environ.get("OCR_DET_MODEL_PATH")
  68. if det:
  69. return Path(det).parent
  70. return _DEFAULT_MODEL_DIR
  71. def _upscale(img: np.ndarray, min_side: int) -> np.ndarray:
  72. h, w = img.shape[:2]
  73. if h >= min_side and w >= min_side:
  74. return img
  75. s = max(min_side / max(h, 1), min_side / max(w, 1), 1.0)
  76. return cv2.resize(img, None, fx=s, fy=s, interpolation=cv2.INTER_CUBIC)
  77. # ── 对比度增强方法(clahe / gamma / linear / text_restore / none)──
  78. def _apply_contrast(
  79. gray: np.ndarray,
  80. *,
  81. method: str,
  82. clip_limit: float = 1.0,
  83. tile_grid_size: int = 8,
  84. gamma: float = 0.85,
  85. black_percentile: float = 2.0,
  86. white_percentile: float = 98.0,
  87. text_black_target: int = 85,
  88. background_threshold: int = 248,
  89. ) -> np.ndarray:
  90. """对灰度图应用对比度增强;method="none" 时原样返回。"""
  91. if method == "none":
  92. return gray
  93. if method == "text_restore":
  94. return enhance_document_contrast(
  95. gray, method="text_restore",
  96. text_black_target=text_black_target,
  97. background_threshold=background_threshold,
  98. )
  99. if method == "clahe":
  100. return enhance_document_contrast(
  101. gray, method="clahe",
  102. clip_limit=clip_limit, tile_grid_size=tile_grid_size,
  103. )
  104. if method == "gamma":
  105. return enhance_document_contrast(gray, method="gamma", gamma=gamma)
  106. if method == "linear":
  107. return enhance_document_contrast(
  108. gray, method="linear",
  109. black_percentile=black_percentile,
  110. white_percentile=white_percentile,
  111. )
  112. return gray
  113. def _contrast_tag(cfg: Dict[str, Any]) -> str:
  114. """生成 contrast 配置的短标签。"""
  115. m = cfg.get("method", "none")
  116. if m == "none":
  117. return "c0"
  118. if m == "text_restore":
  119. return f"tr_{cfg.get('text_black_target', 85)}"
  120. if m == "clahe":
  121. return f"cl_{cfg.get('clip_limit', 1.0)}_{cfg.get('tile_grid_size', 8)}"
  122. if m == "gamma":
  123. return f"gm_{cfg.get('gamma', 0.85)}"
  124. if m == "linear":
  125. return f"ln_{cfg.get('black_percentile', 2.0)}_{cfg.get('white_percentile', 98.0)}"
  126. return m
  127. def _build_contrast_grid(quick: bool = False) -> List[Dict[str, Any]]:
  128. """构建 contrast 参数网格(对齐 contrast_sweep.py 的设计)。
  129. 返回列表,每个元素是一个 Dict,至少包含 "method" 字段。
  130. """
  131. grid: List[Dict[str, Any]] = [{"method": "none"}] # 对照组:不增强
  132. # text_restore
  133. if quick:
  134. tbt = [60, 85]
  135. bts = [240, 248]
  136. else:
  137. tbt = [60, 85, 100, 120]
  138. bts = [240, 248, 252]
  139. for target, bg_th in product(tbt, bts):
  140. grid.append({"method": "text_restore", "text_black_target": target, "background_threshold": bg_th})
  141. # clahe
  142. if quick:
  143. cl = [1.0, 2.0]
  144. ts = [4, 8]
  145. else:
  146. cl = [0.5, 1.0, 2.0, 3.0, 5.0]
  147. ts = [4, 8]
  148. for clip, tile in product(cl, ts):
  149. grid.append({"method": "clahe", "clip_limit": clip, "tile_grid_size": tile})
  150. # # gamma
  151. # if quick:
  152. # gvs = [0.5, 0.85]
  153. # else:
  154. # gvs = [0.4, 0.55, 0.7, 0.85]
  155. # for g in gvs:
  156. # grid.append({"method": "gamma", "gamma": g})
  157. # # linear
  158. # if quick:
  159. # bps = [2.0, 5.0]
  160. # wps = [95.0, 98.0]
  161. # else:
  162. # bps = [2.0, 5.0, 8.0]
  163. # wps = [95.0, 98.0]
  164. # for bp, wp in product(bps, wps):
  165. # grid.append({"method": "linear", "black_percentile": bp, "white_percentile": wp})
  166. return grid
  167. def _preprocess(
  168. raw: np.ndarray,
  169. *,
  170. method: str,
  171. thresh: Optional[int],
  172. contrast_cfg: Dict[str, Any],
  173. upscale: int,
  174. contrast_order: str = "before_upscale",
  175. ) -> np.ndarray:
  176. """预处理管线:去水印 → [contrast] → 放大(或去水印 → 放大 → contrast)。
  177. method="none" 时跳过去水印,直接从原图开始处理。
  178. """
  179. if method == "none":
  180. img = raw.copy() # 不处理水印,直接使用原图
  181. else:
  182. user: Dict[str, Any] = {"enabled": True, "method": method}
  183. if method == "threshold" and thresh is not None:
  184. user["threshold"] = thresh
  185. cfg = merge_watermark_config("cell", user)
  186. img, _ = WatermarkProcessor(cfg, scope="cell").process(raw, force=True)
  187. contrast_method = contrast_cfg.get("method", "none")
  188. if contrast_method != "none" and contrast_order == "before_upscale":
  189. gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  190. gray = _apply_contrast(gray, **contrast_cfg)
  191. img = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
  192. img = _upscale(img, upscale)
  193. if contrast_method != "none" and contrast_order == "after_upscale":
  194. gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  195. gray = _apply_contrast(gray, **contrast_cfg)
  196. img = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
  197. return img
  198. def _parse_rec_pair(rec_part: Any) -> Tuple[str, float]:
  199. """从 OCR 返回的 (text, score) 或嵌套结构中解析识别结果。"""
  200. if rec_part is None:
  201. return "", 0.0
  202. if isinstance(rec_part, (list, tuple)) and len(rec_part) >= 2:
  203. if isinstance(rec_part[0], (list, tuple, dict)):
  204. return "", 0.0
  205. txt = str(rec_part[0] or "").strip()
  206. try:
  207. sc = float(rec_part[1] or 0.0)
  208. except (TypeError, ValueError):
  209. sc = 0.0
  210. return txt, sc if txt else 0.0
  211. if isinstance(rec_part, (list, tuple)) and len(rec_part) == 1:
  212. txt = str(rec_part[0] or "").strip()
  213. return txt, 0.0
  214. return "", 0.0
  215. def _aggregate_rec_score(boxes: List[Dict[str, Any]]) -> float:
  216. """按字符数加权平均识别分(与 pipeline aggregate_line_ocr 一致)。"""
  217. total_len = sum(len(b.get("text") or "") for b in boxes)
  218. if total_len <= 0:
  219. return 0.0
  220. weighted = sum(
  221. len(b.get("text") or "") * float(b.get("score") or 0.0) for b in boxes
  222. )
  223. return weighted / total_len
  224. def _ocr(engine: Any, img: np.ndarray, *, det: bool, rec: bool) -> Dict[str, Any]:
  225. empty: Dict[str, Any] = {
  226. "text": "",
  227. "score": 0.0,
  228. "boxes": [],
  229. "det": det,
  230. "rec": rec,
  231. "n_boxes": 0,
  232. }
  233. try:
  234. res = engine.ocr(img, det=det, rec=rec)
  235. items = res[0] if res and res[0] is not None else []
  236. boxes_out: List[Dict[str, Any]] = []
  237. if det:
  238. for item in items:
  239. if not item or len(item) < 2:
  240. continue
  241. text, score = _parse_rec_pair(item[1])
  242. bbox = item[0]
  243. if hasattr(bbox, "tolist"):
  244. bbox = bbox.tolist()
  245. entry: Dict[str, Any] = {
  246. "text": text,
  247. "score": round(score, 6),
  248. }
  249. if bbox is not None:
  250. entry["det_bbox"] = bbox
  251. boxes_out.append(entry)
  252. else:
  253. for item in items:
  254. text, score = _parse_rec_pair(item)
  255. if not text and isinstance(item, (list, tuple)) and len(item) >= 1:
  256. text, score = _parse_rec_pair(item[0])
  257. boxes_out.append({"text": text, "score": round(score, 6)})
  258. text = "".join(b["text"] for b in boxes_out if b.get("text")).strip()
  259. agg_score = _aggregate_rec_score(boxes_out)
  260. return {
  261. "text": text,
  262. "score": round(agg_score, 6),
  263. "boxes": boxes_out,
  264. "det": det,
  265. "rec": rec,
  266. "n_boxes": len(boxes_out),
  267. }
  268. except Exception as e:
  269. out = dict(empty)
  270. out["error"] = str(e)
  271. return out
  272. def _make_engine(det_thresh: float, model_dir: Path) -> Any:
  273. from ocr_tools.pytorch_models.pytorch_paddle import PytorchPaddleOCR
  274. det_path = os.environ.get("OCR_DET_MODEL_PATH") or str(
  275. model_dir / "ch_PP-OCRv5_det_infer.pth"
  276. )
  277. rec_path = os.environ.get("OCR_REC_MODEL_PATH") or str(
  278. model_dir / "ch_PP-OCRv4_rec_server_doc_infer.pth"
  279. )
  280. return PytorchPaddleOCR(
  281. lang="ch",
  282. det_model_path=det_path,
  283. rec_model_path=rec_path,
  284. det_db_box_thresh=det_thresh,
  285. )
  286. def resolve_input_image(path: Path, *, prefer_raw: bool) -> Path:
  287. """优先使用与 pipeline debug 配套的 *_raw.png。"""
  288. if not prefer_raw or path.stem.endswith("_raw"):
  289. return path
  290. raw_path = path.parent / f"{path.stem}_raw{path.suffix}"
  291. if raw_path.is_file():
  292. print(f" 使用原图: {raw_path.name}(跳过 {path.name})")
  293. return raw_path
  294. return path
  295. def collect_inputs(path: Path, *, prefer_raw: bool) -> List[Path]:
  296. if path.is_file():
  297. if path.suffix.lower() not in _IMAGE_SUFFIXES:
  298. raise ValueError(f"不支持的图像格式: {path}")
  299. return [resolve_input_image(path, prefer_raw=prefer_raw)]
  300. if not path.is_dir():
  301. raise FileNotFoundError(path)
  302. all_images = sorted(
  303. p
  304. for p in path.iterdir()
  305. if p.is_file() and p.suffix.lower() in _IMAGE_SUFFIXES
  306. )
  307. if not all_images:
  308. raise FileNotFoundError(f"目录内无图像: {path}")
  309. if prefer_raw:
  310. raws = [p for p in all_images if p.stem.endswith("_raw")]
  311. if raws:
  312. return raws
  313. chosen: List[Path] = []
  314. for p in all_images:
  315. if p.stem.endswith("_raw"):
  316. continue
  317. raw_sibling = p.parent / f"{p.stem}_raw{p.suffix}"
  318. if prefer_raw and raw_sibling.is_file():
  319. continue
  320. chosen.append(p)
  321. return chosen or all_images
  322. def _match_hit(text: str, target: Optional[str]) -> Optional[str]:
  323. if not text:
  324. return None
  325. if not target:
  326. return "nonempty"
  327. if target in text:
  328. return "full"
  329. if len(target) >= 6 and target.isdigit() and len(text) >= 6 and text.isdigit():
  330. return "partial"
  331. return None
  332. def _collect_qualified_hits(
  333. results: List[Dict[str, Any]],
  334. target: Optional[str],
  335. *,
  336. min_score: float = 0.9,
  337. ocr_mode: str = "det_rec",
  338. ) -> List[Dict[str, Any]]:
  339. """命中 target 且 score > min_score 的组合(按 tag 去重,保留最高分)。"""
  340. if not target:
  341. return []
  342. by_tag: Dict[str, Dict[str, Any]] = {}
  343. for r in results:
  344. if r.get("ocr_mode") != ocr_mode:
  345. continue
  346. if not _match_hit(r.get("text", "") or "", target):
  347. continue
  348. score = float(r.get("score") or 0)
  349. if score <= min_score:
  350. continue
  351. tag = str(r.get("tag") or "")
  352. prev = by_tag.get(tag)
  353. if prev is None or score > float(prev.get("score") or 0):
  354. by_tag[tag] = r
  355. return sorted(by_tag.values(), key=lambda x: -(float(x.get("score") or 0)))
  356. def run_sweep(
  357. input_path: Path,
  358. out_dir: Path,
  359. *,
  360. prefer_raw: bool,
  361. target: Optional[str],
  362. model_dir: Path,
  363. methods: Sequence[str],
  364. thresholds: Sequence[Optional[int]],
  365. contrast_grid: List[Dict[str, Any]],
  366. contrast_orders: Sequence[str],
  367. upscales: Sequence[int],
  368. det_threshs: Sequence[float],
  369. save_images: bool,
  370. run_baseline: bool,
  371. baseline_upscale: int,
  372. min_hit_score: float = 0.9,
  373. ) -> Dict[str, Any]:
  374. resolved = resolve_input_image(input_path, prefer_raw=prefer_raw)
  375. raw = cv2.imread(str(resolved))
  376. if raw is None:
  377. raise RuntimeError(f"无法读取图像: {resolved}")
  378. stem = resolved.stem.removesuffix("_raw") if resolved.stem.endswith("_raw") else resolved.stem
  379. cell_out = out_dir / stem
  380. cell_out.mkdir(parents=True, exist_ok=True)
  381. ocr_modes: List[Tuple[str, bool, bool]] = [
  382. ("det_rec", True, True),
  383. ("whole_rec", False, True),
  384. ]
  385. results: List[Dict[str, Any]] = []
  386. hits: List[Dict[str, Any]] = []
  387. engines: Dict[float, Any] = {}
  388. total = 0
  389. for method, thresh, contrast_cfg, c_order, upscale, det_th in product(
  390. methods, thresholds, contrast_grid, contrast_orders, upscales, det_threshs
  391. ):
  392. # 过滤无效组合:非 threshold 方法不需要阈值
  393. if method not in ("threshold",):
  394. if thresh is not None:
  395. continue
  396. if det_th not in engines:
  397. print(f" [{stem}] 加载 OCR det_db_box_thresh={det_th} ...")
  398. engines[det_th] = _make_engine(det_th, model_dir)
  399. img = _preprocess(
  400. raw,
  401. method=method,
  402. thresh=thresh,
  403. contrast_cfg=contrast_cfg,
  404. upscale=upscale,
  405. contrast_order=c_order,
  406. )
  407. c_tag = _contrast_tag(contrast_cfg)
  408. o_tag = "b" if c_order == "before_upscale" else "a"
  409. tag = f"{method}_t{thresh or 'd'}_{c_tag}_o{o_tag}_u{upscale}_det{det_th}"
  410. if save_images:
  411. cv2.imwrite(str(cell_out / f"{tag}.png"), img)
  412. for mode_name, det, rec in ocr_modes:
  413. total += 1
  414. ocr = _ocr(engines[det_th], img, det=det, rec=rec)
  415. row: Dict[str, Any] = {
  416. "tag": tag,
  417. "method": method,
  418. "threshold": thresh,
  419. "contrast_method": contrast_cfg.get("method", "none"),
  420. "contrast_order": c_order,
  421. "contrast_cfg": contrast_cfg,
  422. "upscale": upscale,
  423. "det_db_box_thresh": det_th,
  424. "ocr_mode": mode_name,
  425. **ocr,
  426. }
  427. results.append(row)
  428. m = _match_hit(row.get("text", ""), target)
  429. if m:
  430. row["match"] = m
  431. hits.append(row)
  432. print(
  433. f" HIT [{m}] {mode_name} {tag} "
  434. f"score={row.get('score')} -> {row.get('text')!r}"
  435. )
  436. if run_baseline:
  437. for det_th in det_threshs:
  438. if det_th not in engines:
  439. engines[det_th] = _make_engine(det_th, model_dir)
  440. base_img = _upscale(raw, baseline_upscale)
  441. if save_images:
  442. cv2.imwrite(str(cell_out / f"baseline_upscale{baseline_upscale}.png"), base_img)
  443. for mode_name, det, rec in ocr_modes:
  444. ocr = _ocr(engines[det_th], base_img, det=det, rec=rec)
  445. row = {
  446. "tag": f"baseline_upscale{baseline_upscale}",
  447. "det_db_box_thresh": det_th,
  448. "ocr_mode": mode_name,
  449. **ocr,
  450. }
  451. results.append(row)
  452. m = _match_hit(row.get("text", ""), target)
  453. if m:
  454. row["match"] = m
  455. hits.append(row)
  456. qualified_hits = _collect_qualified_hits(
  457. results, target, min_score=min_hit_score, ocr_mode="det_rec"
  458. )
  459. report = {
  460. "input": str(resolved),
  461. "input_requested": str(input_path),
  462. "output_dir": str(cell_out),
  463. "target": target,
  464. "min_hit_score": min_hit_score,
  465. "total_trials": total,
  466. "hits": hits,
  467. "hits_target_score_above": [
  468. {
  469. "tag": r.get("tag"),
  470. "score": r.get("score"),
  471. "text": r.get("text"),
  472. "match": r.get("match"),
  473. "method": r.get("method"),
  474. "threshold": r.get("threshold"),
  475. "contrast_method": r.get("contrast_method"),
  476. "contrast_order": r.get("contrast_order"),
  477. "contrast_cfg": r.get("contrast_cfg"),
  478. "upscale": r.get("upscale"),
  479. "det_db_box_thresh": r.get("det_db_box_thresh"),
  480. "ocr_mode": r.get("ocr_mode"),
  481. }
  482. for r in qualified_hits
  483. ],
  484. "all_results": results,
  485. }
  486. report_path = cell_out / "sweep_report.json"
  487. report_path.write_text(
  488. json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8"
  489. )
  490. _print_conclusions(stem, results, target, min_hit_score=min_hit_score)
  491. return report
  492. def _sorted_summary_labels(values: set) -> List[str]:
  493. """将可能含 None 的集合转为可排序的展示标签(baseline 行字段常为 None)。"""
  494. labels: List[str] = []
  495. if None in values:
  496. labels.append("baseline")
  497. labels.extend(sorted(str(v) for v in values if v is not None))
  498. return labels
  499. def _format_param_row(r: Dict[str, Any]) -> str:
  500. """单行参数组合描述。"""
  501. cfg = r.get("contrast_cfg") or {}
  502. cm = r.get("contrast_method", "none")
  503. extra = ""
  504. if cm == "clahe":
  505. extra = f" cl={cfg.get('clip_limit')} tile={cfg.get('tile_grid_size')}"
  506. elif cm == "text_restore":
  507. extra = f" t={cfg.get('text_black_target')} bg={cfg.get('background_threshold')}"
  508. elif cm == "gamma":
  509. extra = f" g={cfg.get('gamma')}"
  510. elif cm == "linear":
  511. extra = f" b={cfg.get('black_percentile')} w={cfg.get('white_percentile')}"
  512. order = "放大前" if r.get("contrast_order") == "before_upscale" else "放大后"
  513. return (
  514. f"wm={r.get('method')} thresh={r.get('threshold', 'd')} contrast={cm}{extra} "
  515. f"order={order} upscale={r.get('upscale')} det={r.get('det_db_box_thresh')}"
  516. )
  517. def _print_conclusions(
  518. stem: str,
  519. results: List[Dict[str, Any]],
  520. target: Optional[str],
  521. *,
  522. min_hit_score: float = 0.9,
  523. ) -> None:
  524. """先列出命中 target 且 score>阈值的全部参数组合,再输出结论。"""
  525. if not results:
  526. return
  527. print(f"\n{'='*70}")
  528. print(f" 实验结论: {stem}")
  529. if target:
  530. print(f" 目标文字: {target!r} 阈值: score > {min_hit_score} (det_rec)")
  531. print(f"{'='*70}")
  532. dr_results = [r for r in results if r.get("ocr_mode") == "det_rec" and r.get("text")]
  533. if not dr_results:
  534. dr_results = [r for r in results if r.get("text")]
  535. if not dr_results:
  536. print(" (无有效 OCR 结果)")
  537. return
  538. scored = sorted(dr_results, key=lambda r: -(float(r.get("score") or 0)))
  539. qualified = _collect_qualified_hits(
  540. results, target, min_score=min_hit_score, ocr_mode="det_rec"
  541. )
  542. # ── 1. 命中 target 且 score > 阈值 的全部参数组合 ──
  543. print(f"\n 【命中列表】共 {len(qualified)} 组 (target 匹配 + score > {min_hit_score}):")
  544. if not target:
  545. print(" (未指定 -t/--target,跳过命中列表)")
  546. elif not qualified:
  547. print(" (无满足条件的组合)")
  548. # 仍展示最接近的 HIT
  549. near = [
  550. r for r in scored
  551. if _match_hit(r.get("text", "") or "", target)
  552. ]
  553. if near:
  554. print(f" 提示: 有 {len(near)} 组命中 target 但 score <= {min_hit_score},最高分:")
  555. for r in near[:5]:
  556. print(
  557. f" score={float(r.get('score', 0)):.4f} text={r.get('text', '')!r}"
  558. )
  559. print(f" tag={r.get('tag', '')}")
  560. else:
  561. for i, r in enumerate(qualified, 1):
  562. print(
  563. f" {i}. score={float(r.get('score', 0)):.4f} "
  564. f"match={r.get('match')} text={r.get('text', '')!r}"
  565. )
  566. print(f" tag={r.get('tag', '')}")
  567. print(f" {_format_param_row(r)}")
  568. # ── 2. 结论 ──
  569. print("\n 【结论】")
  570. if qualified:
  571. best = qualified[0]
  572. print(
  573. f" 推荐参数组合: {best.get('tag')} "
  574. f"(score={float(best.get('score', 0)):.4f}, text={best.get('text', '')!r})"
  575. )
  576. print(f" {_format_param_row(best)}")
  577. # 在合格集合内做简要对比
  578. cm_best: Dict[str, Dict[str, Any]] = {}
  579. for r in qualified:
  580. cm = r.get("contrast_method") or "baseline"
  581. if cm not in cm_best or float(r.get("score") or 0) > float(
  582. cm_best[cm].get("score") or 0
  583. ):
  584. cm_best[cm] = r
  585. print(" 合格组合内各 contrast 最优:")
  586. for cm in sorted(cm_best.keys(), key=str):
  587. r = cm_best[cm]
  588. print(
  589. f" [{cm}] score={float(r.get('score', 0)):.4f} tag={r.get('tag', '')}"
  590. )
  591. wm_set = {r.get("method") for r in qualified}
  592. order_set = {r.get("contrast_order") for r in qualified}
  593. upscale_vals = {r.get("upscale") for r in qualified}
  594. upscale_set = _sorted_summary_labels(
  595. {str(v) if v is not None else "baseline" for v in upscale_vals}
  596. )
  597. print(
  598. f" 合格组合涉及: 去水印={_sorted_summary_labels(wm_set)} "
  599. f"放大顺序={_sorted_summary_labels(order_set)} upscale={upscale_set}"
  600. )
  601. print(f" 共 {len(qualified)} 组参数可用于生产配置参考")
  602. else:
  603. top = scored[0] if scored else None
  604. if top and target and _match_hit(top.get("text", "") or "", target):
  605. print(
  606. f" 无 score>{min_hit_score} 的命中组合;最高分命中: "
  607. f"{top.get('tag')} score={float(top.get('score', 0)):.4f}"
  608. )
  609. elif top:
  610. print(
  611. f" 无命中 target 的组合;全局最高 det_rec: "
  612. f"{top.get('tag')} score={float(top.get('score', 0)):.4f} text={top.get('text', '')!r}"
  613. )
  614. else:
  615. print(" 无可用结论")
  616. print(f"{'='*70}\n")
  617. def _parse_best_config(tag: str) -> Dict[str, Any]:
  618. """解析最优参数 tag,如 threshold_t150_cl_1.0_8_ob_u128_det0.5。
  619. tag 格式: {method}_t{thresh}_{c_tag}_o{b|a}_u{upscale}_det{det_th}
  620. """
  621. import re
  622. cfg: Dict[str, Any] = {}
  623. tag = tag.strip()
  624. # 解析 method: threshold | masked_adaptive | none
  625. m = re.match(r"(threshold|masked_adaptive|none)_t(\w+?)_(.+?)_o([ba])_u(\d+)_det([\d.]+)$", tag)
  626. if not m:
  627. raise ValueError(f"无法解析 best-config tag: {tag!r}")
  628. method, thresh_str, c_part, order_char, upscale, det_th = m.groups()
  629. cfg["method"] = method
  630. cfg["threshold"] = int(thresh_str) if thresh_str.isdigit() else None
  631. cfg["contrast_order"] = "before_upscale" if order_char == "b" else "after_upscale"
  632. cfg["upscale"] = int(upscale)
  633. cfg["det_db_box_thresh"] = float(det_th)
  634. # 解析 contrast 部分: cl_1.0_8 | tr_85 | gm_0.85 | ln_2.0_98.0 | c0
  635. if c_part == "c0":
  636. cfg["contrast_cfg"] = {"method": "none"}
  637. elif c_part.startswith("cl_"):
  638. parts = c_part.split("_")
  639. cfg["contrast_cfg"] = {"method": "clahe", "clip_limit": float(parts[1]), "tile_grid_size": int(parts[2])}
  640. elif c_part.startswith("tr_"):
  641. parts = c_part.split("_")
  642. cfg["contrast_cfg"] = {"method": "text_restore", "text_black_target": int(parts[1])}
  643. elif c_part.startswith("gm_"):
  644. parts = c_part.split("_")
  645. cfg["contrast_cfg"] = {"method": "gamma", "gamma": float(parts[1])}
  646. elif c_part.startswith("ln_"):
  647. parts = c_part.split("_")
  648. cfg["contrast_cfg"] = {"method": "linear", "black_percentile": float(parts[1]), "white_percentile": float(parts[2])}
  649. else:
  650. raise ValueError(f"无法解析 contrast tag: {c_part!r} (in {tag})")
  651. return cfg
  652. def _tag_to_cell_preprocess_yaml(tag: str) -> Dict[str, Any]:
  653. """将 sweep tag 转为 second_pass_ocr.cell_preprocess 片段(Pass1 + Pass2)。"""
  654. cfg = _parse_best_config(tag)
  655. cm = cfg["contrast_cfg"].get("method", "none")
  656. cpp: Dict[str, Any] = {
  657. "watermark": {
  658. "enabled": cfg["method"] != "none",
  659. "method": cfg["method"],
  660. },
  661. "upscale_min_side": cfg["upscale"],
  662. }
  663. if cfg["method"] == "threshold" and cfg.get("threshold") is not None:
  664. cpp["watermark"]["threshold"] = cfg["threshold"]
  665. if cm != "none":
  666. contrast = {"enabled": True, "method": cm, **{k: v for k, v in cfg["contrast_cfg"].items() if k != "method"}}
  667. cpp["contrast"] = contrast
  668. else:
  669. cpp["contrast"] = {"enabled": False}
  670. pass2_contrast: Dict[str, Any] = dict(cpp.get("contrast") or {"enabled": False})
  671. if pass2_contrast.get("enabled") and pass2_contrast.get("method") == "clahe":
  672. pass2_contrast = dict(pass2_contrast)
  673. pass2_contrast["tile_grid_size"] = int(4)
  674. cpp["enhance_retry"] = {
  675. "enabled": True,
  676. "upscale_min_side": cfg["upscale"],
  677. "contrast": pass2_contrast,
  678. }
  679. return cpp
  680. def _qualified_from_report(report: Dict[str, Any], min_score: float) -> List[Dict[str, Any]]:
  681. qh = report.get("hits_target_score_above")
  682. if qh is not None:
  683. return list(qh)
  684. return _collect_qualified_hits(
  685. report.get("all_results") or [],
  686. report.get("target"),
  687. min_score=min_score,
  688. ocr_mode="det_rec",
  689. )
  690. def _majority_key(items: Sequence[Any]) -> Any:
  691. from collections import Counter
  692. if not items:
  693. return None
  694. return Counter(items).most_common(1)[0][0]
  695. def aggregate_sweep_reports(
  696. report_paths: Sequence[Path],
  697. *,
  698. min_hit_score: float = 0.9,
  699. ) -> Dict[str, Any]:
  700. """跨多个 sweep_report.json 汇总,输出 Pass1/Pass2 配置建议。"""
  701. per_case: List[Dict[str, Any]] = []
  702. all_qualified_tags: List[set] = []
  703. for rp in report_paths:
  704. if not rp.is_file():
  705. per_case.append({"report": str(rp), "error": "missing"})
  706. continue
  707. report = json.loads(rp.read_text(encoding="utf-8"))
  708. qualified = _qualified_from_report(report, min_hit_score)
  709. tags = [str(r.get("tag") or "") for r in qualified if r.get("tag")]
  710. all_qualified_tags.append(set(tags))
  711. top = qualified[0] if qualified else None
  712. per_case.append(
  713. {
  714. "report": str(rp),
  715. "input": report.get("input"),
  716. "target": report.get("target"),
  717. "qualified_count": len(qualified),
  718. "top_tag": top.get("tag") if top else None,
  719. "top_score": top.get("score") if top else None,
  720. "top_text": top.get("text") if top else None,
  721. }
  722. )
  723. ok_cases = [c for c in per_case if c.get("top_tag")]
  724. intersection: set = set.intersection(*all_qualified_tags) if all_qualified_tags else set()
  725. intersection_sorted = sorted(
  726. intersection,
  727. key=lambda t: -max(
  728. float(r.get("score") or 0)
  729. for rp in report_paths
  730. if rp.is_file()
  731. for r in _qualified_from_report(json.loads(rp.read_text(encoding="utf-8")), min_hit_score)
  732. if r.get("tag") == t
  733. ),
  734. )
  735. pick_tag: Optional[str] = None
  736. pick_reason = ""
  737. majority_fields: Dict[str, Any] = {}
  738. if intersection_sorted:
  739. pick_tag = intersection_sorted[0]
  740. pick_reason = f"全部 {len(report_paths)} 个 case 的合格集合交集,取最高分 tag"
  741. elif ok_cases:
  742. parsed: List[Dict[str, Any]] = []
  743. for c in ok_cases:
  744. tag = c["top_tag"]
  745. if tag and str(tag).startswith("baseline"):
  746. continue
  747. try:
  748. parsed.append(_parse_best_config(str(tag)))
  749. except ValueError:
  750. continue
  751. if parsed:
  752. pick_reason = "无全局 tag 交集,按各 case 榜首做分字段多数票(跳过 baseline)"
  753. method = _majority_key([p["method"] for p in parsed])
  754. thresh = _majority_key([p.get("threshold") for p in parsed])
  755. upscale = _majority_key([p["upscale"] for p in parsed])
  756. cm = _majority_key([p["contrast_cfg"].get("method") for p in parsed])
  757. c_cfgs = [p["contrast_cfg"] for p in parsed if p["contrast_cfg"].get("method") == cm]
  758. contrast_cfg: Dict[str, Any] = {"method": cm or "none"}
  759. if cm == "clahe" and c_cfgs:
  760. contrast_cfg["clip_limit"] = _majority_key([c.get("clip_limit") for c in c_cfgs])
  761. contrast_cfg["tile_grid_size"] = _majority_key(
  762. [c.get("tile_grid_size") for c in c_cfgs]
  763. )
  764. synthetic = {
  765. "method": method,
  766. "threshold": thresh,
  767. "contrast_cfg": contrast_cfg,
  768. "contrast_order": _majority_key([p["contrast_order"] for p in parsed]),
  769. "upscale": upscale,
  770. "det_db_box_thresh": _majority_key([p["det_db_box_thresh"] for p in parsed]),
  771. }
  772. c_tag = _contrast_tag(contrast_cfg)
  773. o_tag = "b" if synthetic["contrast_order"] == "before_upscale" else "a"
  774. t_s = str(synthetic["threshold"] or "d")
  775. pick_tag = (
  776. f"{method}_t{t_s}_{c_tag}_o{o_tag}_u{synthetic['upscale']}"
  777. f"_det{synthetic['det_db_box_thresh']}"
  778. )
  779. majority_fields = synthetic
  780. else:
  781. pick_tag = _majority_key([c["top_tag"] for c in ok_cases])
  782. pick_reason = "无全局交集,按各 case 榜首 tag 多数票(含 baseline)"
  783. pass1_cpp = _tag_to_cell_preprocess_yaml(pick_tag) if pick_tag else None
  784. def _is_pass2_tile_tag(t: str) -> bool:
  785. import re
  786. return bool(re.search(r"cl_[\d.]+_4_", t))
  787. pass2_candidates = [t for t in intersection_sorted if _is_pass2_tile_tag(t)]
  788. pass2_tag = pass2_candidates[0] if pass2_candidates else pick_tag
  789. pass2_cpp = _tag_to_cell_preprocess_yaml(pass2_tag) if pass2_tag else None
  790. return {
  791. "min_hit_score": min_hit_score,
  792. "per_case": per_case,
  793. "intersection_tags": intersection_sorted[:20],
  794. "intersection_count": len(intersection),
  795. "recommended_tag": pick_tag,
  796. "pick_reason": pick_reason,
  797. "majority_fields": majority_fields,
  798. "pass1_cell_preprocess": pass1_cpp,
  799. "pass2_enhance_retry": (pass2_cpp or {}).get("enhance_retry") if pass2_cpp else None,
  800. "pass2_tag": pass2_tag,
  801. }
  802. _CASE_META_KEYS = frozenset({"name", "note", "description"})
  803. def _config_to_argv(cfg: Dict[str, Any]) -> List[str]:
  804. argv = [str(cfg["input"])]
  805. for key, value in cfg.items():
  806. if key == "input" or key in _CASE_META_KEYS:
  807. continue
  808. flag = f"--{key.replace('_', '-')}"
  809. if isinstance(value, bool) and value:
  810. argv.append(flag)
  811. elif not isinstance(value, bool):
  812. argv.extend([flag, str(value)])
  813. return argv
  814. def _load_cases_json(path: Path) -> List[Dict[str, Any]]:
  815. data = json.loads(path.read_text(encoding="utf-8"))
  816. shared = data.get("shared") or {}
  817. cases = data.get("cases") or []
  818. out: List[Dict[str, Any]] = []
  819. for c in cases:
  820. merged = {**shared, **c}
  821. out.append(merged)
  822. return out
  823. def _find_sweep_report(case_output: Path, input_path: Path) -> Optional[Path]:
  824. stem = Path(str(input_path)).stem.removesuffix("_raw")
  825. direct = case_output / stem / "sweep_report.json"
  826. if direct.is_file():
  827. return direct
  828. hits = sorted(case_output.rglob("sweep_report.json"))
  829. return hits[0] if hits else None
  830. def run_cases_batch(cases: Sequence[Dict[str, Any]]) -> List[Dict[str, Any]]:
  831. """顺序执行多组 case(每组独立 output)。"""
  832. summaries: List[Dict[str, Any]] = []
  833. for i, case in enumerate(cases, 1):
  834. name = case.get("name") or case.get("input")
  835. print(f"\n{'#'*70}\n 批量 case {i}/{len(cases)}: {name}\n{'#'*70}")
  836. argv = _config_to_argv(case)
  837. main(argv)
  838. out = Path(case["output"])
  839. rp = _find_sweep_report(out, Path(case["input"]))
  840. summaries.append({"name": name, "output": str(out), "report": str(rp) if rp else None})
  841. return summaries
  842. def run_best_config(
  843. input_path: Path,
  844. out_dir: Path,
  845. *,
  846. prefer_raw: bool,
  847. best_cfg: Dict[str, Any],
  848. model_dir: Path,
  849. save_images: bool,
  850. ) -> Dict[str, Any]:
  851. """对单图用指定最优参数跑一次 OCR。"""
  852. resolved = resolve_input_image(input_path, prefer_raw=prefer_raw)
  853. raw = cv2.imread(str(resolved))
  854. if raw is None:
  855. raise RuntimeError(f"无法读取图像: {resolved}")
  856. stem = resolved.stem.removesuffix("_raw") if resolved.stem.endswith("_raw") else resolved.stem
  857. cell_out = out_dir / stem
  858. cell_out.mkdir(parents=True, exist_ok=True)
  859. engine = _make_engine(best_cfg["det_db_box_thresh"], model_dir)
  860. img = _preprocess(
  861. raw,
  862. method=best_cfg["method"],
  863. thresh=best_cfg.get("threshold"),
  864. contrast_cfg=best_cfg["contrast_cfg"],
  865. upscale=best_cfg["upscale"],
  866. contrast_order=best_cfg["contrast_order"],
  867. )
  868. tag = best_cfg.get("_tag", "best")
  869. if save_images:
  870. cv2.imwrite(str(cell_out / f"{tag}.png"), img)
  871. ocr = _ocr(engine, img, det=True, rec=True)
  872. row: Dict[str, Any] = {
  873. "tag": tag,
  874. "method": best_cfg["method"],
  875. "threshold": best_cfg.get("threshold"),
  876. "contrast_method": best_cfg["contrast_cfg"].get("method", "none"),
  877. "contrast_order": best_cfg["contrast_order"],
  878. "contrast_cfg": best_cfg["contrast_cfg"],
  879. "upscale": best_cfg["upscale"],
  880. "det_db_box_thresh": best_cfg["det_db_box_thresh"],
  881. "ocr_mode": "det_rec",
  882. **ocr,
  883. }
  884. report = {
  885. "input": str(resolved),
  886. "input_requested": str(input_path),
  887. "output_dir": str(cell_out),
  888. "result": row,
  889. }
  890. report_path = cell_out / "best_result.json"
  891. report_path.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
  892. return report
  893. def _build_arg_parser() -> argparse.ArgumentParser:
  894. p = argparse.ArgumentParser(
  895. description="单元格图预处理 + OCR 参数网格扫描(对齐 pipeline 格级二次 OCR)",
  896. )
  897. p.add_argument(
  898. "input",
  899. type=Path,
  900. nargs="?",
  901. default=None,
  902. help="单元格裁剪图路径,或 tablecell_ocr 目录(批量扫描)",
  903. )
  904. p.add_argument(
  905. "-o",
  906. "--output",
  907. type=Path,
  908. default=None,
  909. help="输出目录,默认 <input_dir|input_parent>/sweep_out/<stem>",
  910. )
  911. p.add_argument(
  912. "-t",
  913. "--target",
  914. default=None,
  915. help="期望 OCR 文本;用于标记 HIT(子串匹配)。省略则任意非空为 HIT",
  916. )
  917. p.add_argument(
  918. "--min-hit-score",
  919. type=float,
  920. default=0.9,
  921. help="结论中「命中列表」的最低 score 阈值(默认 0.9,仅 det_rec)",
  922. )
  923. p.add_argument(
  924. "--model-dir",
  925. type=Path,
  926. default=None,
  927. help="PaddleOCR torch 模型目录(含 det/rec .pth),也可用 OCR_*_MODEL_PATH",
  928. )
  929. p.add_argument(
  930. "--no-prefer-raw",
  931. action="store_true",
  932. help="不自动选用同名的 *_raw.png",
  933. )
  934. p.add_argument(
  935. "--quick",
  936. action="store_true",
  937. help="缩小网格(threshold 155,165 × upscale 128,192 × det 0.5 × contrast 精简)",
  938. )
  939. p.add_argument(
  940. "--methods",
  941. default="threshold,masked_adaptive,none",
  942. help="去水印方式,逗号分隔;none=不去水印",
  943. )
  944. p.add_argument(
  945. "--thresholds",
  946. default="155,165,none",
  947. help="threshold 法的阈值;none=预设默认",
  948. )
  949. p.add_argument(
  950. "--contrast-orders",
  951. default="before_upscale,after_upscale",
  952. help="contrast 执行顺序: before_upscale(放大前), after_upscale(放大后), 逗号组合",
  953. )
  954. p.add_argument(
  955. "--upscales",
  956. default="128,192",
  957. help="最短边放大目标,逗号分隔整数",
  958. )
  959. p.add_argument(
  960. "--det-threshs",
  961. # default="0.2,0.3,0.4,0.5",
  962. default="0.5",
  963. help="det_db_box_thresh,逗号分隔",
  964. )
  965. p.add_argument(
  966. "--no-save-images",
  967. action="store_true",
  968. help="不写出中间预处理 png(仅报告)",
  969. )
  970. p.add_argument(
  971. "--no-baseline",
  972. action="store_true",
  973. help="跳过「仅放大、不去水印」对照组",
  974. )
  975. p.add_argument(
  976. "--baseline-upscale",
  977. type=int,
  978. default=192,
  979. help="baseline 对照组的最短边放大",
  980. )
  981. p.add_argument(
  982. "--best-only",
  983. action="store_true",
  984. help="不跑参数网格,对目录下所有图用 --best-config 指定参数跑一次,验证适配性",
  985. )
  986. p.add_argument(
  987. "--best-config",
  988. default="threshold_t150_cl_1.0_8_ob_u128_det0.5",
  989. help="最优参数 tag,如 threshold_t150_cl_1.0_8_ob_u128_det0.5",
  990. )
  991. p.add_argument(
  992. "--cases",
  993. type=Path,
  994. default=None,
  995. help="批量 case JSON(见 sweep_cases.json),顺序跑网格扫描",
  996. )
  997. p.add_argument(
  998. "--aggregate-only",
  999. action="store_true",
  1000. help="不扫描,仅根据已有 sweep_report.json 汇总 Pass1/Pass2 建议",
  1001. )
  1002. p.add_argument(
  1003. "--aggregate-out",
  1004. type=Path,
  1005. default=None,
  1006. help="汇总输出路径(默认 <cases 父目录>/aggregate_recommendation.json)",
  1007. )
  1008. return p
  1009. def main(argv: Optional[Sequence[str]] = None) -> None:
  1010. args = _build_arg_parser().parse_args(argv)
  1011. if args.aggregate_only:
  1012. if not args.cases or not args.cases.is_file():
  1013. raise SystemExit("--aggregate-only 需要 --cases <sweep_cases.json>")
  1014. cases = _load_cases_json(args.cases)
  1015. report_paths: List[Path] = []
  1016. for case in cases:
  1017. out = Path(case["output"])
  1018. rp = _find_sweep_report(out, Path(case["input"]))
  1019. if rp:
  1020. report_paths.append(rp)
  1021. else:
  1022. print(f" 跳过(无报告): {case.get('name') or case['input']}")
  1023. agg = aggregate_sweep_reports(
  1024. report_paths, min_hit_score=args.min_hit_score
  1025. )
  1026. agg_out = args.aggregate_out or args.cases.parent / "aggregate_recommendation.json"
  1027. agg_out.parent.mkdir(parents=True, exist_ok=True)
  1028. agg_out.write_text(json.dumps(agg, ensure_ascii=False, indent=2), encoding="utf-8")
  1029. print(f"\n汇总完成 -> {agg_out}")
  1030. print(f" 推荐 tag: {agg.get('recommended_tag')}")
  1031. print(f" 交集合格 tag 数: {agg.get('intersection_count')}")
  1032. if agg.get("pass1_cell_preprocess"):
  1033. print("\n Pass1 cell_preprocess 片段:")
  1034. print(json.dumps(agg["pass1_cell_preprocess"], ensure_ascii=False, indent=2))
  1035. if agg.get("pass2_enhance_retry"):
  1036. print("\n Pass2 enhance_retry 片段:")
  1037. print(json.dumps(agg["pass2_enhance_retry"], ensure_ascii=False, indent=2))
  1038. return
  1039. if args.cases and args.cases.is_file():
  1040. if args.input is not None:
  1041. print(" 提示: 已指定 --cases,忽略 positional input")
  1042. cases = _load_cases_json(args.cases)
  1043. run_cases_batch(cases)
  1044. report_paths = []
  1045. for case in cases:
  1046. rp = _find_sweep_report(Path(case["output"]), Path(case["input"]))
  1047. if rp:
  1048. report_paths.append(rp)
  1049. if report_paths:
  1050. agg = aggregate_sweep_reports(
  1051. report_paths, min_hit_score=args.min_hit_score
  1052. )
  1053. agg_out = args.aggregate_out or args.cases.parent / "aggregate_recommendation.json"
  1054. agg_out.write_text(json.dumps(agg, ensure_ascii=False, indent=2), encoding="utf-8")
  1055. print(f"\n批量汇总 -> {agg_out}")
  1056. print(f" 推荐 tag: {agg.get('recommended_tag')} ({agg.get('pick_reason')})")
  1057. return
  1058. if args.input is None:
  1059. raise SystemExit("需要 input 路径,或使用 --cases / --aggregate-only")
  1060. inputs = collect_inputs(args.input, prefer_raw=not args.no_prefer_raw)
  1061. if not inputs:
  1062. raise SystemExit("未找到可扫描的图像")
  1063. if args.output is not None:
  1064. out_root = args.output
  1065. elif args.input.is_file():
  1066. out_root = args.input.parent / "sweep_out"
  1067. else:
  1068. out_root = args.input / "sweep_out"
  1069. out_root.mkdir(parents=True, exist_ok=True)
  1070. model_dir = args.model_dir or _default_model_dir()
  1071. if args.best_only:
  1072. # 验证适配性模式:对目录下所有图用最优参数跑一次
  1073. best_cfg = _parse_best_config(args.best_config)
  1074. best_cfg["_tag"] = args.best_config
  1075. print(f"最佳参数验证模式: {args.best_config}")
  1076. print(f" 解析: method={best_cfg['method']} contrast={best_cfg['contrast_cfg'].get('method')} "
  1077. f"upscale={best_cfg['upscale']} order={best_cfg['contrast_order']}")
  1078. print(f" 共 {len(inputs)} 张图")
  1079. all_texts: List[Dict[str, Any]] = []
  1080. hit_count = 0
  1081. for img_path in inputs:
  1082. report = run_best_config(
  1083. img_path, out_root,
  1084. prefer_raw=not args.no_prefer_raw,
  1085. best_cfg=best_cfg,
  1086. model_dir=model_dir,
  1087. save_images=not args.no_save_images,
  1088. )
  1089. result = report["result"]
  1090. text = result.get("text", "")
  1091. score = result.get("score", 0)
  1092. all_texts.append({
  1093. "input": img_path.name,
  1094. "text": text,
  1095. "score": score,
  1096. "report": str(Path(report["output_dir"]) / "best_result.json"),
  1097. })
  1098. m = _match_hit(text, args.target)
  1099. hit_info = f" [HIT: {m}]" if m else ""
  1100. print(f" {img_path.name}: score={score:.4f} text={text!r}{hit_info}")
  1101. if m:
  1102. hit_count += 1
  1103. # 汇总
  1104. summary_path = out_root / "best_summary.json"
  1105. summary_data = {
  1106. "best_config": args.best_config,
  1107. "total": len(all_texts),
  1108. "hits": hit_count,
  1109. "target": args.target,
  1110. "results": all_texts,
  1111. }
  1112. summary_path.write_text(json.dumps(summary_data, ensure_ascii=False, indent=2), encoding="utf-8")
  1113. print(f"\n汇总: {hit_count}/{len(all_texts)} HIT -> {summary_path}")
  1114. return
  1115. # 正常参数网格扫描模式
  1116. methods = [m.strip() for m in args.methods.split(",") if m.strip()]
  1117. contrast_orders = [o.strip() for o in args.contrast_orders.split(",") if o.strip()]
  1118. if args.quick:
  1119. thresholds = [150, 155]
  1120. upscales = [96, 128, 192]
  1121. det_threshs = [0.5]
  1122. else:
  1123. thresholds = _parse_csv_ints(args.thresholds)
  1124. upscales = [int(x) for x in args.upscales.split(",") if x.strip()]
  1125. det_threshs = _parse_csv_floats(args.det_threshs)
  1126. contrast_grid = _build_contrast_grid(quick=args.quick)
  1127. print(f"扫描 {len(inputs)} 张图 -> {out_root}")
  1128. print(f" methods={methods} thresholds={thresholds} upscales={upscales}")
  1129. print(f" contrast_methods={len(contrast_grid)} orders={contrast_orders}")
  1130. if args.target:
  1131. print(f" target={args.target!r}")
  1132. summary: List[Dict[str, Any]] = []
  1133. for img_path in inputs:
  1134. print(f"\n=== {img_path.name} ===")
  1135. report = run_sweep(
  1136. img_path,
  1137. out_root,
  1138. prefer_raw=not args.no_prefer_raw,
  1139. target=args.target,
  1140. model_dir=model_dir,
  1141. methods=methods,
  1142. thresholds=thresholds,
  1143. contrast_grid=contrast_grid,
  1144. contrast_orders=contrast_orders,
  1145. upscales=upscales,
  1146. det_threshs=det_threshs,
  1147. save_images=not args.no_save_images,
  1148. run_baseline=not args.no_baseline,
  1149. baseline_upscale=args.baseline_upscale,
  1150. min_hit_score=args.min_hit_score,
  1151. )
  1152. qh = report.get("hits_target_score_above") or []
  1153. summary.append(
  1154. {
  1155. "input": report["input"],
  1156. "hits": len(report["hits"]),
  1157. "hits_target_score_above": len(qh),
  1158. "top_qualified_tag": qh[0]["tag"] if qh else None,
  1159. "top_qualified_score": qh[0]["score"] if qh else None,
  1160. "report": str(Path(report["output_dir"]) / "sweep_report.json"),
  1161. }
  1162. )
  1163. index_path = out_root / "sweep_index.json"
  1164. index_path.write_text(
  1165. json.dumps(summary, ensure_ascii=False, indent=2), encoding="utf-8"
  1166. )
  1167. print(f"\n全部完成,索引: {index_path}")
  1168. for s in summary:
  1169. qn = s.get("hits_target_score_above", 0)
  1170. top = s.get("top_qualified_tag")
  1171. print(
  1172. f" {s['input']}: hits={s['hits']} "
  1173. f"qualified(score>{args.min_hit_score})={qn}"
  1174. + (f" top={top}" if top else "")
  1175. + f" -> {s['report']}"
  1176. )
  1177. if __name__ == "__main__":
  1178. if len(sys.argv) == 1:
  1179. cases_path = Path(__file__).resolve().parent / "sweep_cases.json"
  1180. print(f"ℹ️ 未提供命令行参数,使用批量 cases: {cases_path.name}")
  1181. sys.argv = [sys.argv[0], "--cases", str(cases_path)]
  1182. sys.exit(main())