command.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. from six import text_type as _text_type
  2. import argparse
  3. import sys
  4. def arg_parser():
  5. parser = argparse.ArgumentParser()
  6. parser.add_argument(
  7. "--model_dir",
  8. "-m",
  9. type=_text_type,
  10. default=None,
  11. help="define model directory path")
  12. parser.add_argument(
  13. "--save_dir",
  14. "-s",
  15. type=_text_type,
  16. default=None,
  17. help="path to save inference model")
  18. parser.add_argument(
  19. "--version",
  20. "-v",
  21. action="store_true",
  22. default=False,
  23. help="get version of PaddleX")
  24. parser.add_argument(
  25. "--export_inference",
  26. "-e",
  27. action="store_true",
  28. default=False,
  29. help="export inference model for C++/Python deployment")
  30. parser.add_argument(
  31. "--export_onnx",
  32. "-eo",
  33. action="store_true",
  34. default=False,
  35. help="export onnx model for deployment")
  36. parser.add_argument(
  37. "--fixed_input_shape",
  38. "-fs",
  39. default=None,
  40. help="export inference model with fixed input shape:[w,h]")
  41. return parser
  42. def main():
  43. import os
  44. os.environ['CUDA_VISIBLE_DEVICES'] = ""
  45. import paddlex as pdx
  46. if len(sys.argv) < 2:
  47. print("Use command 'paddlex -h` to print the help information\n")
  48. return
  49. parser = arg_parser()
  50. args = parser.parse_args()
  51. if args.version:
  52. print("PaddleX-{}".format(pdx.__version__))
  53. print("Repo: https://github.com/PaddlePaddle/PaddleX.git")
  54. print("Email: paddlex@baidu.com")
  55. return
  56. if args.export_inference:
  57. assert args.model_dir is not None, "--model_dir should be defined while exporting inference model"
  58. assert args.save_dir is not None, "--save_dir should be defined to save inference model"
  59. fixed_input_shape = None
  60. if args.fixed_input_shape is not None:
  61. fixed_input_shape = eval(args.fixed_input_shape)
  62. assert len(
  63. fixed_input_shape
  64. ) == 2, "len of fixed input shape must == 2, such as [224,224]"
  65. else:
  66. fixed_input_shape = None
  67. model = pdx.load_model(args.model_dir, fixed_input_shape)
  68. model.export_inference_model(args.save_dir)
  69. if args.export_onnx:
  70. assert args.model_dir is not None, "--model_dir should be defined while exporting onnx model"
  71. assert args.save_dir is not None, "--save_dir should be defined to create onnx model"
  72. assert args.fixed_input_shape is not None, "--fixed_input_shape should be defined [w,h] to create onnx model, such as [224,224]"
  73. fixed_input_shape = []
  74. if args.fixed_input_shape is not None:
  75. fixed_input_shape = eval(args.fixed_input_shape)
  76. assert len(
  77. fixed_input_shape
  78. ) == 2, "len of fixed input shape must == 2, such as [224,224]"
  79. model = pdx.load_model(args.model_dir, fixed_input_shape)
  80. pdx.convertor.export_onnx_model(model, args.save_dir)
  81. if __name__ == "__main__":
  82. main()