reader_deploy.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. # coding: utf8
  2. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import os
  16. import os.path as osp
  17. import numpy as np
  18. import math
  19. import cv2
  20. import argparse
  21. from paddlex.seg import transforms
  22. import paddlex as pdx
  23. METER_SHAPE = 512
  24. CIRCLE_CENTER = [256, 256]
  25. CIRCLE_RADIUS = 250
  26. PI = 3.1415926536
  27. LINE_HEIGHT = 120
  28. LINE_WIDTH = 1570
  29. TYPE_THRESHOLD = 40
  30. METER_CONFIG = [{
  31. 'scale_value': 25.0 / 50.0,
  32. 'range': 25.0,
  33. 'unit': "(MPa)"
  34. }, {
  35. 'scale_value': 1.6 / 32.0,
  36. 'range': 1.6,
  37. 'unit': "(MPa)"
  38. }]
  39. def parse_args():
  40. parser = argparse.ArgumentParser(description='Meter Reader Infering')
  41. parser.add_argument(
  42. '--detector_dir',
  43. dest='detector_dir',
  44. help='The directory of models to do detection',
  45. type=str)
  46. parser.add_argument(
  47. '--segmenter_dir',
  48. dest='segmenter_dir',
  49. help='The directory of models to do segmentation',
  50. type=str)
  51. parser.add_argument(
  52. '--image_dir',
  53. dest='image_dir',
  54. help='The directory of images to be infered',
  55. type=str,
  56. default=None)
  57. parser.add_argument(
  58. '--image',
  59. dest='image',
  60. help='The image to be infered',
  61. type=str,
  62. default=None)
  63. parser.add_argument(
  64. '--use_camera',
  65. dest='use_camera',
  66. help='Whether use camera or not',
  67. action='store_true')
  68. parser.add_argument(
  69. '--use_erode',
  70. dest='use_erode',
  71. help='Whether erode the predicted lable map',
  72. action='store_true')
  73. parser.add_argument(
  74. '--erode_kernel',
  75. dest='erode_kernel',
  76. help='Erode kernel size',
  77. type=int,
  78. default=4)
  79. parser.add_argument(
  80. '--save_dir',
  81. dest='save_dir',
  82. help='The directory for saving the inference results',
  83. type=str,
  84. default='./output/result')
  85. parser.add_argument(
  86. '--score_threshold',
  87. dest='score_threshold',
  88. help="Detected bbox whose score is lower than this threshlod is filtered",
  89. type=float,
  90. default=0.5)
  91. parser.add_argument(
  92. '--seg_batch_size',
  93. dest='seg_batch_size',
  94. help="Segmentation batch size",
  95. type=int,
  96. default=2)
  97. parser.add_argument(
  98. '--seg_thread_num',
  99. dest='seg_thread_num',
  100. help="Thread number of segmentation preprocess",
  101. type=int,
  102. default=2)
  103. return parser.parse_args()
  104. def is_pic(img_name):
  105. valid_suffix = ['JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png']
  106. suffix = img_name.split('.')[-1]
  107. if suffix not in valid_suffix:
  108. return False
  109. return True
  110. class MeterReader:
  111. def __init__(self, detector_dir, segmenter_dir):
  112. if not osp.exists(detector_dir):
  113. raise Exception("Model path {} does not exist".format(
  114. detector_dir))
  115. if not osp.exists(segmenter_dir):
  116. raise Exception("Model path {} does not exist".format(
  117. segmenter_dir))
  118. self.detector = pdx.deploy.Predictor(detector_dir)
  119. self.segmenter = pdx.deploy.Predictor(segmenter_dir)
  120. # Because we will resize images with (METER_SHAPE, METER_SHAPE) before fed into the segmenter,
  121. # here the transform is composed of normalization only.
  122. self.seg_transforms = transforms.Compose([transforms.Normalize()])
  123. def predict(self,
  124. im_file,
  125. save_dir='./',
  126. use_erode=True,
  127. erode_kernel=4,
  128. score_threshold=0.5,
  129. seg_batch_size=2,
  130. seg_thread_num=2):
  131. if isinstance(im_file, str):
  132. im = cv2.imread(im_file).astype('float32')
  133. else:
  134. im = im_file.copy()
  135. # Get detection results
  136. det_results = self.detector.predict(im)
  137. # Filter bbox whose score is lower than score_threshold
  138. filtered_results = list()
  139. for res in det_results:
  140. if res['score'] > score_threshold:
  141. filtered_results.append(res)
  142. resized_meters = list()
  143. for res in filtered_results:
  144. # Crop the bbox area
  145. xmin, ymin, w, h = res['bbox']
  146. xmin = max(0, int(xmin))
  147. ymin = max(0, int(ymin))
  148. xmax = min(im.shape[1], int(xmin + w - 1))
  149. ymax = min(im.shape[0], int(ymin + h - 1))
  150. sub_image = im[ymin:(ymax + 1), xmin:(xmax + 1), :]
  151. # Resize the image with shape (METER_SHAPE, METER_SHAPE)
  152. meter_shape = sub_image.shape
  153. scale_x = float(METER_SHAPE) / float(meter_shape[1])
  154. scale_y = float(METER_SHAPE) / float(meter_shape[0])
  155. meter_meter = cv2.resize(
  156. sub_image,
  157. None,
  158. None,
  159. fx=scale_x,
  160. fy=scale_y,
  161. interpolation=cv2.INTER_LINEAR)
  162. meter_meter = meter_meter.astype('float32')
  163. resized_meters.append(meter_meter)
  164. meter_num = len(resized_meters)
  165. seg_results = list()
  166. for i in range(0, meter_num, seg_batch_size):
  167. im_size = min(meter_num, i + seg_batch_size)
  168. meter_images = list()
  169. for j in range(i, im_size):
  170. meter_images.append(resized_meters[j - i])
  171. result = self.segmenter.batch_predict(
  172. transforms=self.seg_transforms,
  173. img_file_list=meter_images,
  174. thread_num=seg_thread_num)
  175. if use_erode:
  176. kernel = np.ones((erode_kernel, erode_kernel), np.uint8)
  177. for i in range(len(seg_results)):
  178. result[i]['label_map'] = cv2.erode(
  179. result[i]['label_map'], kernel)
  180. seg_results.extend(result)
  181. results = list()
  182. for i, seg_result in enumerate(seg_results):
  183. result = self.read_process(seg_result['label_map'])
  184. results.append(result)
  185. meter_values = list()
  186. for i, result in enumerate(results):
  187. if result['scale_num'] > TYPE_THRESHOLD:
  188. value = result['scales'] * METER_CONFIG[0]['scale_value']
  189. else:
  190. value = result['scales'] * METER_CONFIG[1]['scale_value']
  191. meter_values.append(value)
  192. print("-- Meter {} -- result: {} --\n".format(i, value))
  193. # visualize the results
  194. visual_results = list()
  195. for i, res in enumerate(filtered_results):
  196. # Use `score` to represent the meter value
  197. res['score'] = meter_values[i]
  198. visual_results.append(res)
  199. pdx.det.visualize(im_file, visual_results, -1, save_dir=save_dir)
  200. def read_process(self, label_maps):
  201. # Convert the circular meter into rectangular meter
  202. line_images = self.creat_line_image(label_maps)
  203. # Convert the 2d meter into 1d meter
  204. scale_data, pointer_data = self.convert_1d_data(line_images)
  205. # Fliter scale data whose value is lower than the mean value
  206. self.scale_mean_filtration(scale_data)
  207. # Get scale_num, scales and ratio of meters
  208. result = self.get_meter_reader(scale_data, pointer_data)
  209. return result
  210. def creat_line_image(self, meter_image):
  211. line_image = np.zeros((LINE_HEIGHT, LINE_WIDTH), dtype=np.uint8)
  212. for row in range(LINE_HEIGHT):
  213. for col in range(LINE_WIDTH):
  214. theta = PI * 2 / LINE_WIDTH * (col + 1)
  215. rho = CIRCLE_RADIUS - row - 1
  216. x = int(CIRCLE_CENTER[0] + rho * math.cos(theta) + 0.5)
  217. y = int(CIRCLE_CENTER[1] - rho * math.sin(theta) + 0.5)
  218. line_image[row, col] = meter_image[x, y]
  219. return line_image
  220. def convert_1d_data(self, meter_image):
  221. scale_data = np.zeros((LINE_WIDTH), dtype=np.uint8)
  222. pointer_data = np.zeros((LINE_WIDTH), dtype=np.uint8)
  223. for col in range(LINE_WIDTH):
  224. for row in range(LINE_HEIGHT):
  225. if meter_image[row, col] == 1:
  226. pointer_data[col] += 1
  227. elif meter_image[row, col] == 2:
  228. scale_data[col] += 1
  229. return scale_data, pointer_data
  230. def scale_mean_filtration(self, scale_data):
  231. mean_data = np.mean(scale_data)
  232. for col in range(LINE_WIDTH):
  233. if scale_data[col] < mean_data:
  234. scale_data[col] = 0
  235. def get_meter_reader(self, scale_data, pointer_data):
  236. scale_flag = False
  237. pointer_flag = False
  238. one_scale_start = 0
  239. one_scale_end = 0
  240. one_pointer_start = 0
  241. one_pointer_end = 0
  242. scale_location = list()
  243. pointer_location = 0
  244. for i in range(LINE_WIDTH - 1):
  245. if scale_data[i] > 0 and scale_data[i + 1] > 0:
  246. if scale_flag == False:
  247. one_scale_start = i
  248. scale_flag = True
  249. if scale_flag:
  250. if scale_data[i] == 0 and scale_data[i + 1] == 0:
  251. one_scale_end = i - 1
  252. one_scale_location = (one_scale_start + one_scale_end) / 2
  253. scale_location.append(one_scale_location)
  254. one_scale_start = 0
  255. one_scale_end = 0
  256. scale_flag = False
  257. if pointer_data[i] > 0 and pointer_data[i + 1] > 0:
  258. if pointer_flag == False:
  259. one_pointer_start = i
  260. pointer_flag = True
  261. if pointer_flag:
  262. if pointer_data[i] == 0 and pointer_data[i + 1] == 0:
  263. one_pointer_end = i - 1
  264. pointer_location = (
  265. one_pointer_start + one_pointer_end) / 2
  266. one_pointer_start = 0
  267. one_pointer_end = 0
  268. pointer_flag = False
  269. scale_num = len(scale_location)
  270. scales = -1
  271. ratio = -1
  272. if scale_num > 0:
  273. for i in range(scale_num - 1):
  274. if scale_location[
  275. i] <= pointer_location and pointer_location < scale_location[
  276. i + 1]:
  277. scales = i + (pointer_location - scale_location[i]) / (
  278. scale_location[i + 1] - scale_location[i] + 1e-05) + 1
  279. ratio = (pointer_location - scale_location[0]) / (
  280. scale_location[scale_num - 1] - scale_location[0] + 1e-05)
  281. result = {'scale_num': scale_num, 'scales': scales, 'ratio': ratio}
  282. return result
  283. def infer(args):
  284. image_lists = list()
  285. if args.image is not None:
  286. if not osp.exists(args.image):
  287. raise Exception("Image {} does not exist.".format(args.image))
  288. if not is_pic(args.image):
  289. raise Exception("{} is not a picture.".format(args.image))
  290. image_lists.append(args.image)
  291. elif args.image_dir is not None:
  292. if not osp.exists(args.image_dir):
  293. raise Exception("Directory {} does not exist.".format(
  294. args.image_dir))
  295. for im_file in os.listdir(args.image_dir):
  296. if not is_pic(im_file):
  297. continue
  298. im_file = osp.join(args.image_dir, im_file)
  299. image_lists.append(im_file)
  300. meter_reader = MeterReader(args.detector_dir, args.segmenter_dir)
  301. if len(image_lists) > 0:
  302. for im_file in image_lists:
  303. meter_reader.predict(im_file, args.save_dir, args.use_erode,
  304. args.erode_kernel, args.score_threshold,
  305. args.seg_batch_size, args.seg_thread_num)
  306. elif args.with_camera:
  307. cap_video = cv2.VideoCapture(0)
  308. if not cap_video.isOpened():
  309. raise Exception(
  310. "Error opening video stream, please make sure the camera is working"
  311. )
  312. while cap_video.isOpened():
  313. ret, frame = cap_video.read()
  314. if ret:
  315. meter_reader.predict(frame, args.save_dir, args.use_erode,
  316. args.erode_kernel, args.score_threshold,
  317. args.seg_batch_size, args.seg_thread_num)
  318. if cv2.waitKey(1) & 0xFF == ord('q'):
  319. break
  320. else:
  321. break
  322. cap_video.release()
  323. if __name__ == '__main__':
  324. args = parse_args()
  325. infer(args)