assemble.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. #!/usr/bin/env python
  2. # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import argparse
  16. import json
  17. import pathlib
  18. import shutil
  19. import subprocess
  20. import sys
  21. import tarfile
  22. import tempfile
  23. TARGET_NAME_PATTERN = "paddlex_hps_{pipeline_name}_sdk"
  24. ARCHIVE_SUFFIX = ".tar.gz"
  25. BASE_DIR = pathlib.Path.cwd()
  26. PIPELINES_DIR = BASE_DIR / "pipelines"
  27. COMMON_DIR = BASE_DIR / "common"
  28. CLIENT_LIB_PATH = BASE_DIR / "paddlex-hps-client"
  29. VERSIONS_PATH = BASE_DIR / "versions.json"
  30. OUTPUT_DIR = BASE_DIR / "output"
  31. if __name__ == "__main__":
  32. parser = argparse.ArgumentParser()
  33. parser.add_argument("pipeline_names", type=str, metavar="pipeline-names", nargs="*")
  34. parser.add_argument("--all", action="store_true")
  35. parser.add_argument(
  36. "--no-server",
  37. action="store_true",
  38. )
  39. parser.add_argument(
  40. "--no-client",
  41. action="store_true",
  42. )
  43. args = parser.parse_args()
  44. if args.all and args.pipeline_names:
  45. print(
  46. "Cannot specify `--all` and `pipeline-names` at the same time",
  47. file=sys.stderr,
  48. )
  49. sys.exit(2)
  50. if args.all:
  51. pipeline_names = [p.name for p in PIPELINES_DIR.iterdir()]
  52. else:
  53. pipeline_names = args.pipeline_names
  54. if not pipeline_names:
  55. sys.exit(0)
  56. with_server = not args.no_server
  57. with_client = not args.no_client
  58. OUTPUT_DIR.mkdir(exist_ok=True)
  59. if with_client:
  60. # HACK: Make a copy to avoid creating files in the source directory
  61. with tempfile.TemporaryDirectory() as td:
  62. tmp_client_lib_path = shutil.copytree(
  63. CLIENT_LIB_PATH, str(pathlib.Path(td, CLIENT_LIB_PATH.name))
  64. )
  65. subprocess.check_call(
  66. [
  67. sys.executable,
  68. "-m",
  69. "pip",
  70. "wheel",
  71. "--no-deps",
  72. "--wheel-dir",
  73. str(OUTPUT_DIR),
  74. tmp_client_lib_path,
  75. ]
  76. )
  77. client_lib_whl_path = next(OUTPUT_DIR.glob("paddlex_hps_client*.whl"))
  78. with VERSIONS_PATH.open("r", encoding="utf-8") as f:
  79. versions = json.load(f)
  80. for pipeline_name in pipeline_names:
  81. print("=" * 30)
  82. print(f"Pipeline: {pipeline_name}")
  83. pipeline_dir = PIPELINES_DIR / pipeline_name
  84. if not pipeline_dir.exists():
  85. sys.exit(f"{pipeline_dir} not found")
  86. if pipeline_name not in versions:
  87. sys.exit(f"Version is missing for {repr(pipeline_name)}")
  88. tgt_name = TARGET_NAME_PATTERN.format(pipeline_name=pipeline_name)
  89. tgt_dir = OUTPUT_DIR / tgt_name
  90. if tgt_dir.exists():
  91. print(f"Removing existing target directory: {tgt_dir}")
  92. shutil.rmtree(tgt_dir)
  93. if with_server:
  94. shutil.copytree(pipeline_dir / "server", tgt_dir / "server")
  95. shutil.copy(COMMON_DIR / "server.sh", tgt_dir / "server")
  96. for dir_ in (tgt_dir / "server" / "model_repo").iterdir():
  97. if dir_.is_dir():
  98. if (dir_ / "config.pbtxt").exists():
  99. continue
  100. for device_type in ("cpu", "gpu"):
  101. config_path = dir_ / f"config_{device_type}.pbtxt"
  102. if not config_path.exists():
  103. shutil.copy(
  104. COMMON_DIR / f"config_{device_type}.pbtxt", config_path
  105. )
  106. if with_client:
  107. shutil.copytree(pipeline_dir / "client", tgt_dir / "client")
  108. shutil.copy(client_lib_whl_path, tgt_dir / "client")
  109. version = versions[pipeline_name]
  110. (tgt_dir / "version.txt").write_text(version + "\n", encoding="utf-8")
  111. arch_path = tgt_dir.with_suffix(ARCHIVE_SUFFIX)
  112. print(f"Creating archive: {arch_path}")
  113. with tarfile.open(arch_path, "w:gz") as tar:
  114. tar.add(tgt_dir, arcname=tgt_dir.name)
  115. print("Done" + "\n" + "=" * 30)