command.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. from six import text_type as _text_type
  2. import argparse
  3. import sys
  4. def arg_parser():
  5. parser = argparse.ArgumentParser()
  6. parser.add_argument(
  7. "--model_dir",
  8. "-m",
  9. type=_text_type,
  10. default=None,
  11. help="define model directory path")
  12. parser.add_argument(
  13. "--save_dir",
  14. "-s",
  15. type=_text_type,
  16. default=None,
  17. help="path to save inference model")
  18. parser.add_argument(
  19. "--version",
  20. "-v",
  21. action="store_true",
  22. default=False,
  23. help="get version of PaddleX")
  24. parser.add_argument(
  25. "--export_inference",
  26. "-e",
  27. action="store_true",
  28. default=False,
  29. help="export inference model for C++/Python deployment")
  30. parser.add_argument(
  31. "--fixed_input_shape",
  32. "-fs",
  33. default=None,
  34. help="export inference model with fixed input shape:[w,h]")
  35. return parser
  36. def main():
  37. import os
  38. os.environ['CUDA_VISIBLE_DEVICES'] = ""
  39. import paddlex as pdx
  40. if len(sys.argv) < 2:
  41. print("Use command 'paddlex -h` to print the help information\n")
  42. return
  43. parser = arg_parser()
  44. args = parser.parse_args()
  45. if args.version:
  46. print("PaddleX-{}".format(pdx.__version__))
  47. print("Repo: https://github.com/PaddlePaddle/PaddleX.git")
  48. print("Email: paddlex@baidu.com")
  49. return
  50. if args.export_inference:
  51. assert args.model_dir is not None, "--model_dir should be defined while exporting inference model"
  52. assert args.save_dir is not None, "--save_dir should be defined to save inference model"
  53. fixed_input_shape = eval(args.fixed_input_shape)
  54. assert len(
  55. fixed_input_shape) == 2, "len of fixed input shape must == 2"
  56. model = pdx.load_model(args.model_dir, fixed_input_shape)
  57. model.export_inference_model(args.save_dir)
  58. if args.export_onnx:
  59. assert args.model_dir is not None, "--model_dir should be defined while exporting onnx model"
  60. assert args.save_dir is not None, "--save_dir should be defined to save onnx model"
  61. fixed_input_shape = eval(args.fixed_input_shape)
  62. assert len(
  63. fixed_input_shape) == 2, "len of fixed input shape must == 2"
  64. model = pdx.load_model(args.model_dir, fixed_input_shape)
  65. model.export_onnx_model(args.save_dir)
  66. if __name__ == "__main__":
  67. main()