command.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. from six import text_type as _text_type
  2. import argparse
  3. import sys
  4. def arg_parser():
  5. parser = argparse.ArgumentParser()
  6. parser.add_argument(
  7. "--model_dir",
  8. "-m",
  9. type=_text_type,
  10. default=None,
  11. help="define model directory path")
  12. parser.add_argument(
  13. "--save_dir",
  14. "-s",
  15. type=_text_type,
  16. default=None,
  17. help="path to save inference model")
  18. parser.add_argument(
  19. "--version",
  20. "-v",
  21. action="store_true",
  22. default=False,
  23. help="get version of PaddleX")
  24. parser.add_argument(
  25. "--export_inference",
  26. "-e",
  27. action="store_true",
  28. default=False,
  29. help="export inference model for C++/Python deployment")
  30. parser.add_argument(
  31. "--fixed_input_shape",
  32. "-fs",
  33. default=None,
  34. help="export inference model with fixed input shape(TensorRT need)")
  35. return parser
  36. def main():
  37. import os
  38. os.environ['CUDA_VISIBLE_DEVICES'] = ""
  39. import paddlex as pdx
  40. if len(sys.argv) < 2:
  41. print("Use command 'paddlex -h` to print the help information\n")
  42. return
  43. parser = arg_parser()
  44. args = parser.parse_args()
  45. if args.version:
  46. print("PaddleX-{}".format(pdx.__version__))
  47. print("Repo: https://github.com/PaddlePaddle/PaddleX.git")
  48. print("Email: paddlex@baidu.com")
  49. return
  50. if args.export_inference:
  51. assert args.model_dir is not None, "--model_dir should be defined while exporting inference model"
  52. assert args.save_dir is not None, "--save_dir should be defined to save inference model"
  53. fixed_input_shape = eval(args.fixed_input_shape)
  54. assert len(fixed_input_shape) == 2, "len of fixed input shape must == 2"
  55. model = pdx.load_model(args.model_dir, fixed_input_shape)
  56. model.export_inference_model(args.save_dir, fixed_input_shape)
  57. if __name__ == "__main__":
  58. main()