assemble.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. #!/usr/bin/env python
  2. import argparse
  3. import json
  4. import pathlib
  5. import shutil
  6. import subprocess
  7. import sys
  8. import tarfile
  9. import tempfile
  10. TARGET_NAME_PATTERN = "paddlex_hps_{pipeline_name}_sdk"
  11. ARCHIVE_SUFFIX = ".tar.gz"
  12. BASE_DIR = pathlib.Path.cwd()
  13. PIPELINES_DIR = BASE_DIR / "pipelines"
  14. COMMON_DIR = BASE_DIR / "common"
  15. CLIENT_LIB_PATH = BASE_DIR / "paddlex-hps-client"
  16. VERSIONS_PATH = BASE_DIR / "versions.json"
  17. OUTPUT_DIR = BASE_DIR / "output"
  18. if __name__ == "__main__":
  19. parser = argparse.ArgumentParser()
  20. parser.add_argument("pipeline_names", type=str, metavar="pipeline-names", nargs="*")
  21. parser.add_argument("--all", action="store_true")
  22. parser.add_argument(
  23. "--no-server",
  24. action="store_true",
  25. )
  26. parser.add_argument(
  27. "--no-client",
  28. action="store_true",
  29. )
  30. args = parser.parse_args()
  31. if args.all and args.pipeline_names:
  32. print(
  33. "Cannot specify `--all` and `pipeline-names` at the same time",
  34. file=sys.stderr,
  35. )
  36. sys.exit(2)
  37. if args.all:
  38. pipeline_names = [p.name for p in PIPELINES_DIR.iterdir()]
  39. else:
  40. pipeline_names = args.pipeline_names
  41. if not pipeline_names:
  42. sys.exit(0)
  43. with_server = not args.no_server
  44. with_client = not args.no_client
  45. OUTPUT_DIR.mkdir(exist_ok=True)
  46. if with_client:
  47. # HACK: Make a copy to avoid creating files in the source directory
  48. with tempfile.TemporaryDirectory() as td:
  49. tmp_client_lib_path = shutil.copytree(
  50. CLIENT_LIB_PATH, str(pathlib.Path(td, CLIENT_LIB_PATH.name))
  51. )
  52. subprocess.check_call(
  53. [
  54. sys.executable,
  55. "-m",
  56. "pip",
  57. "wheel",
  58. "--no-deps",
  59. "--wheel-dir",
  60. str(OUTPUT_DIR),
  61. tmp_client_lib_path,
  62. ]
  63. )
  64. client_lib_whl_path = next(OUTPUT_DIR.glob("paddlex_hps_client*.whl"))
  65. with VERSIONS_PATH.open("r", encoding="utf-8") as f:
  66. versions = json.load(f)
  67. for pipeline_name in pipeline_names:
  68. print("=" * 30)
  69. print(f"Pipeline: {pipeline_name}")
  70. pipeline_dir = PIPELINES_DIR / pipeline_name
  71. if not pipeline_dir.exists():
  72. sys.exit(f"{pipeline_dir} not found")
  73. if pipeline_name not in versions:
  74. sys.exit(f"Version is missing for {repr(pipeline_name)}")
  75. tgt_name = TARGET_NAME_PATTERN.format(pipeline_name=pipeline_name)
  76. tgt_dir = OUTPUT_DIR / tgt_name
  77. if tgt_dir.exists():
  78. print(f"Removing existing target directory: {tgt_dir}")
  79. shutil.rmtree(tgt_dir)
  80. if with_server:
  81. shutil.copytree(pipeline_dir / "server", tgt_dir / "server")
  82. shutil.copy(COMMON_DIR / "server.sh", tgt_dir / "server")
  83. for dir_ in (tgt_dir / "server" / "model_repo").iterdir():
  84. if dir_.is_dir():
  85. if (dir_ / "config.pbtxt").exists():
  86. continue
  87. for device_type in ("cpu", "gpu"):
  88. config_path = dir_ / f"config_{device_type}.pbtxt"
  89. if not config_path.exists():
  90. shutil.copy(
  91. COMMON_DIR / f"config_{device_type}.pbtxt", config_path
  92. )
  93. if with_client:
  94. shutil.copytree(pipeline_dir / "client", tgt_dir / "client")
  95. shutil.copy(client_lib_whl_path, tgt_dir / "client")
  96. version = versions[pipeline_name]
  97. (tgt_dir / "version.txt").write_text(version + "\n", encoding="utf-8")
  98. arch_path = tgt_dir.with_suffix(ARCHIVE_SUFFIX)
  99. print(f"Creating archive: {arch_path}")
  100. with tarfile.open(arch_path, "w:gz") as tar:
  101. tar.add(tgt_dir, arcname=tgt_dir.name)
  102. print("Done" + "\n" + "=" * 30)