reader_deploy.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360
  1. # coding: utf8
  2. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import os
  16. import os.path as osp
  17. import numpy as np
  18. import math
  19. import cv2
  20. import argparse
  21. from paddlex.seg import transforms
  22. import paddlex as pdx
  23. METER_SHAPE = 512
  24. CIRCLE_CENTER = [256, 256]
  25. CIRCLE_RADIUS = 250
  26. PI = 3.1415926536
  27. LINE_HEIGHT = 120
  28. LINE_WIDTH = 1570
  29. TYPE_THRESHOLD = 40
  30. METER_CONFIG = [{
  31. 'scale_value': 25.0 / 50.0,
  32. 'range': 25.0,
  33. 'unit': "(MPa)"
  34. }, {
  35. 'scale_value': 1.6 / 32.0,
  36. 'range': 1.6,
  37. 'unit': "(MPa)"
  38. }]
  39. def parse_args():
  40. parser = argparse.ArgumentParser(description='Meter Reader Infering')
  41. parser.add_argument(
  42. '--detector_dir',
  43. dest='detector_dir',
  44. help='The directory of models to do detection',
  45. type=str)
  46. parser.add_argument(
  47. '--segmenter_dir',
  48. dest='segmenter_dir',
  49. help='The directory of models to do segmentation',
  50. type=str)
  51. parser.add_argument(
  52. '--image_dir',
  53. dest='image_dir',
  54. help='The directory of images to be infered',
  55. type=str,
  56. default=None)
  57. parser.add_argument(
  58. '--image',
  59. dest='image',
  60. help='The image to be infered',
  61. type=str,
  62. default=None)
  63. parser.add_argument(
  64. '--use_camera',
  65. dest='use_camera',
  66. help='Whether use camera or not',
  67. action='store_true')
  68. parser.add_argument(
  69. '--camera_id',
  70. dest='camera_id',
  71. type=int,
  72. help='The camera id',
  73. default=0)
  74. parser.add_argument(
  75. '--use_erode',
  76. dest='use_erode',
  77. help='Whether erode the predicted lable map',
  78. action='store_true')
  79. parser.add_argument(
  80. '--erode_kernel',
  81. dest='erode_kernel',
  82. help='Erode kernel size',
  83. type=int,
  84. default=4)
  85. parser.add_argument(
  86. '--save_dir',
  87. dest='save_dir',
  88. help='The directory for saving the inference results',
  89. type=str,
  90. default='./output/result')
  91. parser.add_argument(
  92. '--score_threshold',
  93. dest='score_threshold',
  94. help="Detected bbox whose score is lower than this threshlod is filtered",
  95. type=float,
  96. default=0.5)
  97. parser.add_argument(
  98. '--seg_batch_size',
  99. dest='seg_batch_size',
  100. help="Segmentation batch size",
  101. type=int,
  102. default=2)
  103. parser.add_argument(
  104. '--seg_thread_num',
  105. dest='seg_thread_num',
  106. help="Thread number of segmentation preprocess",
  107. type=int,
  108. default=2)
  109. return parser.parse_args()
  110. def is_pic(img_name):
  111. valid_suffix = ['JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png']
  112. suffix = img_name.split('.')[-1]
  113. if suffix not in valid_suffix:
  114. return False
  115. return True
  116. class MeterReader:
  117. def __init__(self, detector_dir, segmenter_dir):
  118. if not osp.exists(detector_dir):
  119. raise Exception("Model path {} does not exist".format(
  120. detector_dir))
  121. if not osp.exists(segmenter_dir):
  122. raise Exception("Model path {} does not exist".format(
  123. segmenter_dir))
  124. self.detector = pdx.deploy.Predictor(detector_dir)
  125. self.segmenter = pdx.deploy.Predictor(segmenter_dir)
  126. # Because we will resize images with (METER_SHAPE, METER_SHAPE) before fed into the segmenter,
  127. # here the transform is composed of normalization only.
  128. self.seg_transforms = transforms.Compose([transforms.Normalize()])
  129. def predict(self,
  130. im_file,
  131. save_dir='./',
  132. use_erode=True,
  133. erode_kernel=4,
  134. score_threshold=0.5,
  135. seg_batch_size=2,
  136. seg_thread_num=2):
  137. if isinstance(im_file, str):
  138. im = cv2.imread(im_file).astype('float32')
  139. else:
  140. im = im_file.copy()
  141. # Get detection results
  142. det_results = self.detector.predict(im)
  143. # Filter bbox whose score is lower than score_threshold
  144. filtered_results = list()
  145. for res in det_results:
  146. if res['score'] > score_threshold:
  147. filtered_results.append(res)
  148. resized_meters = list()
  149. for res in filtered_results:
  150. # Crop the bbox area
  151. xmin, ymin, w, h = res['bbox']
  152. xmin = max(0, int(xmin))
  153. ymin = max(0, int(ymin))
  154. xmax = min(im.shape[1], int(xmin + w - 1))
  155. ymax = min(im.shape[0], int(ymin + h - 1))
  156. sub_image = im[ymin:(ymax + 1), xmin:(xmax + 1), :]
  157. # Resize the image with shape (METER_SHAPE, METER_SHAPE)
  158. meter_shape = sub_image.shape
  159. scale_x = float(METER_SHAPE) / float(meter_shape[1])
  160. scale_y = float(METER_SHAPE) / float(meter_shape[0])
  161. meter_meter = cv2.resize(
  162. sub_image,
  163. None,
  164. None,
  165. fx=scale_x,
  166. fy=scale_y,
  167. interpolation=cv2.INTER_LINEAR)
  168. meter_meter = meter_meter.astype('float32')
  169. resized_meters.append(meter_meter)
  170. meter_num = len(resized_meters)
  171. seg_results = list()
  172. for i in range(0, meter_num, seg_batch_size):
  173. im_size = min(meter_num, i + seg_batch_size)
  174. meter_images = list()
  175. for j in range(i, im_size):
  176. meter_images.append(resized_meters[j - i])
  177. result = self.segmenter.batch_predict(
  178. transforms=self.seg_transforms,
  179. img_file_list=meter_images,
  180. thread_num=seg_thread_num)
  181. if use_erode:
  182. kernel = np.ones((erode_kernel, erode_kernel), np.uint8)
  183. for i in range(len(result)):
  184. result[i]['label_map'] = cv2.erode(result[i]['label_map'],
  185. kernel)
  186. seg_results.extend(result)
  187. results = list()
  188. for i, seg_result in enumerate(seg_results):
  189. result = self.read_process(seg_result['label_map'])
  190. results.append(result)
  191. meter_values = list()
  192. for i, result in enumerate(results):
  193. if result['scale_num'] > TYPE_THRESHOLD:
  194. value = result['scales'] * METER_CONFIG[0]['scale_value']
  195. else:
  196. value = result['scales'] * METER_CONFIG[1]['scale_value']
  197. meter_values.append(value)
  198. print("-- Meter {} -- result: {} --\n".format(i, value))
  199. # visualize the results
  200. visual_results = list()
  201. for i, res in enumerate(filtered_results):
  202. # Use `score` to represent the meter value
  203. res['score'] = meter_values[i]
  204. visual_results.append(res)
  205. pdx.det.visualize(im_file, visual_results, -1, save_dir=save_dir)
  206. def read_process(self, label_maps):
  207. # Convert the circular meter into rectangular meter
  208. line_images = self.creat_line_image(label_maps)
  209. # Convert the 2d meter into 1d meter
  210. scale_data, pointer_data = self.convert_1d_data(line_images)
  211. # Fliter scale data whose value is lower than the mean value
  212. self.scale_mean_filtration(scale_data)
  213. # Get scale_num, scales and ratio of meters
  214. result = self.get_meter_reader(scale_data, pointer_data)
  215. return result
  216. def creat_line_image(self, meter_image):
  217. line_image = np.zeros((LINE_HEIGHT, LINE_WIDTH), dtype=np.uint8)
  218. for row in range(LINE_HEIGHT):
  219. for col in range(LINE_WIDTH):
  220. theta = PI * 2 / LINE_WIDTH * (col + 1)
  221. rho = CIRCLE_RADIUS - row - 1
  222. x = int(CIRCLE_CENTER[0] + rho * math.cos(theta) + 0.5)
  223. y = int(CIRCLE_CENTER[1] - rho * math.sin(theta) + 0.5)
  224. line_image[row, col] = meter_image[x, y]
  225. return line_image
  226. def convert_1d_data(self, meter_image):
  227. scale_data = np.zeros((LINE_WIDTH), dtype=np.uint8)
  228. pointer_data = np.zeros((LINE_WIDTH), dtype=np.uint8)
  229. for col in range(LINE_WIDTH):
  230. for row in range(LINE_HEIGHT):
  231. if meter_image[row, col] == 1:
  232. pointer_data[col] += 1
  233. elif meter_image[row, col] == 2:
  234. scale_data[col] += 1
  235. return scale_data, pointer_data
  236. def scale_mean_filtration(self, scale_data):
  237. mean_data = np.mean(scale_data)
  238. for col in range(LINE_WIDTH):
  239. if scale_data[col] < mean_data:
  240. scale_data[col] = 0
  241. def get_meter_reader(self, scale_data, pointer_data):
  242. scale_flag = False
  243. pointer_flag = False
  244. one_scale_start = 0
  245. one_scale_end = 0
  246. one_pointer_start = 0
  247. one_pointer_end = 0
  248. scale_location = list()
  249. pointer_location = 0
  250. for i in range(LINE_WIDTH - 1):
  251. if scale_data[i] > 0 and scale_data[i + 1] > 0:
  252. if scale_flag == False:
  253. one_scale_start = i
  254. scale_flag = True
  255. if scale_flag:
  256. if scale_data[i] == 0 and scale_data[i + 1] == 0:
  257. one_scale_end = i - 1
  258. one_scale_location = (one_scale_start + one_scale_end) / 2
  259. scale_location.append(one_scale_location)
  260. one_scale_start = 0
  261. one_scale_end = 0
  262. scale_flag = False
  263. if pointer_data[i] > 0 and pointer_data[i + 1] > 0:
  264. if pointer_flag == False:
  265. one_pointer_start = i
  266. pointer_flag = True
  267. if pointer_flag:
  268. if pointer_data[i] == 0 and pointer_data[i + 1] == 0:
  269. one_pointer_end = i - 1
  270. pointer_location = (
  271. one_pointer_start + one_pointer_end) / 2
  272. one_pointer_start = 0
  273. one_pointer_end = 0
  274. pointer_flag = False
  275. scale_num = len(scale_location)
  276. scales = -1
  277. ratio = -1
  278. if scale_num > 0:
  279. for i in range(scale_num - 1):
  280. if scale_location[
  281. i] <= pointer_location and pointer_location < scale_location[
  282. i + 1]:
  283. scales = i + (pointer_location - scale_location[i]) / (
  284. scale_location[i + 1] - scale_location[i] + 1e-05) + 1
  285. ratio = (pointer_location - scale_location[0]) / (
  286. scale_location[scale_num - 1] - scale_location[0] + 1e-05)
  287. result = {'scale_num': scale_num, 'scales': scales, 'ratio': ratio}
  288. return result
  289. def infer(args):
  290. image_lists = list()
  291. if args.image is not None:
  292. if not osp.exists(args.image):
  293. raise Exception("Image {} does not exist.".format(args.image))
  294. if not is_pic(args.image):
  295. raise Exception("{} is not a picture.".format(args.image))
  296. image_lists.append(args.image)
  297. elif args.image_dir is not None:
  298. if not osp.exists(args.image_dir):
  299. raise Exception("Directory {} does not exist.".format(
  300. args.image_dir))
  301. for im_file in os.listdir(args.image_dir):
  302. if not is_pic(im_file):
  303. continue
  304. im_file = osp.join(args.image_dir, im_file)
  305. image_lists.append(im_file)
  306. meter_reader = MeterReader(args.detector_dir, args.segmenter_dir)
  307. if len(image_lists) > 0:
  308. for im_file in image_lists:
  309. meter_reader.predict(im_file, args.save_dir, args.use_erode,
  310. args.erode_kernel, args.score_threshold,
  311. args.seg_batch_size, args.seg_thread_num)
  312. elif args.use_camera:
  313. cap_video = cv2.VideoCapture(args.camera_id)
  314. if not cap_video.isOpened():
  315. raise Exception(
  316. "Error opening video stream, please make sure the camera is working"
  317. )
  318. while cap_video.isOpened():
  319. ret, frame = cap_video.read()
  320. if ret:
  321. meter_reader.predict(frame, args.save_dir, args.use_erode,
  322. args.erode_kernel, args.score_threshold,
  323. args.seg_batch_size, args.seg_thread_num)
  324. if cv2.waitKey(1) & 0xFF == ord('q'):
  325. break
  326. else:
  327. break
  328. cap_video.release()
  329. if __name__ == '__main__':
  330. args = parse_args()
  331. infer(args)