client.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. #!/usr/bin/env python
  2. # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import argparse
  16. import pprint
  17. import sys
  18. from paddlex_hps_client import triton_request, utils
  19. from tritonclient import grpc as triton_grpc
  20. OUTPUT_IMAGE_PATH = "out.jpg"
  21. def parse_image_label_pairs(image_label_pairs):
  22. if len(image_label_pairs) % 2 != 0:
  23. raise ValueError("The number of image-label pairs must be even.")
  24. return [
  25. {"image": utils.prepare_input_file(img), "label": lab}
  26. for img, lab in zip(image_label_pairs[0::2], image_label_pairs[1::2])
  27. ]
  28. def create_triton_client(url):
  29. return triton_grpc.InferenceServerClient(url)
  30. def ensure_no_error(output):
  31. if output["errorCode"] != 0:
  32. print(f"Error code: {output['errorCode']}", file=sys.stderr)
  33. print(f"Error message: {output['errorMsg']}", file=sys.stderr)
  34. sys.exit(1)
  35. def do_index_build(args):
  36. client = create_triton_client(args.url)
  37. if args.image_label_pairs:
  38. image_label_pairs = parse_image_label_pairs(args.image_label_pairs)
  39. else:
  40. image_label_pairs = []
  41. input_ = {"imageLabelPairs": image_label_pairs}
  42. output = triton_request(client, "face-recognition-index-build", input_)
  43. ensure_no_error(output)
  44. result = output["result"]
  45. pprint.pp(result)
  46. def do_index_add(args):
  47. client = create_triton_client(args.url)
  48. image_label_pairs = parse_image_label_pairs(args.image_label_pairs)
  49. input_ = {"imageLabelPairs": image_label_pairs}
  50. if args.index_key is not None:
  51. input_["indexKey"] = args.index_key
  52. output = triton_request(client, "face-recognition-index-add", input_)
  53. ensure_no_error(output)
  54. result = output["result"]
  55. pprint.pp(result)
  56. def do_index_remove(args):
  57. client = create_triton_client(args.url)
  58. input_ = {"ids": args.ids}
  59. if args.index_key is not None:
  60. input_["indexKey"] = args.index_key
  61. output = triton_request(client, "face-recognition-index-remove", input_)
  62. ensure_no_error(output)
  63. result = output["result"]
  64. pprint.pp(result)
  65. def do_infer(args):
  66. client = create_triton_client(args.url)
  67. input_ = {"image": utils.prepare_input_file(args.image)}
  68. if args.index_key is not None:
  69. input_["indexKey"] = args.index_key
  70. if args.no_visualization:
  71. input_["visualize"] = False
  72. output = triton_request(client, "face-recognition-infer", input_)
  73. ensure_no_error(output)
  74. result = output["result"]
  75. utils.save_output_file(result["image"], OUTPUT_IMAGE_PATH)
  76. print(f"Output image saved at {OUTPUT_IMAGE_PATH}")
  77. print("\nDetected faces:")
  78. pprint.pp(result["faces"])
  79. def main():
  80. parser = argparse.ArgumentParser()
  81. parser.add_argument("--url", type=str, default="localhost:8001")
  82. subparsers = parser.add_subparsers(dest="cmd")
  83. parser_index_build = subparsers.add_parser("index-build")
  84. parser_index_build.add_argument("--image-label-pairs", type=str, nargs="+")
  85. parser_index_build.set_defaults(func=do_index_build)
  86. parser_index_add = subparsers.add_parser("index-add")
  87. parser_index_add.add_argument(
  88. "--image-label-pairs", type=str, nargs="+", required=True
  89. )
  90. parser_index_add.add_argument("--index-key", type=str, required=True)
  91. parser_index_add.set_defaults(func=do_index_add)
  92. parser_index_remove = subparsers.add_parser("index-remove")
  93. parser_index_remove.add_argument("--ids", type=int, nargs="+", required=True)
  94. parser_index_remove.add_argument("--index-key", type=str, required=True)
  95. parser_index_remove.set_defaults(func=do_index_remove)
  96. parser_infer = subparsers.add_parser("infer")
  97. parser_infer.add_argument("--image", type=str, required=True)
  98. parser_infer.add_argument("--index-key", type=str)
  99. parser.add_argument("--no-visualization", action="store_true")
  100. parser_infer.set_defaults(func=do_infer)
  101. args = parser.parse_args()
  102. args.func(args)
  103. if __name__ == "__main__":
  104. main()