client.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. #!/usr/bin/env python
  2. import argparse
  3. import pprint
  4. import sys
  5. from paddlex_hps_client import triton_request, utils
  6. from tritonclient import grpc as triton_grpc
  7. OUTPUT_IMAGE_PATH = "out.jpg"
  8. def parse_image_label_pairs(image_label_pairs):
  9. if len(image_label_pairs) % 2 != 0:
  10. raise ValueError("The number of image-label pairs must be even.")
  11. return [
  12. {"image": utils.prepare_input_file(img), "label": lab}
  13. for img, lab in zip(image_label_pairs[0::2], image_label_pairs[1::2])
  14. ]
  15. def create_triton_client(url):
  16. return triton_grpc.InferenceServerClient(url)
  17. def ensure_no_error(output):
  18. if output["errorCode"] != 0:
  19. print(f"Error code: {output['errorCode']}", file=sys.stderr)
  20. print(f"Error message: {output['errorMsg']}", file=sys.stderr)
  21. sys.exit(1)
  22. def do_index_build(args):
  23. client = create_triton_client(args.url)
  24. if args.image_label_pairs:
  25. image_label_pairs = parse_image_label_pairs(args.image_label_pairs)
  26. else:
  27. image_label_pairs = []
  28. input_ = {"imageLabelPairs": image_label_pairs}
  29. output = triton_request(client, "shitu-index-build", input_)
  30. ensure_no_error(output)
  31. result = output["result"]
  32. pprint.pp(result)
  33. def do_index_add(args):
  34. client = create_triton_client(args.url)
  35. image_label_pairs = parse_image_label_pairs(args.image_label_pairs)
  36. input_ = {"imageLabelPairs": image_label_pairs}
  37. if args.index_key is not None:
  38. input_["indexKey"] = args.index_key
  39. output = triton_request(client, "shitu-index-add", input_)
  40. ensure_no_error(output)
  41. result = output["result"]
  42. pprint.pp(result)
  43. def do_index_remove(args):
  44. client = create_triton_client(args.url)
  45. input_ = {"ids": args.ids}
  46. if args.index_key is not None:
  47. input_["indexKey"] = args.index_key
  48. output = triton_request(client, "shitu-index-remove", input_)
  49. ensure_no_error(output)
  50. result = output["result"]
  51. pprint.pp(result)
  52. def do_infer(args):
  53. client = create_triton_client(args.url)
  54. input_ = {"image": utils.prepare_input_file(args.image)}
  55. if args.index_key is not None:
  56. input_["indexKey"] = args.index_key
  57. if args.no_visualization:
  58. input_["visualize"] = False
  59. output = triton_request(client, "shitu-infer", input_)
  60. ensure_no_error(output)
  61. result = output["result"]
  62. utils.save_output_file(result["image"], OUTPUT_IMAGE_PATH)
  63. print(f"Output image saved at {OUTPUT_IMAGE_PATH}")
  64. print("\nDetected objects:")
  65. pprint.pp(result["detectedObjects"])
  66. def main():
  67. parser = argparse.ArgumentParser()
  68. parser.add_argument("--url", type=str, default="localhost:8001")
  69. subparsers = parser.add_subparsers(dest="cmd")
  70. parser_index_build = subparsers.add_parser("index-build")
  71. parser_index_build.add_argument("--image-label-pairs", type=str, nargs="+")
  72. parser_index_build.set_defaults(func=do_index_build)
  73. parser_index_add = subparsers.add_parser("index-add")
  74. parser_index_add.add_argument(
  75. "--image-label-pairs", type=str, nargs="+", required=True
  76. )
  77. parser_index_add.add_argument("--index-key", type=str, required=True)
  78. parser_index_add.set_defaults(func=do_index_add)
  79. parser_index_remove = subparsers.add_parser("index-remove")
  80. parser_index_remove.add_argument("--ids", type=int, nargs="+", required=True)
  81. parser_index_remove.add_argument("--index-key", type=str, required=True)
  82. parser_index_remove.set_defaults(func=do_index_remove)
  83. parser_infer = subparsers.add_parser("infer")
  84. parser_infer.add_argument("--image", type=str, required=True)
  85. parser_infer.add_argument("--index-key", type=str)
  86. parser.add_argument("--no-visualization", action="store_true")
  87. parser_infer.set_defaults(func=do_infer)
  88. args = parser.parse_args()
  89. args.func(args)
  90. if __name__ == "__main__":
  91. main()