demo.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import sys
  15. import os
  16. import argparse
  17. import deploy
  18. def arg_parser():
  19. parser = argparse.ArgumentParser()
  20. parser.add_argument(
  21. "--model_dir",
  22. "-m",
  23. type=str,
  24. default=None,
  25. help="path to openvino model .xml file")
  26. parser.add_argument(
  27. "--img", "-i", type=str, default=None, help="path to an image files")
  28. parser.add_argument(
  29. "--img_list", "-l", type=str, default=None, help="Path to a imglist")
  30. parser.add_argument(
  31. "--cfg_file",
  32. "-c",
  33. type=str,
  34. default=None,
  35. help="Path to PaddelX model yml file")
  36. parser.add_argument(
  37. "--thread_num",
  38. "-t",
  39. type=int,
  40. default=1,
  41. help="Path to PaddelX model yml file")
  42. parser.add_argument(
  43. "--input_shape",
  44. "-ip",
  45. type=str,
  46. default=None,
  47. help=" image input shape of model [NCHW] like [1,3,224,244] ")
  48. return parser
  49. def main():
  50. parser = arg_parser()
  51. args = parser.parse_args()
  52. model_nb = args.model_dir
  53. model_yaml = args.cfg_file
  54. thread_num = args.thread_num
  55. input_shape = args.input_shape
  56. input_shape = input_shape[1:-1].split(",", 3)
  57. shape = list(map(int, input_shape))
  58. #model init
  59. predictor = deploy.Predictor(model_nb, model_yaml, thread_num, shape)
  60. #predict
  61. if (args.img_list != None):
  62. f = open(args.img_list)
  63. lines = f.readlines()
  64. for im_path in lines:
  65. print(im_path)
  66. predictor.predict(im_path.strip('\n'))
  67. f.close()
  68. else:
  69. im_path = args.img
  70. predictor.predict(im_path)
  71. if __name__ == "__main__":
  72. main()