| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134 |
- #!/usr/bin/env python
- # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import argparse
- import pprint
- import sys
- from paddlex_hps_client import triton_request, utils
- from tritonclient import grpc as triton_grpc
- OUTPUT_IMAGE_PATH = "out.jpg"
- def parse_image_label_pairs(image_label_pairs):
- if len(image_label_pairs) % 2 != 0:
- raise ValueError("The number of image-label pairs must be even.")
- return [
- {"image": utils.prepare_input_file(img), "label": lab}
- for img, lab in zip(image_label_pairs[0::2], image_label_pairs[1::2])
- ]
- def create_triton_client(url):
- return triton_grpc.InferenceServerClient(url)
- def ensure_no_error(output):
- if output["errorCode"] != 0:
- print(f"Error code: {output['errorCode']}", file=sys.stderr)
- print(f"Error message: {output['errorMsg']}", file=sys.stderr)
- sys.exit(1)
- def do_index_build(args):
- client = create_triton_client(args.url)
- if args.image_label_pairs:
- image_label_pairs = parse_image_label_pairs(args.image_label_pairs)
- else:
- image_label_pairs = []
- input_ = {"imageLabelPairs": image_label_pairs}
- output = triton_request(client, "face-recognition-index-build", input_)
- ensure_no_error(output)
- result = output["result"]
- pprint.pp(result)
- def do_index_add(args):
- client = create_triton_client(args.url)
- image_label_pairs = parse_image_label_pairs(args.image_label_pairs)
- input_ = {"imageLabelPairs": image_label_pairs}
- if args.index_key is not None:
- input_["indexKey"] = args.index_key
- output = triton_request(client, "face-recognition-index-add", input_)
- ensure_no_error(output)
- result = output["result"]
- pprint.pp(result)
- def do_index_remove(args):
- client = create_triton_client(args.url)
- input_ = {"ids": args.ids}
- if args.index_key is not None:
- input_["indexKey"] = args.index_key
- output = triton_request(client, "face-recognition-index-remove", input_)
- ensure_no_error(output)
- result = output["result"]
- pprint.pp(result)
- def do_infer(args):
- client = create_triton_client(args.url)
- input_ = {"image": utils.prepare_input_file(args.image)}
- if args.index_key is not None:
- input_["indexKey"] = args.index_key
- if args.no_visualization:
- input_["visualize"] = False
- output = triton_request(client, "face-recognition-infer", input_)
- ensure_no_error(output)
- result = output["result"]
- utils.save_output_file(result["image"], OUTPUT_IMAGE_PATH)
- print(f"Output image saved at {OUTPUT_IMAGE_PATH}")
- print("\nDetected faces:")
- pprint.pp(result["faces"])
- def main():
- parser = argparse.ArgumentParser()
- parser.add_argument("--url", type=str, default="localhost:8001")
- subparsers = parser.add_subparsers(dest="cmd")
- parser_index_build = subparsers.add_parser("index-build")
- parser_index_build.add_argument("--image-label-pairs", type=str, nargs="+")
- parser_index_build.set_defaults(func=do_index_build)
- parser_index_add = subparsers.add_parser("index-add")
- parser_index_add.add_argument(
- "--image-label-pairs", type=str, nargs="+", required=True
- )
- parser_index_add.add_argument("--index-key", type=str, required=True)
- parser_index_add.set_defaults(func=do_index_add)
- parser_index_remove = subparsers.add_parser("index-remove")
- parser_index_remove.add_argument("--ids", type=int, nargs="+", required=True)
- parser_index_remove.add_argument("--index-key", type=str, required=True)
- parser_index_remove.set_defaults(func=do_index_remove)
- parser_infer = subparsers.add_parser("infer")
- parser_infer.add_argument("--image", type=str, required=True)
- parser_infer.add_argument("--index-key", type=str)
- parser.add_argument("--no-visualization", action="store_true")
- parser_infer.set_defaults(func=do_infer)
- args = parser.parse_args()
- args.func(args)
- if __name__ == "__main__":
- main()
|