| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433 |
- #!/usr/bin/env python
- """model_doctor —— 模型变更巡检工具(方案 B)。
- 依据手工维护的 model_registry.yaml,对四类模型来源采集「指纹」,
- 与基线 models.lock.json 比对,报告哪些模型发生了变化 / 缺失 / 服务不可达,
- 以及(可选)HF 远端是否有更新。
- 子命令:
- list 列出 registry 中的模型条目
- show 采集并打印当前指纹(不比对)
- check 采集并与 lock 基线比对,输出报告(有变更/缺失则退出码非 0)
- update-lock 采集并把当前指纹固化为新的 lock 基线
- 用法(建议在 conda 环境 mineru 下):
- conda run -n mineru python model_doctor.py check
- conda run -n mineru python model_doctor.py check --online --hash
- conda run -n mineru python model_doctor.py update-lock
- """
- from __future__ import annotations
- import argparse
- import hashlib
- import json
- import os
- import sys
- import urllib.request
- from datetime import datetime
- from pathlib import Path
- try:
- import yaml
- except ImportError:
- sys.stderr.write("缺少依赖 PyYAML,请在 mineru 环境安装:conda run -n mineru pip install pyyaml\n")
- raise
- HERE = Path(__file__).resolve().parent
- DEFAULT_REGISTRY = HERE / "model_registry.yaml"
- DEFAULT_LOCK = HERE / "models.lock.json"
- # 报告状态符号
- SYM = {
- "ok": "✅",
- "changed": "⚠️ ",
- "remote_update": "🔺",
- "missing": "❌",
- "unreachable": "❌",
- "new": "🆕",
- "removed": "🗑 ",
- "skipped": "· ",
- }
- # 目录指纹遍历上限,避免误指向超大目录卡死
- _MAX_DIR_FILES = 20000
- # --------------------------------------------------------------------------- #
- # 通用指纹工具
- # --------------------------------------------------------------------------- #
- def _fast_sha256(path: Path, head_tail_mb: int = 8) -> str:
- """对大文件取「头 + 尾 + 大小」的快速 sha256,避免全量读取。"""
- size = path.stat().st_size
- chunk = head_tail_mb * 1024 * 1024
- h = hashlib.sha256()
- h.update(str(size).encode())
- with path.open("rb") as f:
- h.update(f.read(chunk))
- if size > chunk * 2:
- f.seek(-chunk, os.SEEK_END)
- h.update(f.read(chunk))
- return h.hexdigest()
- def _file_fp(path: Path, do_hash: bool) -> dict:
- if not path.exists():
- return {"exists": False}
- st = path.stat()
- fp = {
- "exists": True,
- "size": st.st_size,
- "mtime": int(st.st_mtime),
- }
- if do_hash and path.is_file():
- fp["sha256_fast"] = _fast_sha256(path)
- return fp
- def _dir_fp(path: Path, do_hash: bool) -> dict:
- """目录指纹:聚合 (相对路径, size, mtime) 排序后的 sha256。"""
- if not path.exists():
- return {"exists": False}
- files = []
- count = 0
- for p in sorted(path.rglob("*")):
- if p.is_file():
- count += 1
- if count > _MAX_DIR_FILES:
- files.append(("<truncated>", -1, -1))
- break
- st = p.stat()
- files.append((str(p.relative_to(path)), st.st_size, int(st.st_mtime)))
- h = hashlib.sha256()
- total_size = 0
- for rel, size, mtime in files:
- h.update(f"{rel}|{size}|{mtime}\n".encode())
- if size > 0:
- total_size += size
- fp = {
- "exists": True,
- "file_count": len([f for f in files if f[1] >= 0]),
- "total_size": total_size,
- "tree_sha256": h.hexdigest(),
- }
- return fp
- # --------------------------------------------------------------------------- #
- # 各 kind 的指纹采集
- # --------------------------------------------------------------------------- #
- def fp_local_file(entry: dict, defaults: dict) -> dict:
- do_hash = entry.get("hash", defaults.get("hash", False))
- path = Path(os.path.expanduser(entry["path"]))
- fp = _dir_fp(path, do_hash) if path.is_dir() else _file_fp(path, do_hash)
- status = "ok" if fp.get("exists") else "missing"
- return {"status": status, "fingerprint": fp}
- def fp_hf(entry: dict, defaults: dict) -> dict:
- repo_id = entry["repo_id"]
- hub_dir = Path(os.path.expanduser(entry.get("cache_dir", defaults["hf_hub_dir"])))
- repo_dir = hub_dir / ("models--" + repo_id.replace("/", "--"))
- fp: dict = {"repo_id": repo_id}
- local_rev = None
- if repo_dir.exists():
- refs_main = repo_dir / "refs" / "main"
- if refs_main.exists():
- local_rev = refs_main.read_text().strip()
- else:
- snaps = repo_dir / "snapshots"
- if snaps.exists():
- cand = sorted([d.name for d in snaps.iterdir() if d.is_dir()])
- local_rev = cand[-1] if cand else None
- fp["local_revision"] = local_rev
- fp["cached"] = True
- else:
- fp["cached"] = False
- status = "ok" if fp.get("cached") else "missing"
- # 可选:查远端最新 commit
- if entry.get("online", defaults.get("online", False)):
- try:
- from huggingface_hub import HfApi
- remote_sha = HfApi().model_info(repo_id).sha
- fp["remote_revision"] = remote_sha
- if local_rev and remote_sha and local_rev != remote_sha:
- status = "remote_update"
- except Exception as e: # 网络不可达等
- fp["remote_error"] = str(e)
- return {"status": status, "fingerprint": fp}
- def fp_daemon(entry: dict, defaults: dict) -> dict:
- do_hash = entry.get("hash", defaults.get("hash", False))
- timeout = entry.get("daemon_timeout", defaults.get("daemon_timeout", 3))
- url = entry["server_url"].rstrip("/") + "/v1/models"
- fp: dict = {"server_url": entry["server_url"]}
- status = "ok"
- try:
- req = urllib.request.Request(url, headers={"Accept": "application/json"})
- with urllib.request.urlopen(req, timeout=timeout) as resp:
- data = json.loads(resp.read().decode())
- served = [m.get("id") for m in data.get("data", [])]
- fp["reachable"] = True
- fp["served_models"] = served
- expect = entry.get("served_model")
- if expect and expect not in served:
- fp["served_mismatch"] = {"expect": expect, "actual": served}
- status = "changed"
- except Exception as e:
- fp["reachable"] = False
- fp["error"] = str(e)
- status = "unreachable"
- # 本地 GGUF 资产指纹(即使服务不可达也采集,便于发现文件被换)
- assets = entry.get("assets") or []
- if assets:
- fp["assets"] = {}
- for a in assets:
- p = Path(os.path.expanduser(a))
- afp = _file_fp(p, do_hash)
- fp["assets"][a] = afp
- if not afp.get("exists"):
- status = "missing" if status == "ok" else status
- return {"status": status, "fingerprint": fp}
- def fp_mineru(entry: dict, defaults: dict) -> dict:
- do_hash = entry.get("hash", defaults.get("hash", False))
- pkg = entry.get("package", "mineru")
- fp: dict = {"package": pkg}
- status = "ok"
- try:
- import importlib.metadata as md
- fp["package_version"] = md.version(pkg)
- except Exception as e:
- fp["package_error"] = str(e)
- status = "missing"
- root = entry.get("model_root")
- if root:
- p = Path(os.path.expanduser(root))
- rfp = _dir_fp(p, do_hash)
- fp["model_root"] = str(p)
- fp["model_root_fp"] = rfp
- if not rfp.get("exists"):
- status = "missing" if status == "ok" else status
- return {"status": status, "fingerprint": fp}
- _COLLECTORS = {
- "local_file": fp_local_file,
- "hf": fp_hf,
- "daemon": fp_daemon,
- "mineru": fp_mineru,
- }
- # --------------------------------------------------------------------------- #
- # registry / lock 读写与比对
- # --------------------------------------------------------------------------- #
- def load_registry(path: Path) -> dict:
- with path.open("r", encoding="utf-8") as f:
- reg = yaml.safe_load(f)
- reg.setdefault("defaults", {})
- reg.setdefault("models", [])
- return reg
- def collect(reg: dict, online: bool, do_hash: bool) -> dict:
- defaults = dict(reg.get("defaults", {}))
- if online:
- defaults["online"] = True
- if do_hash:
- defaults["hash"] = True
- snapshot = {}
- for entry in reg.get("models", []):
- name = entry["name"]
- if not entry.get("enabled", True):
- snapshot[name] = {"kind": entry.get("kind"), "status": "skipped", "fingerprint": {}}
- continue
- kind = entry.get("kind")
- collector = _COLLECTORS.get(kind)
- if collector is None:
- snapshot[name] = {"kind": kind, "status": "missing",
- "fingerprint": {"error": f"未知 kind: {kind}"}}
- continue
- try:
- result = collector(entry, defaults)
- except Exception as e:
- result = {"status": "missing", "fingerprint": {"error": str(e)}}
- result["kind"] = kind
- result["used_by"] = entry.get("used_by", [])
- snapshot[name] = result
- return snapshot
- def load_lock(path: Path) -> dict:
- if not path.exists():
- return {}
- with path.open("r", encoding="utf-8") as f:
- return json.load(f).get("models", {})
- def save_lock(path: Path, snapshot: dict) -> None:
- payload = {
- "generated_at": datetime.now().astimezone().isoformat(timespec="seconds"),
- "models": snapshot,
- }
- with path.open("w", encoding="utf-8") as f:
- json.dump(payload, f, ensure_ascii=False, indent=2)
- def diff_fp(old: dict, new: dict) -> list:
- """返回发生变化的字段路径(浅层 + 一层嵌套)。"""
- changes = []
- keys = set(old.keys()) | set(new.keys())
- for k in sorted(keys):
- ov, nv = old.get(k), new.get(k)
- if isinstance(ov, dict) and isinstance(nv, dict):
- for sk in sorted(set(ov) | set(nv)):
- if ov.get(sk) != nv.get(sk):
- changes.append(f"{k}.{sk}: {ov.get(sk)} → {nv.get(sk)}")
- elif ov != nv:
- changes.append(f"{k}: {ov} → {nv}")
- return changes
- # --------------------------------------------------------------------------- #
- # 子命令
- # --------------------------------------------------------------------------- #
- def cmd_list(reg: dict) -> int:
- print(f"模型清单({len(reg.get('models', []))} 条):\n")
- for e in reg.get("models", []):
- flag = " " if e.get("enabled", True) else "× "
- used = ";".join(e.get("used_by", []))
- print(f"{flag}[{e.get('kind'):<10}] {e['name']}")
- if used:
- print(f" 用于:{used}")
- print("\n(× 表示 enabled: false,体检时跳过)")
- return 0
- def cmd_show(reg: dict, online: bool, do_hash: bool) -> int:
- snap = collect(reg, online, do_hash)
- print(json.dumps(snap, ensure_ascii=False, indent=2))
- return 0
- def cmd_check(reg: dict, lock_path: Path, online: bool, do_hash: bool, strict: bool) -> int:
- new = collect(reg, online, do_hash)
- old = load_lock(lock_path)
- has_baseline = bool(old)
- problems = 0 # missing / unreachable / remote_update
- changes = 0 # 指纹变化
- news = 0 # 新增(lock 中无记录)
- print(f"模型体检报告 baseline={'有' if has_baseline else '无(首次,请先 update-lock)'}\n")
- for name, cur in new.items():
- kind = cur.get("kind")
- status = cur.get("status")
- if status == "skipped":
- print(f"{SYM['skipped']} {name:<28} [{kind}] 跳过(disabled)")
- continue
- prev = old.get(name)
- if prev is None:
- news += 1
- print(f"{SYM['new']} {name:<28} [{kind}] 新增条目(基线中无记录)")
- continue
- # 状态类问题优先
- if status in ("missing", "unreachable"):
- problems += 1
- detail = cur["fingerprint"].get("error", "")
- print(f"{SYM[status]} {name:<28} [{kind}] {status} {detail}")
- continue
- if status == "remote_update":
- problems += 1
- fpr = cur["fingerprint"]
- print(f"{SYM['remote_update']} {name:<28} [{kind}] HF 远端有更新:"
- f"{fpr.get('local_revision')} → {fpr.get('remote_revision')}")
- continue
- # 指纹比对
- fp_changes = diff_fp(prev.get("fingerprint", {}), cur.get("fingerprint", {}))
- if fp_changes:
- changes += 1
- print(f"{SYM['changed']} {name:<28} [{kind}] 指纹变化:")
- for c in fp_changes:
- print(f" - {c}")
- else:
- print(f"{SYM['ok']} {name:<28} [{kind}] 未变化")
- # 基线中存在但 registry 已删除
- removed = [n for n in old if n not in new]
- for n in removed:
- print(f"{SYM['removed']} {n:<28} 基线中存在但 registry 已移除")
- print("\n" + "-" * 60)
- print(f"问题(缺失/不可达/远端更新)={problems} 指纹变化={changes} "
- f"新增={news} 移除={len(removed)}")
- if not has_baseline:
- print("提示:尚无基线,运行 `update-lock` 生成。")
- return 1
- fail = problems + (changes if strict else 0) + news
- if problems and not strict:
- # 即使非 strict,缺失/不可达也应非 0 退出
- return 1
- return 1 if fail else 0
- def cmd_update_lock(reg: dict, lock_path: Path, online: bool, do_hash: bool) -> int:
- snap = collect(reg, online, do_hash)
- save_lock(lock_path, snap)
- enabled = [n for n, v in snap.items() if v.get("status") != "skipped"]
- print(f"已写入基线 {lock_path}({len(enabled)} 条生效,"
- f"{len(snap) - len(enabled)} 条跳过)")
- return 0
- # --------------------------------------------------------------------------- #
- # 入口
- # --------------------------------------------------------------------------- #
- def main(argv=None) -> int:
- parser = argparse.ArgumentParser(description="模型变更巡检工具 model_doctor")
- parser.add_argument("command", choices=["list", "show", "check", "update-lock"])
- parser.add_argument("--registry", type=Path, default=DEFAULT_REGISTRY,
- help=f"清单文件,默认 {DEFAULT_REGISTRY.name}")
- parser.add_argument("--lock", type=Path, default=DEFAULT_LOCK,
- help=f"基线文件,默认 {DEFAULT_LOCK.name}")
- parser.add_argument("--online", action="store_true",
- help="hf 条目额外查询远端最新 commit 进行比对")
- parser.add_argument("--hash", dest="do_hash", action="store_true",
- help="本地文件/目录额外计算快速 sha256(更敏感但更慢)")
- parser.add_argument("--strict", action="store_true",
- help="check 时把『指纹变化』也视为失败(非 0 退出)")
- args = parser.parse_args(argv)
- reg = load_registry(args.registry)
- if args.command == "list":
- return cmd_list(reg)
- if args.command == "show":
- return cmd_show(reg, args.online, args.do_hash)
- if args.command == "check":
- return cmd_check(reg, args.lock, args.online, args.do_hash, args.strict)
- if args.command == "update-lock":
- return cmd_update_lock(reg, args.lock, args.online, args.do_hash)
- return 2
- if __name__ == "__main__":
- sys.exit(main())
|