#!/usr/bin/env python """model_doctor —— 模型变更巡检工具(方案 B)。 依据手工维护的 model_registry.yaml,对四类模型来源采集「指纹」, 与基线 models.lock.json 比对,报告哪些模型发生了变化 / 缺失 / 服务不可达, 以及(可选)HF 远端是否有更新。 子命令: list 列出 registry 中的模型条目 show 采集并打印当前指纹(不比对) check 采集并与 lock 基线比对,输出报告(有变更/缺失则退出码非 0) update-lock 采集并把当前指纹固化为新的 lock 基线 用法(建议在 conda 环境 mineru 下): conda run -n mineru python model_doctor.py check conda run -n mineru python model_doctor.py check --online --hash conda run -n mineru python model_doctor.py update-lock """ from __future__ import annotations import argparse import hashlib import json import os import sys import urllib.request from datetime import datetime from pathlib import Path try: import yaml except ImportError: sys.stderr.write("缺少依赖 PyYAML,请在 mineru 环境安装:conda run -n mineru pip install pyyaml\n") raise HERE = Path(__file__).resolve().parent DEFAULT_REGISTRY = HERE / "model_registry.yaml" DEFAULT_LOCK = HERE / "models.lock.json" # 报告状态符号 SYM = { "ok": "✅", "changed": "⚠️ ", "remote_update": "🔺", "missing": "❌", "unreachable": "❌", "new": "🆕", "removed": "🗑 ", "skipped": "· ", } # 目录指纹遍历上限,避免误指向超大目录卡死 _MAX_DIR_FILES = 20000 # --------------------------------------------------------------------------- # # 通用指纹工具 # --------------------------------------------------------------------------- # def _fast_sha256(path: Path, head_tail_mb: int = 8) -> str: """对大文件取「头 + 尾 + 大小」的快速 sha256,避免全量读取。""" size = path.stat().st_size chunk = head_tail_mb * 1024 * 1024 h = hashlib.sha256() h.update(str(size).encode()) with path.open("rb") as f: h.update(f.read(chunk)) if size > chunk * 2: f.seek(-chunk, os.SEEK_END) h.update(f.read(chunk)) return h.hexdigest() def _file_fp(path: Path, do_hash: bool) -> dict: if not path.exists(): return {"exists": False} st = path.stat() fp = { "exists": True, "size": st.st_size, "mtime": int(st.st_mtime), } if do_hash and path.is_file(): fp["sha256_fast"] = _fast_sha256(path) return fp def _dir_fp(path: Path, do_hash: bool) -> dict: """目录指纹:聚合 (相对路径, size, mtime) 排序后的 sha256。""" if not path.exists(): return {"exists": False} files = [] count = 0 for p in sorted(path.rglob("*")): if p.is_file(): count += 1 if count > _MAX_DIR_FILES: files.append(("", -1, -1)) break st = p.stat() files.append((str(p.relative_to(path)), st.st_size, int(st.st_mtime))) h = hashlib.sha256() total_size = 0 for rel, size, mtime in files: h.update(f"{rel}|{size}|{mtime}\n".encode()) if size > 0: total_size += size fp = { "exists": True, "file_count": len([f for f in files if f[1] >= 0]), "total_size": total_size, "tree_sha256": h.hexdigest(), } return fp # --------------------------------------------------------------------------- # # 各 kind 的指纹采集 # --------------------------------------------------------------------------- # def fp_local_file(entry: dict, defaults: dict) -> dict: do_hash = entry.get("hash", defaults.get("hash", False)) path = Path(os.path.expanduser(entry["path"])) fp = _dir_fp(path, do_hash) if path.is_dir() else _file_fp(path, do_hash) status = "ok" if fp.get("exists") else "missing" return {"status": status, "fingerprint": fp} def fp_hf(entry: dict, defaults: dict) -> dict: repo_id = entry["repo_id"] hub_dir = Path(os.path.expanduser(entry.get("cache_dir", defaults["hf_hub_dir"]))) repo_dir = hub_dir / ("models--" + repo_id.replace("/", "--")) fp: dict = {"repo_id": repo_id} local_rev = None if repo_dir.exists(): refs_main = repo_dir / "refs" / "main" if refs_main.exists(): local_rev = refs_main.read_text().strip() else: snaps = repo_dir / "snapshots" if snaps.exists(): cand = sorted([d.name for d in snaps.iterdir() if d.is_dir()]) local_rev = cand[-1] if cand else None fp["local_revision"] = local_rev fp["cached"] = True else: fp["cached"] = False status = "ok" if fp.get("cached") else "missing" # 可选:查远端最新 commit if entry.get("online", defaults.get("online", False)): try: from huggingface_hub import HfApi remote_sha = HfApi().model_info(repo_id).sha fp["remote_revision"] = remote_sha if local_rev and remote_sha and local_rev != remote_sha: status = "remote_update" except Exception as e: # 网络不可达等 fp["remote_error"] = str(e) return {"status": status, "fingerprint": fp} def fp_daemon(entry: dict, defaults: dict) -> dict: do_hash = entry.get("hash", defaults.get("hash", False)) timeout = entry.get("daemon_timeout", defaults.get("daemon_timeout", 3)) url = entry["server_url"].rstrip("/") + "/v1/models" fp: dict = {"server_url": entry["server_url"]} status = "ok" try: req = urllib.request.Request(url, headers={"Accept": "application/json"}) with urllib.request.urlopen(req, timeout=timeout) as resp: data = json.loads(resp.read().decode()) served = [m.get("id") for m in data.get("data", [])] fp["reachable"] = True fp["served_models"] = served expect = entry.get("served_model") if expect and expect not in served: fp["served_mismatch"] = {"expect": expect, "actual": served} status = "changed" except Exception as e: fp["reachable"] = False fp["error"] = str(e) status = "unreachable" # 本地 GGUF 资产指纹(即使服务不可达也采集,便于发现文件被换) assets = entry.get("assets") or [] if assets: fp["assets"] = {} for a in assets: p = Path(os.path.expanduser(a)) afp = _file_fp(p, do_hash) fp["assets"][a] = afp if not afp.get("exists"): status = "missing" if status == "ok" else status return {"status": status, "fingerprint": fp} def fp_mineru(entry: dict, defaults: dict) -> dict: do_hash = entry.get("hash", defaults.get("hash", False)) pkg = entry.get("package", "mineru") fp: dict = {"package": pkg} status = "ok" try: import importlib.metadata as md fp["package_version"] = md.version(pkg) except Exception as e: fp["package_error"] = str(e) status = "missing" root = entry.get("model_root") if root: p = Path(os.path.expanduser(root)) rfp = _dir_fp(p, do_hash) fp["model_root"] = str(p) fp["model_root_fp"] = rfp if not rfp.get("exists"): status = "missing" if status == "ok" else status return {"status": status, "fingerprint": fp} _COLLECTORS = { "local_file": fp_local_file, "hf": fp_hf, "daemon": fp_daemon, "mineru": fp_mineru, } # --------------------------------------------------------------------------- # # registry / lock 读写与比对 # --------------------------------------------------------------------------- # def load_registry(path: Path) -> dict: with path.open("r", encoding="utf-8") as f: reg = yaml.safe_load(f) reg.setdefault("defaults", {}) reg.setdefault("models", []) return reg def collect(reg: dict, online: bool, do_hash: bool) -> dict: defaults = dict(reg.get("defaults", {})) if online: defaults["online"] = True if do_hash: defaults["hash"] = True snapshot = {} for entry in reg.get("models", []): name = entry["name"] if not entry.get("enabled", True): snapshot[name] = {"kind": entry.get("kind"), "status": "skipped", "fingerprint": {}} continue kind = entry.get("kind") collector = _COLLECTORS.get(kind) if collector is None: snapshot[name] = {"kind": kind, "status": "missing", "fingerprint": {"error": f"未知 kind: {kind}"}} continue try: result = collector(entry, defaults) except Exception as e: result = {"status": "missing", "fingerprint": {"error": str(e)}} result["kind"] = kind result["used_by"] = entry.get("used_by", []) snapshot[name] = result return snapshot def load_lock(path: Path) -> dict: if not path.exists(): return {} with path.open("r", encoding="utf-8") as f: return json.load(f).get("models", {}) def save_lock(path: Path, snapshot: dict) -> None: payload = { "generated_at": datetime.now().astimezone().isoformat(timespec="seconds"), "models": snapshot, } with path.open("w", encoding="utf-8") as f: json.dump(payload, f, ensure_ascii=False, indent=2) def diff_fp(old: dict, new: dict) -> list: """返回发生变化的字段路径(浅层 + 一层嵌套)。""" changes = [] keys = set(old.keys()) | set(new.keys()) for k in sorted(keys): ov, nv = old.get(k), new.get(k) if isinstance(ov, dict) and isinstance(nv, dict): for sk in sorted(set(ov) | set(nv)): if ov.get(sk) != nv.get(sk): changes.append(f"{k}.{sk}: {ov.get(sk)} → {nv.get(sk)}") elif ov != nv: changes.append(f"{k}: {ov} → {nv}") return changes # --------------------------------------------------------------------------- # # 子命令 # --------------------------------------------------------------------------- # def cmd_list(reg: dict) -> int: print(f"模型清单({len(reg.get('models', []))} 条):\n") for e in reg.get("models", []): flag = " " if e.get("enabled", True) else "× " used = ";".join(e.get("used_by", [])) print(f"{flag}[{e.get('kind'):<10}] {e['name']}") if used: print(f" 用于:{used}") print("\n(× 表示 enabled: false,体检时跳过)") return 0 def cmd_show(reg: dict, online: bool, do_hash: bool) -> int: snap = collect(reg, online, do_hash) print(json.dumps(snap, ensure_ascii=False, indent=2)) return 0 def cmd_check(reg: dict, lock_path: Path, online: bool, do_hash: bool, strict: bool) -> int: new = collect(reg, online, do_hash) old = load_lock(lock_path) has_baseline = bool(old) problems = 0 # missing / unreachable / remote_update changes = 0 # 指纹变化 news = 0 # 新增(lock 中无记录) print(f"模型体检报告 baseline={'有' if has_baseline else '无(首次,请先 update-lock)'}\n") for name, cur in new.items(): kind = cur.get("kind") status = cur.get("status") if status == "skipped": print(f"{SYM['skipped']} {name:<28} [{kind}] 跳过(disabled)") continue prev = old.get(name) if prev is None: news += 1 print(f"{SYM['new']} {name:<28} [{kind}] 新增条目(基线中无记录)") continue # 状态类问题优先 if status in ("missing", "unreachable"): problems += 1 detail = cur["fingerprint"].get("error", "") print(f"{SYM[status]} {name:<28} [{kind}] {status} {detail}") continue if status == "remote_update": problems += 1 fpr = cur["fingerprint"] print(f"{SYM['remote_update']} {name:<28} [{kind}] HF 远端有更新:" f"{fpr.get('local_revision')} → {fpr.get('remote_revision')}") continue # 指纹比对 fp_changes = diff_fp(prev.get("fingerprint", {}), cur.get("fingerprint", {})) if fp_changes: changes += 1 print(f"{SYM['changed']} {name:<28} [{kind}] 指纹变化:") for c in fp_changes: print(f" - {c}") else: print(f"{SYM['ok']} {name:<28} [{kind}] 未变化") # 基线中存在但 registry 已删除 removed = [n for n in old if n not in new] for n in removed: print(f"{SYM['removed']} {n:<28} 基线中存在但 registry 已移除") print("\n" + "-" * 60) print(f"问题(缺失/不可达/远端更新)={problems} 指纹变化={changes} " f"新增={news} 移除={len(removed)}") if not has_baseline: print("提示:尚无基线,运行 `update-lock` 生成。") return 1 fail = problems + (changes if strict else 0) + news if problems and not strict: # 即使非 strict,缺失/不可达也应非 0 退出 return 1 return 1 if fail else 0 def cmd_update_lock(reg: dict, lock_path: Path, online: bool, do_hash: bool) -> int: snap = collect(reg, online, do_hash) save_lock(lock_path, snap) enabled = [n for n, v in snap.items() if v.get("status") != "skipped"] print(f"已写入基线 {lock_path}({len(enabled)} 条生效," f"{len(snap) - len(enabled)} 条跳过)") return 0 # --------------------------------------------------------------------------- # # 入口 # --------------------------------------------------------------------------- # def main(argv=None) -> int: parser = argparse.ArgumentParser(description="模型变更巡检工具 model_doctor") parser.add_argument("command", choices=["list", "show", "check", "update-lock"]) parser.add_argument("--registry", type=Path, default=DEFAULT_REGISTRY, help=f"清单文件,默认 {DEFAULT_REGISTRY.name}") parser.add_argument("--lock", type=Path, default=DEFAULT_LOCK, help=f"基线文件,默认 {DEFAULT_LOCK.name}") parser.add_argument("--online", action="store_true", help="hf 条目额外查询远端最新 commit 进行比对") parser.add_argument("--hash", dest="do_hash", action="store_true", help="本地文件/目录额外计算快速 sha256(更敏感但更慢)") parser.add_argument("--strict", action="store_true", help="check 时把『指纹变化』也视为失败(非 0 退出)") args = parser.parse_args(argv) reg = load_registry(args.registry) if args.command == "list": return cmd_list(reg) if args.command == "show": return cmd_show(reg, args.online, args.do_hash) if args.command == "check": return cmd_check(reg, args.lock, args.online, args.do_hash, args.strict) if args.command == "update-lock": return cmd_update_lock(reg, args.lock, args.online, args.do_hash) return 2 if __name__ == "__main__": sys.exit(main())