reader_deploy.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. # coding: utf8
  2. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import os
  16. import os.path as osp
  17. import numpy as np
  18. import math
  19. import cv2
  20. import argparse
  21. from paddlex.seg import transforms
  22. import paddlex as pdx
  23. METER_SHAPE = 512
  24. CIRCLE_CENTER = [256, 256]
  25. CIRCLE_RADIUS = 250
  26. PI = 3.1415926536
  27. LINE_HEIGHT = 120
  28. LINE_WIDTH = 1570
  29. TYPE_THRESHOLD = 40
  30. METER_CONFIG = [{
  31. 'scale_value': 25.0 / 50.0,
  32. 'range': 25.0,
  33. 'unit': "(MPa)"
  34. }, {
  35. 'scale_value': 1.6 / 32.0,
  36. 'range': 1.6,
  37. 'unit': "(MPa)"
  38. }]
  39. def parse_args():
  40. parser = argparse.ArgumentParser(description='Meter Reader Infering')
  41. parser.add_argument(
  42. '--detector_dir',
  43. dest='detector_dir',
  44. help='The directory of models to do detection',
  45. type=str)
  46. parser.add_argument(
  47. '--segmenter_dir',
  48. dest='segmenter_dir',
  49. help='The directory of models to do segmentation',
  50. type=str)
  51. parser.add_argument(
  52. '--image_dir',
  53. dest='image_dir',
  54. help='The directory of images to be infered',
  55. type=str,
  56. default=None)
  57. parser.add_argument(
  58. '--image',
  59. dest='image',
  60. help='The image to be infered',
  61. type=str,
  62. default=None)
  63. parser.add_argument(
  64. '--use_camera',
  65. dest='use_camera',
  66. help='Whether use camera or not',
  67. action='store_true')
  68. parser.add_argument(
  69. '--camera_id',
  70. dest='camera_id',
  71. type=int,
  72. help='The camera id',
  73. default=0)
  74. parser.add_argument(
  75. '--use_erode',
  76. dest='use_erode',
  77. help='Whether erode the predicted lable map',
  78. action='store_true')
  79. parser.add_argument(
  80. '--erode_kernel',
  81. dest='erode_kernel',
  82. help='Erode kernel size',
  83. type=int,
  84. default=4)
  85. parser.add_argument(
  86. '--save_dir',
  87. dest='save_dir',
  88. help='The directory for saving the inference results',
  89. type=str,
  90. default='./output/result')
  91. parser.add_argument(
  92. '--score_threshold',
  93. dest='score_threshold',
  94. help="Detected bbox whose score is lower than this threshlod is filtered",
  95. type=float,
  96. default=0.5)
  97. parser.add_argument(
  98. '--seg_batch_size',
  99. dest='seg_batch_size',
  100. help="Segmentation batch size",
  101. type=int,
  102. default=2)
  103. return parser.parse_args()
  104. def is_pic(img_name):
  105. valid_suffix = ['JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png']
  106. suffix = img_name.split('.')[-1]
  107. if suffix not in valid_suffix:
  108. return False
  109. return True
  110. class MeterReader:
  111. def __init__(self, detector_dir, segmenter_dir):
  112. if not osp.exists(detector_dir):
  113. raise Exception("Model path {} does not exist".format(
  114. detector_dir))
  115. if not osp.exists(segmenter_dir):
  116. raise Exception("Model path {} does not exist".format(
  117. segmenter_dir))
  118. self.detector = pdx.deploy.Predictor(detector_dir)
  119. self.segmenter = pdx.deploy.Predictor(segmenter_dir)
  120. # Because we will resize images with (METER_SHAPE, METER_SHAPE) before fed into the segmenter,
  121. # here the transform is composed of normalization only.
  122. self.seg_transforms = transforms.Compose([transforms.Normalize()])
  123. def predict(self,
  124. im_file,
  125. save_dir='./',
  126. use_erode=True,
  127. erode_kernel=4,
  128. score_threshold=0.5,
  129. seg_batch_size=2):
  130. if isinstance(im_file, str):
  131. im = cv2.imread(im_file).astype('float32')
  132. else:
  133. im = im_file.copy()
  134. # Get detection results
  135. det_results = self.detector.predict(im)
  136. # Filter bbox whose score is lower than score_threshold
  137. filtered_results = list()
  138. for res in det_results:
  139. if res['score'] > score_threshold:
  140. filtered_results.append(res)
  141. resized_meters = list()
  142. for res in filtered_results:
  143. # Crop the bbox area
  144. xmin, ymin, w, h = res['bbox']
  145. xmin = max(0, int(xmin))
  146. ymin = max(0, int(ymin))
  147. xmax = min(im.shape[1], int(xmin + w - 1))
  148. ymax = min(im.shape[0], int(ymin + h - 1))
  149. sub_image = im[ymin:(ymax + 1), xmin:(xmax + 1), :]
  150. # Resize the image with shape (METER_SHAPE, METER_SHAPE)
  151. meter_shape = sub_image.shape
  152. scale_x = float(METER_SHAPE) / float(meter_shape[1])
  153. scale_y = float(METER_SHAPE) / float(meter_shape[0])
  154. meter_meter = cv2.resize(
  155. sub_image,
  156. None,
  157. None,
  158. fx=scale_x,
  159. fy=scale_y,
  160. interpolation=cv2.INTER_LINEAR)
  161. meter_meter = meter_meter.astype('float32')
  162. resized_meters.append(meter_meter)
  163. meter_num = len(resized_meters)
  164. seg_results = list()
  165. for i in range(0, meter_num, seg_batch_size):
  166. im_size = min(meter_num, i + seg_batch_size)
  167. meter_images = list()
  168. for j in range(i, im_size):
  169. meter_images.append(resized_meters[j - i])
  170. result = self.segmenter.batch_predict(
  171. transforms=self.seg_transforms,
  172. img_file_list=meter_images)
  173. if use_erode:
  174. kernel = np.ones((erode_kernel, erode_kernel), np.uint8)
  175. for i in range(len(result)):
  176. result[i]['label_map'] = cv2.erode(result[i]['label_map'],
  177. kernel)
  178. seg_results.extend(result)
  179. results = list()
  180. for i, seg_result in enumerate(seg_results):
  181. result = self.read_process(seg_result['label_map'])
  182. results.append(result)
  183. meter_values = list()
  184. for i, result in enumerate(results):
  185. if result['scale_num'] > TYPE_THRESHOLD:
  186. value = result['scales'] * METER_CONFIG[0]['scale_value']
  187. else:
  188. value = result['scales'] * METER_CONFIG[1]['scale_value']
  189. meter_values.append(value)
  190. print("-- Meter {} -- result: {} --\n".format(i, value))
  191. # visualize the results
  192. visual_results = list()
  193. for i, res in enumerate(filtered_results):
  194. # Use `score` to represent the meter value
  195. res['score'] = meter_values[i]
  196. visual_results.append(res)
  197. pdx.det.visualize(im_file, visual_results, -1, save_dir=save_dir)
  198. def read_process(self, label_maps):
  199. # Convert the circular meter into rectangular meter
  200. line_images = self.creat_line_image(label_maps)
  201. # Convert the 2d meter into 1d meter
  202. scale_data, pointer_data = self.convert_1d_data(line_images)
  203. # Fliter scale data whose value is lower than the mean value
  204. self.scale_mean_filtration(scale_data)
  205. # Get scale_num, scales and ratio of meters
  206. result = self.get_meter_reader(scale_data, pointer_data)
  207. return result
  208. def creat_line_image(self, meter_image):
  209. line_image = np.zeros((LINE_HEIGHT, LINE_WIDTH), dtype=np.uint8)
  210. for row in range(LINE_HEIGHT):
  211. for col in range(LINE_WIDTH):
  212. theta = PI * 2 / LINE_WIDTH * (col + 1)
  213. rho = CIRCLE_RADIUS - row - 1
  214. x = int(CIRCLE_CENTER[0] + rho * math.cos(theta) + 0.5)
  215. y = int(CIRCLE_CENTER[1] - rho * math.sin(theta) + 0.5)
  216. line_image[row, col] = meter_image[x, y]
  217. return line_image
  218. def convert_1d_data(self, meter_image):
  219. scale_data = np.zeros((LINE_WIDTH), dtype=np.uint8)
  220. pointer_data = np.zeros((LINE_WIDTH), dtype=np.uint8)
  221. for col in range(LINE_WIDTH):
  222. for row in range(LINE_HEIGHT):
  223. if meter_image[row, col] == 1:
  224. pointer_data[col] += 1
  225. elif meter_image[row, col] == 2:
  226. scale_data[col] += 1
  227. return scale_data, pointer_data
  228. def scale_mean_filtration(self, scale_data):
  229. mean_data = np.mean(scale_data)
  230. for col in range(LINE_WIDTH):
  231. if scale_data[col] < mean_data:
  232. scale_data[col] = 0
  233. def get_meter_reader(self, scale_data, pointer_data):
  234. scale_flag = False
  235. pointer_flag = False
  236. one_scale_start = 0
  237. one_scale_end = 0
  238. one_pointer_start = 0
  239. one_pointer_end = 0
  240. scale_location = list()
  241. pointer_location = 0
  242. for i in range(LINE_WIDTH - 1):
  243. if scale_data[i] > 0 and scale_data[i + 1] > 0:
  244. if scale_flag == False:
  245. one_scale_start = i
  246. scale_flag = True
  247. if scale_flag:
  248. if scale_data[i] == 0 and scale_data[i + 1] == 0:
  249. one_scale_end = i - 1
  250. one_scale_location = (one_scale_start + one_scale_end) / 2
  251. scale_location.append(one_scale_location)
  252. one_scale_start = 0
  253. one_scale_end = 0
  254. scale_flag = False
  255. if pointer_data[i] > 0 and pointer_data[i + 1] > 0:
  256. if pointer_flag == False:
  257. one_pointer_start = i
  258. pointer_flag = True
  259. if pointer_flag:
  260. if pointer_data[i] == 0 and pointer_data[i + 1] == 0:
  261. one_pointer_end = i - 1
  262. pointer_location = (
  263. one_pointer_start + one_pointer_end) / 2
  264. one_pointer_start = 0
  265. one_pointer_end = 0
  266. pointer_flag = False
  267. scale_num = len(scale_location)
  268. scales = -1
  269. ratio = -1
  270. if scale_num > 0:
  271. for i in range(scale_num - 1):
  272. if scale_location[
  273. i] <= pointer_location and pointer_location < scale_location[
  274. i + 1]:
  275. scales = i + (pointer_location - scale_location[i]) / (
  276. scale_location[i + 1] - scale_location[i] + 1e-05) + 1
  277. ratio = (pointer_location - scale_location[0]) / (
  278. scale_location[scale_num - 1] - scale_location[0] + 1e-05)
  279. result = {'scale_num': scale_num, 'scales': scales, 'ratio': ratio}
  280. return result
  281. def infer(args):
  282. image_lists = list()
  283. if args.image is not None:
  284. if not osp.exists(args.image):
  285. raise Exception("Image {} does not exist.".format(args.image))
  286. if not is_pic(args.image):
  287. raise Exception("{} is not a picture.".format(args.image))
  288. image_lists.append(args.image)
  289. elif args.image_dir is not None:
  290. if not osp.exists(args.image_dir):
  291. raise Exception("Directory {} does not exist.".format(
  292. args.image_dir))
  293. for im_file in os.listdir(args.image_dir):
  294. if not is_pic(im_file):
  295. continue
  296. im_file = osp.join(args.image_dir, im_file)
  297. image_lists.append(im_file)
  298. meter_reader = MeterReader(args.detector_dir, args.segmenter_dir)
  299. if len(image_lists) > 0:
  300. for im_file in image_lists:
  301. meter_reader.predict(im_file, args.save_dir, args.use_erode,
  302. args.erode_kernel, args.score_threshold,
  303. args.seg_batch_size)
  304. elif args.use_camera:
  305. cap_video = cv2.VideoCapture(args.camera_id)
  306. if not cap_video.isOpened():
  307. raise Exception(
  308. "Error opening video stream, please make sure the camera is working"
  309. )
  310. while cap_video.isOpened():
  311. ret, frame = cap_video.read()
  312. if ret:
  313. meter_reader.predict(frame, args.save_dir, args.use_erode,
  314. args.erode_kernel, args.score_threshold,
  315. args.seg_batch_size)
  316. if cv2.waitKey(1) & 0xFF == ord('q'):
  317. break
  318. else:
  319. break
  320. cap_video.release()
  321. if __name__ == '__main__':
  322. args = parse_args()
  323. infer(args)