model_doctor.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433
  1. #!/usr/bin/env python
  2. """model_doctor —— 模型变更巡检工具(方案 B)。
  3. 依据手工维护的 model_registry.yaml,对四类模型来源采集「指纹」,
  4. 与基线 models.lock.json 比对,报告哪些模型发生了变化 / 缺失 / 服务不可达,
  5. 以及(可选)HF 远端是否有更新。
  6. 子命令:
  7. list 列出 registry 中的模型条目
  8. show 采集并打印当前指纹(不比对)
  9. check 采集并与 lock 基线比对,输出报告(有变更/缺失则退出码非 0)
  10. update-lock 采集并把当前指纹固化为新的 lock 基线
  11. 用法(建议在 conda 环境 mineru 下):
  12. conda run -n mineru python model_doctor.py check
  13. conda run -n mineru python model_doctor.py check --online --hash
  14. conda run -n mineru python model_doctor.py update-lock
  15. """
  16. from __future__ import annotations
  17. import argparse
  18. import hashlib
  19. import json
  20. import os
  21. import sys
  22. import urllib.request
  23. from datetime import datetime
  24. from pathlib import Path
  25. try:
  26. import yaml
  27. except ImportError:
  28. sys.stderr.write("缺少依赖 PyYAML,请在 mineru 环境安装:conda run -n mineru pip install pyyaml\n")
  29. raise
  30. HERE = Path(__file__).resolve().parent
  31. DEFAULT_REGISTRY = HERE / "model_registry.yaml"
  32. DEFAULT_LOCK = HERE / "models.lock.json"
  33. # 报告状态符号
  34. SYM = {
  35. "ok": "✅",
  36. "changed": "⚠️ ",
  37. "remote_update": "🔺",
  38. "missing": "❌",
  39. "unreachable": "❌",
  40. "new": "🆕",
  41. "removed": "🗑 ",
  42. "skipped": "· ",
  43. }
  44. # 目录指纹遍历上限,避免误指向超大目录卡死
  45. _MAX_DIR_FILES = 20000
  46. # --------------------------------------------------------------------------- #
  47. # 通用指纹工具
  48. # --------------------------------------------------------------------------- #
  49. def _fast_sha256(path: Path, head_tail_mb: int = 8) -> str:
  50. """对大文件取「头 + 尾 + 大小」的快速 sha256,避免全量读取。"""
  51. size = path.stat().st_size
  52. chunk = head_tail_mb * 1024 * 1024
  53. h = hashlib.sha256()
  54. h.update(str(size).encode())
  55. with path.open("rb") as f:
  56. h.update(f.read(chunk))
  57. if size > chunk * 2:
  58. f.seek(-chunk, os.SEEK_END)
  59. h.update(f.read(chunk))
  60. return h.hexdigest()
  61. def _file_fp(path: Path, do_hash: bool) -> dict:
  62. if not path.exists():
  63. return {"exists": False}
  64. st = path.stat()
  65. fp = {
  66. "exists": True,
  67. "size": st.st_size,
  68. "mtime": int(st.st_mtime),
  69. }
  70. if do_hash and path.is_file():
  71. fp["sha256_fast"] = _fast_sha256(path)
  72. return fp
  73. def _dir_fp(path: Path, do_hash: bool) -> dict:
  74. """目录指纹:聚合 (相对路径, size, mtime) 排序后的 sha256。"""
  75. if not path.exists():
  76. return {"exists": False}
  77. files = []
  78. count = 0
  79. for p in sorted(path.rglob("*")):
  80. if p.is_file():
  81. count += 1
  82. if count > _MAX_DIR_FILES:
  83. files.append(("<truncated>", -1, -1))
  84. break
  85. st = p.stat()
  86. files.append((str(p.relative_to(path)), st.st_size, int(st.st_mtime)))
  87. h = hashlib.sha256()
  88. total_size = 0
  89. for rel, size, mtime in files:
  90. h.update(f"{rel}|{size}|{mtime}\n".encode())
  91. if size > 0:
  92. total_size += size
  93. fp = {
  94. "exists": True,
  95. "file_count": len([f for f in files if f[1] >= 0]),
  96. "total_size": total_size,
  97. "tree_sha256": h.hexdigest(),
  98. }
  99. return fp
  100. # --------------------------------------------------------------------------- #
  101. # 各 kind 的指纹采集
  102. # --------------------------------------------------------------------------- #
  103. def fp_local_file(entry: dict, defaults: dict) -> dict:
  104. do_hash = entry.get("hash", defaults.get("hash", False))
  105. path = Path(os.path.expanduser(entry["path"]))
  106. fp = _dir_fp(path, do_hash) if path.is_dir() else _file_fp(path, do_hash)
  107. status = "ok" if fp.get("exists") else "missing"
  108. return {"status": status, "fingerprint": fp}
  109. def fp_hf(entry: dict, defaults: dict) -> dict:
  110. repo_id = entry["repo_id"]
  111. hub_dir = Path(os.path.expanduser(entry.get("cache_dir", defaults["hf_hub_dir"])))
  112. repo_dir = hub_dir / ("models--" + repo_id.replace("/", "--"))
  113. fp: dict = {"repo_id": repo_id}
  114. local_rev = None
  115. if repo_dir.exists():
  116. refs_main = repo_dir / "refs" / "main"
  117. if refs_main.exists():
  118. local_rev = refs_main.read_text().strip()
  119. else:
  120. snaps = repo_dir / "snapshots"
  121. if snaps.exists():
  122. cand = sorted([d.name for d in snaps.iterdir() if d.is_dir()])
  123. local_rev = cand[-1] if cand else None
  124. fp["local_revision"] = local_rev
  125. fp["cached"] = True
  126. else:
  127. fp["cached"] = False
  128. status = "ok" if fp.get("cached") else "missing"
  129. # 可选:查远端最新 commit
  130. if entry.get("online", defaults.get("online", False)):
  131. try:
  132. from huggingface_hub import HfApi
  133. remote_sha = HfApi().model_info(repo_id).sha
  134. fp["remote_revision"] = remote_sha
  135. if local_rev and remote_sha and local_rev != remote_sha:
  136. status = "remote_update"
  137. except Exception as e: # 网络不可达等
  138. fp["remote_error"] = str(e)
  139. return {"status": status, "fingerprint": fp}
  140. def fp_daemon(entry: dict, defaults: dict) -> dict:
  141. do_hash = entry.get("hash", defaults.get("hash", False))
  142. timeout = entry.get("daemon_timeout", defaults.get("daemon_timeout", 3))
  143. url = entry["server_url"].rstrip("/") + "/v1/models"
  144. fp: dict = {"server_url": entry["server_url"]}
  145. status = "ok"
  146. try:
  147. req = urllib.request.Request(url, headers={"Accept": "application/json"})
  148. with urllib.request.urlopen(req, timeout=timeout) as resp:
  149. data = json.loads(resp.read().decode())
  150. served = [m.get("id") for m in data.get("data", [])]
  151. fp["reachable"] = True
  152. fp["served_models"] = served
  153. expect = entry.get("served_model")
  154. if expect and expect not in served:
  155. fp["served_mismatch"] = {"expect": expect, "actual": served}
  156. status = "changed"
  157. except Exception as e:
  158. fp["reachable"] = False
  159. fp["error"] = str(e)
  160. status = "unreachable"
  161. # 本地 GGUF 资产指纹(即使服务不可达也采集,便于发现文件被换)
  162. assets = entry.get("assets") or []
  163. if assets:
  164. fp["assets"] = {}
  165. for a in assets:
  166. p = Path(os.path.expanduser(a))
  167. afp = _file_fp(p, do_hash)
  168. fp["assets"][a] = afp
  169. if not afp.get("exists"):
  170. status = "missing" if status == "ok" else status
  171. return {"status": status, "fingerprint": fp}
  172. def fp_mineru(entry: dict, defaults: dict) -> dict:
  173. do_hash = entry.get("hash", defaults.get("hash", False))
  174. pkg = entry.get("package", "mineru")
  175. fp: dict = {"package": pkg}
  176. status = "ok"
  177. try:
  178. import importlib.metadata as md
  179. fp["package_version"] = md.version(pkg)
  180. except Exception as e:
  181. fp["package_error"] = str(e)
  182. status = "missing"
  183. root = entry.get("model_root")
  184. if root:
  185. p = Path(os.path.expanduser(root))
  186. rfp = _dir_fp(p, do_hash)
  187. fp["model_root"] = str(p)
  188. fp["model_root_fp"] = rfp
  189. if not rfp.get("exists"):
  190. status = "missing" if status == "ok" else status
  191. return {"status": status, "fingerprint": fp}
  192. _COLLECTORS = {
  193. "local_file": fp_local_file,
  194. "hf": fp_hf,
  195. "daemon": fp_daemon,
  196. "mineru": fp_mineru,
  197. }
  198. # --------------------------------------------------------------------------- #
  199. # registry / lock 读写与比对
  200. # --------------------------------------------------------------------------- #
  201. def load_registry(path: Path) -> dict:
  202. with path.open("r", encoding="utf-8") as f:
  203. reg = yaml.safe_load(f)
  204. reg.setdefault("defaults", {})
  205. reg.setdefault("models", [])
  206. return reg
  207. def collect(reg: dict, online: bool, do_hash: bool) -> dict:
  208. defaults = dict(reg.get("defaults", {}))
  209. if online:
  210. defaults["online"] = True
  211. if do_hash:
  212. defaults["hash"] = True
  213. snapshot = {}
  214. for entry in reg.get("models", []):
  215. name = entry["name"]
  216. if not entry.get("enabled", True):
  217. snapshot[name] = {"kind": entry.get("kind"), "status": "skipped", "fingerprint": {}}
  218. continue
  219. kind = entry.get("kind")
  220. collector = _COLLECTORS.get(kind)
  221. if collector is None:
  222. snapshot[name] = {"kind": kind, "status": "missing",
  223. "fingerprint": {"error": f"未知 kind: {kind}"}}
  224. continue
  225. try:
  226. result = collector(entry, defaults)
  227. except Exception as e:
  228. result = {"status": "missing", "fingerprint": {"error": str(e)}}
  229. result["kind"] = kind
  230. result["used_by"] = entry.get("used_by", [])
  231. snapshot[name] = result
  232. return snapshot
  233. def load_lock(path: Path) -> dict:
  234. if not path.exists():
  235. return {}
  236. with path.open("r", encoding="utf-8") as f:
  237. return json.load(f).get("models", {})
  238. def save_lock(path: Path, snapshot: dict) -> None:
  239. payload = {
  240. "generated_at": datetime.now().astimezone().isoformat(timespec="seconds"),
  241. "models": snapshot,
  242. }
  243. with path.open("w", encoding="utf-8") as f:
  244. json.dump(payload, f, ensure_ascii=False, indent=2)
  245. def diff_fp(old: dict, new: dict) -> list:
  246. """返回发生变化的字段路径(浅层 + 一层嵌套)。"""
  247. changes = []
  248. keys = set(old.keys()) | set(new.keys())
  249. for k in sorted(keys):
  250. ov, nv = old.get(k), new.get(k)
  251. if isinstance(ov, dict) and isinstance(nv, dict):
  252. for sk in sorted(set(ov) | set(nv)):
  253. if ov.get(sk) != nv.get(sk):
  254. changes.append(f"{k}.{sk}: {ov.get(sk)} → {nv.get(sk)}")
  255. elif ov != nv:
  256. changes.append(f"{k}: {ov} → {nv}")
  257. return changes
  258. # --------------------------------------------------------------------------- #
  259. # 子命令
  260. # --------------------------------------------------------------------------- #
  261. def cmd_list(reg: dict) -> int:
  262. print(f"模型清单({len(reg.get('models', []))} 条):\n")
  263. for e in reg.get("models", []):
  264. flag = " " if e.get("enabled", True) else "× "
  265. used = ";".join(e.get("used_by", []))
  266. print(f"{flag}[{e.get('kind'):<10}] {e['name']}")
  267. if used:
  268. print(f" 用于:{used}")
  269. print("\n(× 表示 enabled: false,体检时跳过)")
  270. return 0
  271. def cmd_show(reg: dict, online: bool, do_hash: bool) -> int:
  272. snap = collect(reg, online, do_hash)
  273. print(json.dumps(snap, ensure_ascii=False, indent=2))
  274. return 0
  275. def cmd_check(reg: dict, lock_path: Path, online: bool, do_hash: bool, strict: bool) -> int:
  276. new = collect(reg, online, do_hash)
  277. old = load_lock(lock_path)
  278. has_baseline = bool(old)
  279. problems = 0 # missing / unreachable / remote_update
  280. changes = 0 # 指纹变化
  281. news = 0 # 新增(lock 中无记录)
  282. print(f"模型体检报告 baseline={'有' if has_baseline else '无(首次,请先 update-lock)'}\n")
  283. for name, cur in new.items():
  284. kind = cur.get("kind")
  285. status = cur.get("status")
  286. if status == "skipped":
  287. print(f"{SYM['skipped']} {name:<28} [{kind}] 跳过(disabled)")
  288. continue
  289. prev = old.get(name)
  290. if prev is None:
  291. news += 1
  292. print(f"{SYM['new']} {name:<28} [{kind}] 新增条目(基线中无记录)")
  293. continue
  294. # 状态类问题优先
  295. if status in ("missing", "unreachable"):
  296. problems += 1
  297. detail = cur["fingerprint"].get("error", "")
  298. print(f"{SYM[status]} {name:<28} [{kind}] {status} {detail}")
  299. continue
  300. if status == "remote_update":
  301. problems += 1
  302. fpr = cur["fingerprint"]
  303. print(f"{SYM['remote_update']} {name:<28} [{kind}] HF 远端有更新:"
  304. f"{fpr.get('local_revision')} → {fpr.get('remote_revision')}")
  305. continue
  306. # 指纹比对
  307. fp_changes = diff_fp(prev.get("fingerprint", {}), cur.get("fingerprint", {}))
  308. if fp_changes:
  309. changes += 1
  310. print(f"{SYM['changed']} {name:<28} [{kind}] 指纹变化:")
  311. for c in fp_changes:
  312. print(f" - {c}")
  313. else:
  314. print(f"{SYM['ok']} {name:<28} [{kind}] 未变化")
  315. # 基线中存在但 registry 已删除
  316. removed = [n for n in old if n not in new]
  317. for n in removed:
  318. print(f"{SYM['removed']} {n:<28} 基线中存在但 registry 已移除")
  319. print("\n" + "-" * 60)
  320. print(f"问题(缺失/不可达/远端更新)={problems} 指纹变化={changes} "
  321. f"新增={news} 移除={len(removed)}")
  322. if not has_baseline:
  323. print("提示:尚无基线,运行 `update-lock` 生成。")
  324. return 1
  325. fail = problems + (changes if strict else 0) + news
  326. if problems and not strict:
  327. # 即使非 strict,缺失/不可达也应非 0 退出
  328. return 1
  329. return 1 if fail else 0
  330. def cmd_update_lock(reg: dict, lock_path: Path, online: bool, do_hash: bool) -> int:
  331. snap = collect(reg, online, do_hash)
  332. save_lock(lock_path, snap)
  333. enabled = [n for n, v in snap.items() if v.get("status") != "skipped"]
  334. print(f"已写入基线 {lock_path}({len(enabled)} 条生效,"
  335. f"{len(snap) - len(enabled)} 条跳过)")
  336. return 0
  337. # --------------------------------------------------------------------------- #
  338. # 入口
  339. # --------------------------------------------------------------------------- #
  340. def main(argv=None) -> int:
  341. parser = argparse.ArgumentParser(description="模型变更巡检工具 model_doctor")
  342. parser.add_argument("command", choices=["list", "show", "check", "update-lock"])
  343. parser.add_argument("--registry", type=Path, default=DEFAULT_REGISTRY,
  344. help=f"清单文件,默认 {DEFAULT_REGISTRY.name}")
  345. parser.add_argument("--lock", type=Path, default=DEFAULT_LOCK,
  346. help=f"基线文件,默认 {DEFAULT_LOCK.name}")
  347. parser.add_argument("--online", action="store_true",
  348. help="hf 条目额外查询远端最新 commit 进行比对")
  349. parser.add_argument("--hash", dest="do_hash", action="store_true",
  350. help="本地文件/目录额外计算快速 sha256(更敏感但更慢)")
  351. parser.add_argument("--strict", action="store_true",
  352. help="check 时把『指纹变化』也视为失败(非 0 退出)")
  353. args = parser.parse_args(argv)
  354. reg = load_registry(args.registry)
  355. if args.command == "list":
  356. return cmd_list(reg)
  357. if args.command == "show":
  358. return cmd_show(reg, args.online, args.do_hash)
  359. if args.command == "check":
  360. return cmd_check(reg, args.lock, args.online, args.do_hash, args.strict)
  361. if args.command == "update-lock":
  362. return cmd_update_lock(reg, args.lock, args.online, args.do_hash)
  363. return 2
  364. if __name__ == "__main__":
  365. sys.exit(main())