command.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. from six import text_type as _text_type
  2. import argparse
  3. import sys
  4. def arg_parser():
  5. parser = argparse.ArgumentParser()
  6. parser.add_argument(
  7. "--model_dir",
  8. "-m",
  9. type=_text_type,
  10. default=None,
  11. help="define model directory path")
  12. parser.add_argument(
  13. "--save_dir",
  14. "-s",
  15. type=_text_type,
  16. default=None,
  17. help="path to save inference model")
  18. parser.add_argument(
  19. "--version",
  20. "-v",
  21. action="store_true",
  22. default=False,
  23. help="get version of PaddleX")
  24. parser.add_argument(
  25. "--export_inference",
  26. "-e",
  27. action="store_true",
  28. default=False,
  29. help="export inference model for C++/Python deployment")
  30. parser.add_argument(
  31. "--export_onnx",
  32. "-eo",
  33. action="store_true",
  34. default=False,
  35. help="export onnx model for deployment")
  36. parser.add_argument(
  37. "--fixed_input_shape",
  38. "-fs",
  39. default=None,
  40. help="export inference model with fixed input shape:[w,h]")
  41. return parser
  42. def main():
  43. import os
  44. os.environ['CUDA_VISIBLE_DEVICES'] = ""
  45. import paddlex as pdx
  46. if len(sys.argv) < 2:
  47. print("Use command 'paddlex -h` to print the help information\n")
  48. return
  49. parser = arg_parser()
  50. args = parser.parse_args()
  51. if args.version:
  52. print("PaddleX-{}".format(pdx.__version__))
  53. print("Repo: https://github.com/PaddlePaddle/PaddleX.git")
  54. print("Email: paddlex@baidu.com")
  55. return
  56. if args.export_inference:
  57. assert args.model_dir is not None, "--model_dir should be defined while exporting inference model"
  58. assert args.save_dir is not None, "--save_dir should be defined to save inference model"
  59. fixed_input_shape = None
  60. if args.fixed_input_shape is not None:
  61. fixed_input_shape = eval(args.fixed_input_shape)
  62. assert len(
  63. fixed_input_shape
  64. ) == 2, "len of fixed input shape must == 2, such as [224,224]"
  65. model = pdx.load_model(args.model_dir, fixed_input_shape)
  66. model.export_inference_model(args.save_dir)
  67. if args.export_onnx:
  68. assert args.model_dir is not None, "--model_dir should be defined while exporting onnx model"
  69. assert args.save_dir is not None, "--save_dir should be defined to create onnx model"
  70. assert args.fixed_input_shape is not None, "--fixed_input_shape should be defined [w,h] to create onnx model, such as [224,224]"
  71. fixed_input_shape = []
  72. if args.fixed_input_shape is not None:
  73. fixed_input_shape = eval(args.fixed_input_shape)
  74. assert len(
  75. fixed_input_shape
  76. ) == 2, "len of fixed input shape must == 2, such as [224,224]"
  77. model = pdx.load_model(args.model_dir, fixed_input_shape)
  78. pdx.convertor.export_onnx_model(model, args.save_dir)
  79. if __name__ == "__main__":
  80. main()