|
@@ -24,6 +24,10 @@
|
|
|
|
|
|
|
|
# 指定目标文字,自动统计 HIT 命中率
|
|
# 指定目标文字,自动统计 HIT 命中率
|
|
|
python cell_sweep.py /path/to/tablecell_ocr/ -o ./out --best-only -t "交易类型"
|
|
python cell_sweep.py /path/to/tablecell_ocr/ -o ./out --best-only -t "交易类型"
|
|
|
|
|
+
|
|
|
|
|
+ # 多 case 批量扫描 + 汇总 Pass1/Pass2 建议
|
|
|
|
|
+ python cell_sweep.py --cases sweep_cases.json
|
|
|
|
|
+ python cell_sweep.py --aggregate-only --cases sweep_cases.json
|
|
|
"""
|
|
"""
|
|
|
from __future__ import annotations
|
|
from __future__ import annotations
|
|
|
|
|
|
|
@@ -392,6 +396,32 @@ def _match_hit(text: str, target: Optional[str]) -> Optional[str]:
|
|
|
return None
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+def _collect_qualified_hits(
|
|
|
|
|
+ results: List[Dict[str, Any]],
|
|
|
|
|
+ target: Optional[str],
|
|
|
|
|
+ *,
|
|
|
|
|
+ min_score: float = 0.9,
|
|
|
|
|
+ ocr_mode: str = "det_rec",
|
|
|
|
|
+) -> List[Dict[str, Any]]:
|
|
|
|
|
+ """命中 target 且 score > min_score 的组合(按 tag 去重,保留最高分)。"""
|
|
|
|
|
+ if not target:
|
|
|
|
|
+ return []
|
|
|
|
|
+ by_tag: Dict[str, Dict[str, Any]] = {}
|
|
|
|
|
+ for r in results:
|
|
|
|
|
+ if r.get("ocr_mode") != ocr_mode:
|
|
|
|
|
+ continue
|
|
|
|
|
+ if not _match_hit(r.get("text", "") or "", target):
|
|
|
|
|
+ continue
|
|
|
|
|
+ score = float(r.get("score") or 0)
|
|
|
|
|
+ if score <= min_score:
|
|
|
|
|
+ continue
|
|
|
|
|
+ tag = str(r.get("tag") or "")
|
|
|
|
|
+ prev = by_tag.get(tag)
|
|
|
|
|
+ if prev is None or score > float(prev.get("score") or 0):
|
|
|
|
|
+ by_tag[tag] = r
|
|
|
|
|
+ return sorted(by_tag.values(), key=lambda x: -(float(x.get("score") or 0)))
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
def run_sweep(
|
|
def run_sweep(
|
|
|
input_path: Path,
|
|
input_path: Path,
|
|
|
out_dir: Path,
|
|
out_dir: Path,
|
|
@@ -408,6 +438,7 @@ def run_sweep(
|
|
|
save_images: bool,
|
|
save_images: bool,
|
|
|
run_baseline: bool,
|
|
run_baseline: bool,
|
|
|
baseline_upscale: int,
|
|
baseline_upscale: int,
|
|
|
|
|
+ min_hit_score: float = 0.9,
|
|
|
) -> Dict[str, Any]:
|
|
) -> Dict[str, Any]:
|
|
|
resolved = resolve_input_image(input_path, prefer_raw=prefer_raw)
|
|
resolved = resolve_input_image(input_path, prefer_raw=prefer_raw)
|
|
|
raw = cv2.imread(str(resolved))
|
|
raw = cv2.imread(str(resolved))
|
|
@@ -499,13 +530,34 @@ def run_sweep(
|
|
|
row["match"] = m
|
|
row["match"] = m
|
|
|
hits.append(row)
|
|
hits.append(row)
|
|
|
|
|
|
|
|
|
|
+ qualified_hits = _collect_qualified_hits(
|
|
|
|
|
+ results, target, min_score=min_hit_score, ocr_mode="det_rec"
|
|
|
|
|
+ )
|
|
|
report = {
|
|
report = {
|
|
|
"input": str(resolved),
|
|
"input": str(resolved),
|
|
|
"input_requested": str(input_path),
|
|
"input_requested": str(input_path),
|
|
|
"output_dir": str(cell_out),
|
|
"output_dir": str(cell_out),
|
|
|
"target": target,
|
|
"target": target,
|
|
|
|
|
+ "min_hit_score": min_hit_score,
|
|
|
"total_trials": total,
|
|
"total_trials": total,
|
|
|
"hits": hits,
|
|
"hits": hits,
|
|
|
|
|
+ "hits_target_score_above": [
|
|
|
|
|
+ {
|
|
|
|
|
+ "tag": r.get("tag"),
|
|
|
|
|
+ "score": r.get("score"),
|
|
|
|
|
+ "text": r.get("text"),
|
|
|
|
|
+ "match": r.get("match"),
|
|
|
|
|
+ "method": r.get("method"),
|
|
|
|
|
+ "threshold": r.get("threshold"),
|
|
|
|
|
+ "contrast_method": r.get("contrast_method"),
|
|
|
|
|
+ "contrast_order": r.get("contrast_order"),
|
|
|
|
|
+ "contrast_cfg": r.get("contrast_cfg"),
|
|
|
|
|
+ "upscale": r.get("upscale"),
|
|
|
|
|
+ "det_db_box_thresh": r.get("det_db_box_thresh"),
|
|
|
|
|
+ "ocr_mode": r.get("ocr_mode"),
|
|
|
|
|
+ }
|
|
|
|
|
+ for r in qualified_hits
|
|
|
|
|
+ ],
|
|
|
"all_results": results,
|
|
"all_results": results,
|
|
|
}
|
|
}
|
|
|
report_path = cell_out / "sweep_report.json"
|
|
report_path = cell_out / "sweep_report.json"
|
|
@@ -513,28 +565,57 @@ def run_sweep(
|
|
|
json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8"
|
|
json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8"
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
- # ── 结论报告:按 OCR score 排序,分组对比 ──
|
|
|
|
|
- _print_conclusions(stem, results, target)
|
|
|
|
|
|
|
+ _print_conclusions(stem, results, target, min_hit_score=min_hit_score)
|
|
|
|
|
|
|
|
return report
|
|
return report
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+def _sorted_summary_labels(values: set) -> List[str]:
|
|
|
|
|
+ """将可能含 None 的集合转为可排序的展示标签(baseline 行字段常为 None)。"""
|
|
|
|
|
+ labels: List[str] = []
|
|
|
|
|
+ if None in values:
|
|
|
|
|
+ labels.append("baseline")
|
|
|
|
|
+ labels.extend(sorted(str(v) for v in values if v is not None))
|
|
|
|
|
+ return labels
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _format_param_row(r: Dict[str, Any]) -> str:
|
|
|
|
|
+ """单行参数组合描述。"""
|
|
|
|
|
+ cfg = r.get("contrast_cfg") or {}
|
|
|
|
|
+ cm = r.get("contrast_method", "none")
|
|
|
|
|
+ extra = ""
|
|
|
|
|
+ if cm == "clahe":
|
|
|
|
|
+ extra = f" cl={cfg.get('clip_limit')} tile={cfg.get('tile_grid_size')}"
|
|
|
|
|
+ elif cm == "text_restore":
|
|
|
|
|
+ extra = f" t={cfg.get('text_black_target')} bg={cfg.get('background_threshold')}"
|
|
|
|
|
+ elif cm == "gamma":
|
|
|
|
|
+ extra = f" g={cfg.get('gamma')}"
|
|
|
|
|
+ elif cm == "linear":
|
|
|
|
|
+ extra = f" b={cfg.get('black_percentile')} w={cfg.get('white_percentile')}"
|
|
|
|
|
+ order = "放大前" if r.get("contrast_order") == "before_upscale" else "放大后"
|
|
|
|
|
+ return (
|
|
|
|
|
+ f"wm={r.get('method')} thresh={r.get('threshold', 'd')} contrast={cm}{extra} "
|
|
|
|
|
+ f"order={order} upscale={r.get('upscale')} det={r.get('det_db_box_thresh')}"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
def _print_conclusions(
|
|
def _print_conclusions(
|
|
|
stem: str,
|
|
stem: str,
|
|
|
results: List[Dict[str, Any]],
|
|
results: List[Dict[str, Any]],
|
|
|
target: Optional[str],
|
|
target: Optional[str],
|
|
|
|
|
+ *,
|
|
|
|
|
+ min_hit_score: float = 0.9,
|
|
|
) -> None:
|
|
) -> None:
|
|
|
- """打印实验结论:按 OCR score 排序,分组展示最优组合。"""
|
|
|
|
|
|
|
+ """先列出命中 target 且 score>阈值的全部参数组合,再输出结论。"""
|
|
|
if not results:
|
|
if not results:
|
|
|
return
|
|
return
|
|
|
|
|
|
|
|
print(f"\n{'='*70}")
|
|
print(f"\n{'='*70}")
|
|
|
print(f" 实验结论: {stem}")
|
|
print(f" 实验结论: {stem}")
|
|
|
if target:
|
|
if target:
|
|
|
- print(f" 目标文字: {target}")
|
|
|
|
|
|
|
+ print(f" 目标文字: {target!r} 阈值: score > {min_hit_score} (det_rec)")
|
|
|
print(f"{'='*70}")
|
|
print(f"{'='*70}")
|
|
|
|
|
|
|
|
- # 取 det_rec 模式的结果(优先用检测+识别完整结果)
|
|
|
|
|
dr_results = [r for r in results if r.get("ocr_mode") == "det_rec" and r.get("text")]
|
|
dr_results = [r for r in results if r.get("ocr_mode") == "det_rec" and r.get("text")]
|
|
|
if not dr_results:
|
|
if not dr_results:
|
|
|
dr_results = [r for r in results if r.get("text")]
|
|
dr_results = [r for r in results if r.get("text")]
|
|
@@ -543,77 +624,88 @@ def _print_conclusions(
|
|
|
print(" (无有效 OCR 结果)")
|
|
print(" (无有效 OCR 结果)")
|
|
|
return
|
|
return
|
|
|
|
|
|
|
|
- # ── 1. 全局 Top-5 ──
|
|
|
|
|
- scored = sorted(dr_results, key=lambda r: -(r.get("score") or 0))
|
|
|
|
|
- print("\n 全局 OCR 得分 Top-5:")
|
|
|
|
|
- for i, r in enumerate(scored[:5], 1):
|
|
|
|
|
- print(f" {i}. score={r.get('score', 0):.4f} text={r.get('text', '')!r}")
|
|
|
|
|
- print(f" tag={r.get('tag', '')}")
|
|
|
|
|
-
|
|
|
|
|
- # ── 2. 按 contrast 方法分组最佳 ──
|
|
|
|
|
- print("\n 按 contrast 方法分组最优(score 最高):")
|
|
|
|
|
- groups: Dict[str, List[Dict[str, Any]]] = {}
|
|
|
|
|
- for r in scored:
|
|
|
|
|
- cm = r.get("contrast_method", "?")
|
|
|
|
|
- groups.setdefault(cm, []).append(r)
|
|
|
|
|
-
|
|
|
|
|
- for cm in sorted(groups.keys()):
|
|
|
|
|
- best = groups[cm][0]
|
|
|
|
|
- wm = best.get("method", "?")
|
|
|
|
|
- print(f" [{cm}] 最佳: score={best.get('score', 0):.4f} "
|
|
|
|
|
- f"wm={wm} upscale={best.get('upscale')} "
|
|
|
|
|
- f"text={best.get('text', '')!r}")
|
|
|
|
|
-
|
|
|
|
|
- # ── 3. 有 watermark 处理 vs 无 watermark 处理对比 ──
|
|
|
|
|
- print("\n 去水印开关对比(同 contrast 方法,最高 score):")
|
|
|
|
|
- wm_groups: Dict[str, Dict[str, Any]] = {}
|
|
|
|
|
- for r in scored:
|
|
|
|
|
- cm = r.get("contrast_method", "?")
|
|
|
|
|
- wm = r.get("method", "?") if r.get("method") != "none" else "无去水印"
|
|
|
|
|
- key = f"{cm}|{wm}"
|
|
|
|
|
- cur_score = r.get("score") or 0
|
|
|
|
|
- prev_score = (wm_groups.get(key) or {}).get("score") or 0
|
|
|
|
|
- if key not in wm_groups or cur_score > prev_score:
|
|
|
|
|
- wm_groups[key] = r
|
|
|
|
|
-
|
|
|
|
|
- for cm in sorted(set(r.get("contrast_method", "?") for r in scored)):
|
|
|
|
|
- wm_rows = [r for k, r in wm_groups.items() if k.startswith(cm + "|")]
|
|
|
|
|
- if wm_rows:
|
|
|
|
|
- best_row = max(wm_rows, key=lambda r: r.get("score") or 0)
|
|
|
|
|
- wm_label = "无去水印" if best_row.get("method") == "none" else best_row.get("method", "?")
|
|
|
|
|
- print(f" [{cm}] 最优: wm={wm_label} score={best_row.get('score', 0):.4f} "
|
|
|
|
|
- f"text={best_row.get('text', '')!r}")
|
|
|
|
|
-
|
|
|
|
|
- # ── 4. 放大顺序对比 ──
|
|
|
|
|
- print("\n 放大前/后对比(同方法,最高 score):")
|
|
|
|
|
- order_data: Dict[str, Dict[str, Any]] = {}
|
|
|
|
|
- for r in scored:
|
|
|
|
|
- cm = r.get("contrast_method", "?")
|
|
|
|
|
- co = r.get("contrast_order", "?")
|
|
|
|
|
- key = f"{cm}|{co}"
|
|
|
|
|
- cur_score = r.get("score") or 0
|
|
|
|
|
- prev_score = (order_data.get(key) or {}).get("score") or 0
|
|
|
|
|
- if key not in order_data or cur_score > prev_score:
|
|
|
|
|
- order_data[key] = r
|
|
|
|
|
-
|
|
|
|
|
- for cm in sorted(set(r.get("contrast_method", "?") for r in scored)):
|
|
|
|
|
- b_score = (order_data.get(f"{cm}|before_upscale") or {}).get("score") or 0
|
|
|
|
|
- a_score = (order_data.get(f"{cm}|after_upscale") or {}).get("score") or 0
|
|
|
|
|
- better = "放大前" if b_score > a_score else ("放大后" if a_score > b_score else "持平")
|
|
|
|
|
- if b_score or a_score:
|
|
|
|
|
- print(f" [{cm}] 放大前={b_score:.4f} 放大后={a_score:.4f} 更优: {better}")
|
|
|
|
|
-
|
|
|
|
|
- # ── 5. HIT 命中率统计 ──
|
|
|
|
|
- if target:
|
|
|
|
|
- hit_count = sum(1 for r in results if r.get("match"))
|
|
|
|
|
- hit_by_cm: Dict[str, int] = {}
|
|
|
|
|
- for r in results:
|
|
|
|
|
- if r.get("match"):
|
|
|
|
|
- cm = r.get("contrast_method", "?")
|
|
|
|
|
- hit_by_cm[cm] = hit_by_cm.get(cm, 0) + 1
|
|
|
|
|
- print(f"\n HIT 命中率 (target={target}): {hit_count}/{len(results)}")
|
|
|
|
|
- for cm in sorted(hit_by_cm.keys()):
|
|
|
|
|
- print(f" [{cm}] HIT={hit_by_cm[cm]}")
|
|
|
|
|
|
|
+ scored = sorted(dr_results, key=lambda r: -(float(r.get("score") or 0)))
|
|
|
|
|
+ qualified = _collect_qualified_hits(
|
|
|
|
|
+ results, target, min_score=min_hit_score, ocr_mode="det_rec"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # ── 1. 命中 target 且 score > 阈值 的全部参数组合 ──
|
|
|
|
|
+ print(f"\n 【命中列表】共 {len(qualified)} 组 (target 匹配 + score > {min_hit_score}):")
|
|
|
|
|
+ if not target:
|
|
|
|
|
+ print(" (未指定 -t/--target,跳过命中列表)")
|
|
|
|
|
+ elif not qualified:
|
|
|
|
|
+ print(" (无满足条件的组合)")
|
|
|
|
|
+ # 仍展示最接近的 HIT
|
|
|
|
|
+ near = [
|
|
|
|
|
+ r for r in scored
|
|
|
|
|
+ if _match_hit(r.get("text", "") or "", target)
|
|
|
|
|
+ ]
|
|
|
|
|
+ if near:
|
|
|
|
|
+ print(f" 提示: 有 {len(near)} 组命中 target 但 score <= {min_hit_score},最高分:")
|
|
|
|
|
+ for r in near[:5]:
|
|
|
|
|
+ print(
|
|
|
|
|
+ f" score={float(r.get('score', 0)):.4f} text={r.get('text', '')!r}"
|
|
|
|
|
+ )
|
|
|
|
|
+ print(f" tag={r.get('tag', '')}")
|
|
|
|
|
+ else:
|
|
|
|
|
+ for i, r in enumerate(qualified, 1):
|
|
|
|
|
+ print(
|
|
|
|
|
+ f" {i}. score={float(r.get('score', 0)):.4f} "
|
|
|
|
|
+ f"match={r.get('match')} text={r.get('text', '')!r}"
|
|
|
|
|
+ )
|
|
|
|
|
+ print(f" tag={r.get('tag', '')}")
|
|
|
|
|
+ print(f" {_format_param_row(r)}")
|
|
|
|
|
+
|
|
|
|
|
+ # ── 2. 结论 ──
|
|
|
|
|
+ print("\n 【结论】")
|
|
|
|
|
+ if qualified:
|
|
|
|
|
+ best = qualified[0]
|
|
|
|
|
+ print(
|
|
|
|
|
+ f" 推荐参数组合: {best.get('tag')} "
|
|
|
|
|
+ f"(score={float(best.get('score', 0)):.4f}, text={best.get('text', '')!r})"
|
|
|
|
|
+ )
|
|
|
|
|
+ print(f" {_format_param_row(best)}")
|
|
|
|
|
+
|
|
|
|
|
+ # 在合格集合内做简要对比
|
|
|
|
|
+ cm_best: Dict[str, Dict[str, Any]] = {}
|
|
|
|
|
+ for r in qualified:
|
|
|
|
|
+ cm = r.get("contrast_method") or "baseline"
|
|
|
|
|
+ if cm not in cm_best or float(r.get("score") or 0) > float(
|
|
|
|
|
+ cm_best[cm].get("score") or 0
|
|
|
|
|
+ ):
|
|
|
|
|
+ cm_best[cm] = r
|
|
|
|
|
+ print(" 合格组合内各 contrast 最优:")
|
|
|
|
|
+ for cm in sorted(cm_best.keys(), key=str):
|
|
|
|
|
+ r = cm_best[cm]
|
|
|
|
|
+ print(
|
|
|
|
|
+ f" [{cm}] score={float(r.get('score', 0)):.4f} tag={r.get('tag', '')}"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ wm_set = {r.get("method") for r in qualified}
|
|
|
|
|
+ order_set = {r.get("contrast_order") for r in qualified}
|
|
|
|
|
+ upscale_vals = {r.get("upscale") for r in qualified}
|
|
|
|
|
+ upscale_set = _sorted_summary_labels(
|
|
|
|
|
+ {str(v) if v is not None else "baseline" for v in upscale_vals}
|
|
|
|
|
+ )
|
|
|
|
|
+ print(
|
|
|
|
|
+ f" 合格组合涉及: 去水印={_sorted_summary_labels(wm_set)} "
|
|
|
|
|
+ f"放大顺序={_sorted_summary_labels(order_set)} upscale={upscale_set}"
|
|
|
|
|
+ )
|
|
|
|
|
+ print(f" 共 {len(qualified)} 组参数可用于生产配置参考")
|
|
|
|
|
+ else:
|
|
|
|
|
+ top = scored[0] if scored else None
|
|
|
|
|
+ if top and target and _match_hit(top.get("text", "") or "", target):
|
|
|
|
|
+ print(
|
|
|
|
|
+ f" 无 score>{min_hit_score} 的命中组合;最高分命中: "
|
|
|
|
|
+ f"{top.get('tag')} score={float(top.get('score', 0)):.4f}"
|
|
|
|
|
+ )
|
|
|
|
|
+ elif top:
|
|
|
|
|
+ print(
|
|
|
|
|
+ f" 无命中 target 的组合;全局最高 det_rec: "
|
|
|
|
|
+ f"{top.get('tag')} score={float(top.get('score', 0)):.4f} text={top.get('text', '')!r}"
|
|
|
|
|
+ )
|
|
|
|
|
+ else:
|
|
|
|
|
+ print(" 无可用结论")
|
|
|
|
|
|
|
|
print(f"{'='*70}\n")
|
|
print(f"{'='*70}\n")
|
|
|
|
|
|
|
@@ -659,6 +751,221 @@ def _parse_best_config(tag: str) -> Dict[str, Any]:
|
|
|
return cfg
|
|
return cfg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+def _tag_to_cell_preprocess_yaml(tag: str) -> Dict[str, Any]:
|
|
|
|
|
+ """将 sweep tag 转为 second_pass_ocr.cell_preprocess 片段(Pass1 + Pass2)。"""
|
|
|
|
|
+ cfg = _parse_best_config(tag)
|
|
|
|
|
+ cm = cfg["contrast_cfg"].get("method", "none")
|
|
|
|
|
+ cpp: Dict[str, Any] = {
|
|
|
|
|
+ "watermark": {
|
|
|
|
|
+ "enabled": cfg["method"] != "none",
|
|
|
|
|
+ "method": cfg["method"],
|
|
|
|
|
+ },
|
|
|
|
|
+ "upscale_min_side": cfg["upscale"],
|
|
|
|
|
+ }
|
|
|
|
|
+ if cfg["method"] == "threshold" and cfg.get("threshold") is not None:
|
|
|
|
|
+ cpp["watermark"]["threshold"] = cfg["threshold"]
|
|
|
|
|
+ if cm != "none":
|
|
|
|
|
+ contrast = {"enabled": True, "method": cm, **{k: v for k, v in cfg["contrast_cfg"].items() if k != "method"}}
|
|
|
|
|
+ cpp["contrast"] = contrast
|
|
|
|
|
+ else:
|
|
|
|
|
+ cpp["contrast"] = {"enabled": False}
|
|
|
|
|
+ pass2_contrast: Dict[str, Any] = dict(cpp.get("contrast") or {"enabled": False})
|
|
|
|
|
+ if pass2_contrast.get("enabled") and pass2_contrast.get("method") == "clahe":
|
|
|
|
|
+ pass2_contrast = dict(pass2_contrast)
|
|
|
|
|
+ pass2_contrast["tile_grid_size"] = int(4)
|
|
|
|
|
+ cpp["enhance_retry"] = {
|
|
|
|
|
+ "enabled": True,
|
|
|
|
|
+ "upscale_min_side": cfg["upscale"],
|
|
|
|
|
+ "contrast": pass2_contrast,
|
|
|
|
|
+ }
|
|
|
|
|
+ return cpp
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _qualified_from_report(report: Dict[str, Any], min_score: float) -> List[Dict[str, Any]]:
|
|
|
|
|
+ qh = report.get("hits_target_score_above")
|
|
|
|
|
+ if qh is not None:
|
|
|
|
|
+ return list(qh)
|
|
|
|
|
+ return _collect_qualified_hits(
|
|
|
|
|
+ report.get("all_results") or [],
|
|
|
|
|
+ report.get("target"),
|
|
|
|
|
+ min_score=min_score,
|
|
|
|
|
+ ocr_mode="det_rec",
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _majority_key(items: Sequence[Any]) -> Any:
|
|
|
|
|
+ from collections import Counter
|
|
|
|
|
+
|
|
|
|
|
+ if not items:
|
|
|
|
|
+ return None
|
|
|
|
|
+ return Counter(items).most_common(1)[0][0]
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def aggregate_sweep_reports(
|
|
|
|
|
+ report_paths: Sequence[Path],
|
|
|
|
|
+ *,
|
|
|
|
|
+ min_hit_score: float = 0.9,
|
|
|
|
|
+) -> Dict[str, Any]:
|
|
|
|
|
+ """跨多个 sweep_report.json 汇总,输出 Pass1/Pass2 配置建议。"""
|
|
|
|
|
+ per_case: List[Dict[str, Any]] = []
|
|
|
|
|
+ all_qualified_tags: List[set] = []
|
|
|
|
|
+
|
|
|
|
|
+ for rp in report_paths:
|
|
|
|
|
+ if not rp.is_file():
|
|
|
|
|
+ per_case.append({"report": str(rp), "error": "missing"})
|
|
|
|
|
+ continue
|
|
|
|
|
+ report = json.loads(rp.read_text(encoding="utf-8"))
|
|
|
|
|
+ qualified = _qualified_from_report(report, min_hit_score)
|
|
|
|
|
+ tags = [str(r.get("tag") or "") for r in qualified if r.get("tag")]
|
|
|
|
|
+ all_qualified_tags.append(set(tags))
|
|
|
|
|
+ top = qualified[0] if qualified else None
|
|
|
|
|
+ per_case.append(
|
|
|
|
|
+ {
|
|
|
|
|
+ "report": str(rp),
|
|
|
|
|
+ "input": report.get("input"),
|
|
|
|
|
+ "target": report.get("target"),
|
|
|
|
|
+ "qualified_count": len(qualified),
|
|
|
|
|
+ "top_tag": top.get("tag") if top else None,
|
|
|
|
|
+ "top_score": top.get("score") if top else None,
|
|
|
|
|
+ "top_text": top.get("text") if top else None,
|
|
|
|
|
+ }
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ ok_cases = [c for c in per_case if c.get("top_tag")]
|
|
|
|
|
+ intersection: set = set.intersection(*all_qualified_tags) if all_qualified_tags else set()
|
|
|
|
|
+ intersection_sorted = sorted(
|
|
|
|
|
+ intersection,
|
|
|
|
|
+ key=lambda t: -max(
|
|
|
|
|
+ float(r.get("score") or 0)
|
|
|
|
|
+ for rp in report_paths
|
|
|
|
|
+ if rp.is_file()
|
|
|
|
|
+ for r in _qualified_from_report(json.loads(rp.read_text(encoding="utf-8")), min_hit_score)
|
|
|
|
|
+ if r.get("tag") == t
|
|
|
|
|
+ ),
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ pick_tag: Optional[str] = None
|
|
|
|
|
+ pick_reason = ""
|
|
|
|
|
+ majority_fields: Dict[str, Any] = {}
|
|
|
|
|
+ if intersection_sorted:
|
|
|
|
|
+ pick_tag = intersection_sorted[0]
|
|
|
|
|
+ pick_reason = f"全部 {len(report_paths)} 个 case 的合格集合交集,取最高分 tag"
|
|
|
|
|
+ elif ok_cases:
|
|
|
|
|
+ parsed: List[Dict[str, Any]] = []
|
|
|
|
|
+ for c in ok_cases:
|
|
|
|
|
+ tag = c["top_tag"]
|
|
|
|
|
+ if tag and str(tag).startswith("baseline"):
|
|
|
|
|
+ continue
|
|
|
|
|
+ try:
|
|
|
|
|
+ parsed.append(_parse_best_config(str(tag)))
|
|
|
|
|
+ except ValueError:
|
|
|
|
|
+ continue
|
|
|
|
|
+ if parsed:
|
|
|
|
|
+ pick_reason = "无全局 tag 交集,按各 case 榜首做分字段多数票(跳过 baseline)"
|
|
|
|
|
+ method = _majority_key([p["method"] for p in parsed])
|
|
|
|
|
+ thresh = _majority_key([p.get("threshold") for p in parsed])
|
|
|
|
|
+ upscale = _majority_key([p["upscale"] for p in parsed])
|
|
|
|
|
+ cm = _majority_key([p["contrast_cfg"].get("method") for p in parsed])
|
|
|
|
|
+ c_cfgs = [p["contrast_cfg"] for p in parsed if p["contrast_cfg"].get("method") == cm]
|
|
|
|
|
+ contrast_cfg: Dict[str, Any] = {"method": cm or "none"}
|
|
|
|
|
+ if cm == "clahe" and c_cfgs:
|
|
|
|
|
+ contrast_cfg["clip_limit"] = _majority_key([c.get("clip_limit") for c in c_cfgs])
|
|
|
|
|
+ contrast_cfg["tile_grid_size"] = _majority_key(
|
|
|
|
|
+ [c.get("tile_grid_size") for c in c_cfgs]
|
|
|
|
|
+ )
|
|
|
|
|
+ synthetic = {
|
|
|
|
|
+ "method": method,
|
|
|
|
|
+ "threshold": thresh,
|
|
|
|
|
+ "contrast_cfg": contrast_cfg,
|
|
|
|
|
+ "contrast_order": _majority_key([p["contrast_order"] for p in parsed]),
|
|
|
|
|
+ "upscale": upscale,
|
|
|
|
|
+ "det_db_box_thresh": _majority_key([p["det_db_box_thresh"] for p in parsed]),
|
|
|
|
|
+ }
|
|
|
|
|
+ c_tag = _contrast_tag(contrast_cfg)
|
|
|
|
|
+ o_tag = "b" if synthetic["contrast_order"] == "before_upscale" else "a"
|
|
|
|
|
+ t_s = str(synthetic["threshold"] or "d")
|
|
|
|
|
+ pick_tag = (
|
|
|
|
|
+ f"{method}_t{t_s}_{c_tag}_o{o_tag}_u{synthetic['upscale']}"
|
|
|
|
|
+ f"_det{synthetic['det_db_box_thresh']}"
|
|
|
|
|
+ )
|
|
|
|
|
+ majority_fields = synthetic
|
|
|
|
|
+ else:
|
|
|
|
|
+ pick_tag = _majority_key([c["top_tag"] for c in ok_cases])
|
|
|
|
|
+ pick_reason = "无全局交集,按各 case 榜首 tag 多数票(含 baseline)"
|
|
|
|
|
+ pass1_cpp = _tag_to_cell_preprocess_yaml(pick_tag) if pick_tag else None
|
|
|
|
|
+
|
|
|
|
|
+ def _is_pass2_tile_tag(t: str) -> bool:
|
|
|
|
|
+ import re
|
|
|
|
|
+ return bool(re.search(r"cl_[\d.]+_4_", t))
|
|
|
|
|
+
|
|
|
|
|
+ pass2_candidates = [t for t in intersection_sorted if _is_pass2_tile_tag(t)]
|
|
|
|
|
+ pass2_tag = pass2_candidates[0] if pass2_candidates else pick_tag
|
|
|
|
|
+ pass2_cpp = _tag_to_cell_preprocess_yaml(pass2_tag) if pass2_tag else None
|
|
|
|
|
+
|
|
|
|
|
+ return {
|
|
|
|
|
+ "min_hit_score": min_hit_score,
|
|
|
|
|
+ "per_case": per_case,
|
|
|
|
|
+ "intersection_tags": intersection_sorted[:20],
|
|
|
|
|
+ "intersection_count": len(intersection),
|
|
|
|
|
+ "recommended_tag": pick_tag,
|
|
|
|
|
+ "pick_reason": pick_reason,
|
|
|
|
|
+ "majority_fields": majority_fields,
|
|
|
|
|
+ "pass1_cell_preprocess": pass1_cpp,
|
|
|
|
|
+ "pass2_enhance_retry": (pass2_cpp or {}).get("enhance_retry") if pass2_cpp else None,
|
|
|
|
|
+ "pass2_tag": pass2_tag,
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+_CASE_META_KEYS = frozenset({"name", "note", "description"})
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _config_to_argv(cfg: Dict[str, Any]) -> List[str]:
|
|
|
|
|
+ argv = [str(cfg["input"])]
|
|
|
|
|
+ for key, value in cfg.items():
|
|
|
|
|
+ if key == "input" or key in _CASE_META_KEYS:
|
|
|
|
|
+ continue
|
|
|
|
|
+ flag = f"--{key.replace('_', '-')}"
|
|
|
|
|
+ if isinstance(value, bool) and value:
|
|
|
|
|
+ argv.append(flag)
|
|
|
|
|
+ elif not isinstance(value, bool):
|
|
|
|
|
+ argv.extend([flag, str(value)])
|
|
|
|
|
+ return argv
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _load_cases_json(path: Path) -> List[Dict[str, Any]]:
|
|
|
|
|
+ data = json.loads(path.read_text(encoding="utf-8"))
|
|
|
|
|
+ shared = data.get("shared") or {}
|
|
|
|
|
+ cases = data.get("cases") or []
|
|
|
|
|
+ out: List[Dict[str, Any]] = []
|
|
|
|
|
+ for c in cases:
|
|
|
|
|
+ merged = {**shared, **c}
|
|
|
|
|
+ out.append(merged)
|
|
|
|
|
+ return out
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def _find_sweep_report(case_output: Path, input_path: Path) -> Optional[Path]:
|
|
|
|
|
+ stem = Path(str(input_path)).stem.removesuffix("_raw")
|
|
|
|
|
+ direct = case_output / stem / "sweep_report.json"
|
|
|
|
|
+ if direct.is_file():
|
|
|
|
|
+ return direct
|
|
|
|
|
+ hits = sorted(case_output.rglob("sweep_report.json"))
|
|
|
|
|
+ return hits[0] if hits else None
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def run_cases_batch(cases: Sequence[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
|
|
|
+ """顺序执行多组 case(每组独立 output)。"""
|
|
|
|
|
+ summaries: List[Dict[str, Any]] = []
|
|
|
|
|
+ for i, case in enumerate(cases, 1):
|
|
|
|
|
+ name = case.get("name") or case.get("input")
|
|
|
|
|
+ print(f"\n{'#'*70}\n 批量 case {i}/{len(cases)}: {name}\n{'#'*70}")
|
|
|
|
|
+ argv = _config_to_argv(case)
|
|
|
|
|
+ main(argv)
|
|
|
|
|
+ out = Path(case["output"])
|
|
|
|
|
+ rp = _find_sweep_report(out, Path(case["input"]))
|
|
|
|
|
+ summaries.append({"name": name, "output": str(out), "report": str(rp) if rp else None})
|
|
|
|
|
+ return summaries
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
def run_best_config(
|
|
def run_best_config(
|
|
|
input_path: Path,
|
|
input_path: Path,
|
|
|
out_dir: Path,
|
|
out_dir: Path,
|
|
@@ -723,6 +1030,8 @@ def _build_arg_parser() -> argparse.ArgumentParser:
|
|
|
p.add_argument(
|
|
p.add_argument(
|
|
|
"input",
|
|
"input",
|
|
|
type=Path,
|
|
type=Path,
|
|
|
|
|
+ nargs="?",
|
|
|
|
|
+ default=None,
|
|
|
help="单元格裁剪图路径,或 tablecell_ocr 目录(批量扫描)",
|
|
help="单元格裁剪图路径,或 tablecell_ocr 目录(批量扫描)",
|
|
|
)
|
|
)
|
|
|
p.add_argument(
|
|
p.add_argument(
|
|
@@ -739,6 +1048,12 @@ def _build_arg_parser() -> argparse.ArgumentParser:
|
|
|
help="期望 OCR 文本;用于标记 HIT(子串匹配)。省略则任意非空为 HIT",
|
|
help="期望 OCR 文本;用于标记 HIT(子串匹配)。省略则任意非空为 HIT",
|
|
|
)
|
|
)
|
|
|
p.add_argument(
|
|
p.add_argument(
|
|
|
|
|
+ "--min-hit-score",
|
|
|
|
|
+ type=float,
|
|
|
|
|
+ default=0.9,
|
|
|
|
|
+ help="结论中「命中列表」的最低 score 阈值(默认 0.9,仅 det_rec)",
|
|
|
|
|
+ )
|
|
|
|
|
+ p.add_argument(
|
|
|
"--model-dir",
|
|
"--model-dir",
|
|
|
type=Path,
|
|
type=Path,
|
|
|
default=None,
|
|
default=None,
|
|
@@ -806,11 +1121,80 @@ def _build_arg_parser() -> argparse.ArgumentParser:
|
|
|
default="threshold_t150_cl_1.0_8_ob_u128_det0.5",
|
|
default="threshold_t150_cl_1.0_8_ob_u128_det0.5",
|
|
|
help="最优参数 tag,如 threshold_t150_cl_1.0_8_ob_u128_det0.5",
|
|
help="最优参数 tag,如 threshold_t150_cl_1.0_8_ob_u128_det0.5",
|
|
|
)
|
|
)
|
|
|
|
|
+ p.add_argument(
|
|
|
|
|
+ "--cases",
|
|
|
|
|
+ type=Path,
|
|
|
|
|
+ default=None,
|
|
|
|
|
+ help="批量 case JSON(见 sweep_cases.json),顺序跑网格扫描",
|
|
|
|
|
+ )
|
|
|
|
|
+ p.add_argument(
|
|
|
|
|
+ "--aggregate-only",
|
|
|
|
|
+ action="store_true",
|
|
|
|
|
+ help="不扫描,仅根据已有 sweep_report.json 汇总 Pass1/Pass2 建议",
|
|
|
|
|
+ )
|
|
|
|
|
+ p.add_argument(
|
|
|
|
|
+ "--aggregate-out",
|
|
|
|
|
+ type=Path,
|
|
|
|
|
+ default=None,
|
|
|
|
|
+ help="汇总输出路径(默认 <cases 父目录>/aggregate_recommendation.json)",
|
|
|
|
|
+ )
|
|
|
return p
|
|
return p
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(argv: Optional[Sequence[str]] = None) -> None:
|
|
def main(argv: Optional[Sequence[str]] = None) -> None:
|
|
|
args = _build_arg_parser().parse_args(argv)
|
|
args = _build_arg_parser().parse_args(argv)
|
|
|
|
|
+
|
|
|
|
|
+ if args.aggregate_only:
|
|
|
|
|
+ if not args.cases or not args.cases.is_file():
|
|
|
|
|
+ raise SystemExit("--aggregate-only 需要 --cases <sweep_cases.json>")
|
|
|
|
|
+ cases = _load_cases_json(args.cases)
|
|
|
|
|
+ report_paths: List[Path] = []
|
|
|
|
|
+ for case in cases:
|
|
|
|
|
+ out = Path(case["output"])
|
|
|
|
|
+ rp = _find_sweep_report(out, Path(case["input"]))
|
|
|
|
|
+ if rp:
|
|
|
|
|
+ report_paths.append(rp)
|
|
|
|
|
+ else:
|
|
|
|
|
+ print(f" 跳过(无报告): {case.get('name') or case['input']}")
|
|
|
|
|
+ agg = aggregate_sweep_reports(
|
|
|
|
|
+ report_paths, min_hit_score=args.min_hit_score
|
|
|
|
|
+ )
|
|
|
|
|
+ agg_out = args.aggregate_out or args.cases.parent / "aggregate_recommendation.json"
|
|
|
|
|
+ agg_out.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
+ agg_out.write_text(json.dumps(agg, ensure_ascii=False, indent=2), encoding="utf-8")
|
|
|
|
|
+ print(f"\n汇总完成 -> {agg_out}")
|
|
|
|
|
+ print(f" 推荐 tag: {agg.get('recommended_tag')}")
|
|
|
|
|
+ print(f" 交集合格 tag 数: {agg.get('intersection_count')}")
|
|
|
|
|
+ if agg.get("pass1_cell_preprocess"):
|
|
|
|
|
+ print("\n Pass1 cell_preprocess 片段:")
|
|
|
|
|
+ print(json.dumps(agg["pass1_cell_preprocess"], ensure_ascii=False, indent=2))
|
|
|
|
|
+ if agg.get("pass2_enhance_retry"):
|
|
|
|
|
+ print("\n Pass2 enhance_retry 片段:")
|
|
|
|
|
+ print(json.dumps(agg["pass2_enhance_retry"], ensure_ascii=False, indent=2))
|
|
|
|
|
+ return
|
|
|
|
|
+
|
|
|
|
|
+ if args.cases and args.cases.is_file():
|
|
|
|
|
+ if args.input is not None:
|
|
|
|
|
+ print(" 提示: 已指定 --cases,忽略 positional input")
|
|
|
|
|
+ cases = _load_cases_json(args.cases)
|
|
|
|
|
+ run_cases_batch(cases)
|
|
|
|
|
+ report_paths = []
|
|
|
|
|
+ for case in cases:
|
|
|
|
|
+ rp = _find_sweep_report(Path(case["output"]), Path(case["input"]))
|
|
|
|
|
+ if rp:
|
|
|
|
|
+ report_paths.append(rp)
|
|
|
|
|
+ if report_paths:
|
|
|
|
|
+ agg = aggregate_sweep_reports(
|
|
|
|
|
+ report_paths, min_hit_score=args.min_hit_score
|
|
|
|
|
+ )
|
|
|
|
|
+ agg_out = args.aggregate_out or args.cases.parent / "aggregate_recommendation.json"
|
|
|
|
|
+ agg_out.write_text(json.dumps(agg, ensure_ascii=False, indent=2), encoding="utf-8")
|
|
|
|
|
+ print(f"\n批量汇总 -> {agg_out}")
|
|
|
|
|
+ print(f" 推荐 tag: {agg.get('recommended_tag')} ({agg.get('pick_reason')})")
|
|
|
|
|
+ return
|
|
|
|
|
+
|
|
|
|
|
+ if args.input is None:
|
|
|
|
|
+ raise SystemExit("需要 input 路径,或使用 --cases / --aggregate-only")
|
|
|
inputs = collect_inputs(args.input, prefer_raw=not args.no_prefer_raw)
|
|
inputs = collect_inputs(args.input, prefer_raw=not args.no_prefer_raw)
|
|
|
if not inputs:
|
|
if not inputs:
|
|
|
raise SystemExit("未找到可扫描的图像")
|
|
raise SystemExit("未找到可扫描的图像")
|
|
@@ -878,7 +1262,7 @@ def main(argv: Optional[Sequence[str]] = None) -> None:
|
|
|
|
|
|
|
|
if args.quick:
|
|
if args.quick:
|
|
|
thresholds = [150, 155]
|
|
thresholds = [150, 155]
|
|
|
- upscales = [128, 192]
|
|
|
|
|
|
|
+ upscales = [96, 128, 192]
|
|
|
det_threshs = [0.5]
|
|
det_threshs = [0.5]
|
|
|
else:
|
|
else:
|
|
|
thresholds = _parse_csv_ints(args.thresholds)
|
|
thresholds = _parse_csv_ints(args.thresholds)
|
|
@@ -911,11 +1295,16 @@ def main(argv: Optional[Sequence[str]] = None) -> None:
|
|
|
save_images=not args.no_save_images,
|
|
save_images=not args.no_save_images,
|
|
|
run_baseline=not args.no_baseline,
|
|
run_baseline=not args.no_baseline,
|
|
|
baseline_upscale=args.baseline_upscale,
|
|
baseline_upscale=args.baseline_upscale,
|
|
|
|
|
+ min_hit_score=args.min_hit_score,
|
|
|
)
|
|
)
|
|
|
|
|
+ qh = report.get("hits_target_score_above") or []
|
|
|
summary.append(
|
|
summary.append(
|
|
|
{
|
|
{
|
|
|
"input": report["input"],
|
|
"input": report["input"],
|
|
|
"hits": len(report["hits"]),
|
|
"hits": len(report["hits"]),
|
|
|
|
|
+ "hits_target_score_above": len(qh),
|
|
|
|
|
+ "top_qualified_tag": qh[0]["tag"] if qh else None,
|
|
|
|
|
+ "top_qualified_score": qh[0]["score"] if qh else None,
|
|
|
"report": str(Path(report["output_dir"]) / "sweep_report.json"),
|
|
"report": str(Path(report["output_dir"]) / "sweep_report.json"),
|
|
|
}
|
|
}
|
|
|
)
|
|
)
|
|
@@ -926,46 +1315,20 @@ def main(argv: Optional[Sequence[str]] = None) -> None:
|
|
|
)
|
|
)
|
|
|
print(f"\n全部完成,索引: {index_path}")
|
|
print(f"\n全部完成,索引: {index_path}")
|
|
|
for s in summary:
|
|
for s in summary:
|
|
|
- print(f" {s['input']}: {s['hits']} hits -> {s['report']}")
|
|
|
|
|
|
|
+ qn = s.get("hits_target_score_above", 0)
|
|
|
|
|
+ top = s.get("top_qualified_tag")
|
|
|
|
|
+ print(
|
|
|
|
|
+ f" {s['input']}: hits={s['hits']} "
|
|
|
|
|
+ f"qualified(score>{args.min_hit_score})={qn}"
|
|
|
|
|
+ + (f" top={top}" if top else "")
|
|
|
|
|
+ + f" -> {s['report']}"
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
|
if len(sys.argv) == 1:
|
|
if len(sys.argv) == 1:
|
|
|
- print("ℹ️ 未提供命令行参数,使用默认配置运行...")
|
|
|
|
|
- default_config = {
|
|
|
|
|
- # "input": "/Users/zhch158/workspace/data/流水分析/彭_广东兴宁农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/彭_广东兴宁农村商业银行_page_002_0/cell219_empty_empty_raw.png",
|
|
|
|
|
- # "output": "./output/彭_广东兴宁农村商业银行/cell219_sweep",
|
|
|
|
|
- # "target": "ATM存折取款",
|
|
|
|
|
-
|
|
|
|
|
- # "input": "/Users/zhch158/workspace/data/流水分析/彭_广东兴宁农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/彭_广东兴宁农村商业银行_page_002_0/cell007_whole_longer_易型交类_raw.png",
|
|
|
|
|
- # "output": "./output/彭_广东兴宁农村商业银行/cell007_sweep",
|
|
|
|
|
- # "target": "交易类型",
|
|
|
|
|
- # "quick": True,
|
|
|
|
|
-
|
|
|
|
|
- # "input": "/Users/zhch158/workspace/data/流水分析/钟_广东陆丰农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/钟_广东陆丰农村商业银行_page_001_0/cell217_empty_empty_raw.png",
|
|
|
|
|
- # "output": "./output/钟_广东陆丰农村商业银行/cell217_sweep",
|
|
|
|
|
- # "target": "专项资金",
|
|
|
|
|
- # "quick": True,
|
|
|
|
|
-
|
|
|
|
|
- # "input": "/Users/zhch158/workspace/data/流水分析/彭_广东兴宁农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/彭_广东兴宁农村商业银行_page_002_0",
|
|
|
|
|
- # "output": "./output/彭_广东兴宁农村商业银行",
|
|
|
|
|
- # "best-config": "threshold_t150_cl_1.0_8_ob_u128_det0.5",
|
|
|
|
|
- # "best-only": True,
|
|
|
|
|
-
|
|
|
|
|
- "input": "/Users/zhch158/workspace/data/流水分析/钟_广东陆丰农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/钟_广东陆丰农村商业银行_page_001_0",
|
|
|
|
|
- "output": "./output/钟_广东陆丰农村商业银行",
|
|
|
|
|
- # "best-config": "threshold_t150_cl_1.0_8_ob_u128_det0.5",
|
|
|
|
|
- "best-config": "threshold_t150_cl_1.0_4_ob_u128_det0.5",
|
|
|
|
|
- "best-only": True,
|
|
|
|
|
- }
|
|
|
|
|
- sys.argv = [sys.argv[0], default_config["input"]]
|
|
|
|
|
- for key, value in default_config.items():
|
|
|
|
|
- if key == "input":
|
|
|
|
|
- continue
|
|
|
|
|
- flag = f"--{key.replace('_', '-')}"
|
|
|
|
|
- if isinstance(value, bool) and value:
|
|
|
|
|
- sys.argv.append(flag)
|
|
|
|
|
- elif not isinstance(value, bool):
|
|
|
|
|
- sys.argv.extend([flag, str(value)])
|
|
|
|
|
|
|
+ cases_path = Path(__file__).resolve().parent / "sweep_cases.json"
|
|
|
|
|
+ print(f"ℹ️ 未提供命令行参数,使用批量 cases: {cases_path.name}")
|
|
|
|
|
+ sys.argv = [sys.argv[0], "--cases", str(cases_path)]
|
|
|
|
|
|
|
|
sys.exit(main())
|
|
sys.exit(main())
|