demo.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. #copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import sys
  15. import os
  16. import argparse
  17. import deploy
  18. def arg_parser():
  19. parser = argparse.ArgumentParser()
  20. parser.add_argument(
  21. "--model_dir",
  22. "-m",
  23. type=str,
  24. default=None,
  25. help="path to openvino model .xml file")
  26. parser.add_argument(
  27. "--img",
  28. "-i",
  29. type=str,
  30. default=None,
  31. help="path to an image files")
  32. parser.add_argument(
  33. "--img_list",
  34. "-l",
  35. type=str,
  36. default=None,
  37. help="Path to a imglist")
  38. parser.add_argument(
  39. "--cfg_dir",
  40. "-c",
  41. type=str,
  42. default=None,
  43. help="Path to PaddelX model yml file")
  44. parser.add_argument(
  45. "--thread_num",
  46. "-t",
  47. type=int,
  48. default=1,
  49. help="Path to PaddelX model yml file")
  50. parser.add_argument(
  51. "--input_shape",
  52. "-ip",
  53. type=str,
  54. default=None,
  55. help=" image input shape of model [NCHW] like [1,3,224,244] ")
  56. return parser
  57. def main():
  58. parser = arg_parser()
  59. args = parser.parse_args()
  60. model_nb = args.model_dir
  61. model_yaml = args.cfg_dir
  62. thread_num = args.thread_num
  63. input_shape = args.input_shape
  64. input_shape = input_shape[1:-1].split(",",3)
  65. shape = list(map(int,input_shape))
  66. #model init
  67. predictor = deploy.Predictor(model_nb,model_yaml,thread_num,shape)
  68. #predict
  69. if(args.img_list != None):
  70. f = open(args.img_list)
  71. lines = f.readlines()
  72. for im_path in lines:
  73. print(im_path)
  74. predictor.predict(im_path.strip('\n'))
  75. f.close()
  76. else:
  77. im_path = args.img
  78. predictor.predict(im_path)
  79. if __name__ == "__main__":
  80. main()