reader_infer.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. # coding: utf8
  2. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import os
  16. import os.path as osp
  17. import numpy as np
  18. import math
  19. import cv2
  20. import argparse
  21. from paddlex.seg import transforms
  22. import paddlex as pdx
  23. # The size of inputting images (METER_SHAPE x METER_SHAPE) of the segmenter,
  24. # also the size of circular meters.
  25. METER_SHAPE = 512
  26. # Center of a circular meter
  27. CIRCLE_CENTER = [256, 256]
  28. # Radius of a circular meter
  29. CIRCLE_RADIUS = 250
  30. PI = 3.1415926536
  31. # During the postprocess phase, annulus formed by the radius from
  32. # 130 to 250 of a circular meter will be converted to a rectangle.
  33. # So the height of the rectangle is 120.
  34. LINE_HEIGHT = 120
  35. # The width of the rectangle is 1570, that is to say the perimeter of a circular meter
  36. LINE_WIDTH = 1570
  37. # The type of a meter is estimated by a threshold. If the number of scales in a meter is
  38. # greater than or equal to the threshold, the meter is belong to the former type.
  39. # Otherwize, the latter.
  40. TYPE_THRESHOLD = 40
  41. # The configuration information of a meter, composed of scale value, range, unit.
  42. METER_CONFIG = [{
  43. 'scale_value': 25.0 / 50.0,
  44. 'range': 25.0,
  45. 'unit': "(MPa)"
  46. }, {
  47. 'scale_value': 1.6 / 32.0,
  48. 'range': 1.6,
  49. 'unit': "(MPa)"
  50. }]
  51. def parse_args():
  52. parser = argparse.ArgumentParser(description='Meter Reader Infering')
  53. parser.add_argument(
  54. '--detector_dir',
  55. dest='detector_dir',
  56. help='The directory of models to do detection',
  57. type=str)
  58. parser.add_argument(
  59. '--segmenter_dir',
  60. dest='segmenter_dir',
  61. help='The directory of models to do segmentation',
  62. type=str)
  63. parser.add_argument(
  64. '--image_dir',
  65. dest='image_dir',
  66. help='The directory of images to be infered',
  67. type=str,
  68. default=None)
  69. parser.add_argument(
  70. '--image',
  71. dest='image',
  72. help='The image to be infered',
  73. type=str,
  74. default=None)
  75. parser.add_argument(
  76. '--use_camera',
  77. dest='use_camera',
  78. help='Whether use camera or not',
  79. action='store_true')
  80. parser.add_argument(
  81. '--camera_id',
  82. dest='camera_id',
  83. type=int,
  84. help='The camera id',
  85. default=0)
  86. parser.add_argument(
  87. '--use_erode',
  88. dest='use_erode',
  89. help='Whether erode the predicted lable map',
  90. action='store_true')
  91. parser.add_argument(
  92. '--erode_kernel',
  93. dest='erode_kernel',
  94. help='Erode kernel size',
  95. type=int,
  96. default=4)
  97. parser.add_argument(
  98. '--save_dir',
  99. dest='save_dir',
  100. help='The directory for saving the inference results',
  101. type=str,
  102. default='./output/result')
  103. parser.add_argument(
  104. '--score_threshold',
  105. dest='score_threshold',
  106. help="Detected bbox whose score is lower than this threshlod is filtered",
  107. type=float,
  108. default=0.5)
  109. parser.add_argument(
  110. '--seg_batch_size',
  111. dest='seg_batch_size',
  112. help="Segmentation batch size",
  113. type=int,
  114. default=2)
  115. return parser.parse_args()
  116. def is_pic(img_name):
  117. valid_suffix = ['JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png']
  118. suffix = img_name.split('.')[-1]
  119. if suffix not in valid_suffix:
  120. return False
  121. return True
  122. class MeterReader:
  123. """Find the meters in images and provide a digital readout of each meter.
  124. Args:
  125. detector_dir(str): directory of the detector.
  126. segmenter_dir(str): directory of the segmenter.
  127. """
  128. def __init__(self, detector_dir, segmenter_dir):
  129. if not osp.exists(detector_dir):
  130. raise Exception("Model path {} does not exist".format(
  131. detector_dir))
  132. if not osp.exists(segmenter_dir):
  133. raise Exception("Model path {} does not exist".format(
  134. segmenter_dir))
  135. self.detector = pdx.load_model(detector_dir)
  136. self.segmenter = pdx.load_model(segmenter_dir)
  137. # Because we will resize images with (METER_SHAPE, METER_SHAPE) before fed into the segmenter,
  138. # here the transform is composed of normalization only.
  139. self.seg_transforms = transforms.Compose([transforms.Normalize()])
  140. def predict(self,
  141. im_file,
  142. save_dir='./',
  143. use_erode=True,
  144. erode_kernel=4,
  145. score_threshold=0.5,
  146. seg_batch_size=2):
  147. """Detect meters in a image, segment scales and points in these meters, the postprocess are
  148. done to provide a digital readout according to scale and point location.
  149. Args:
  150. im_file (str): the path of a image to be predicted.
  151. save_dir (str): the directory to save the visual prediction. Default: './'.
  152. use_erode (bool, optional): whether to do image erosion by using a specific structuring element for
  153. the label map output from the segmenter. Default: True.
  154. erode_kernel (int, optional): structuring element used for erosion. Default: 4.
  155. score_threshold (float, optional): detected meters whose scores are not lower than `score_threshold`
  156. will be fed into the following segmenter. Default: 0.5.
  157. seg_batch_size (int, optional): batch size of meters when do segmentation. Default: 2.
  158. """
  159. if isinstance(im_file, str):
  160. im = cv2.imread(im_file).astype('float32')
  161. else:
  162. im = im_file.copy()
  163. # Get detection results
  164. det_results = self.detector.predict(im)
  165. # Filter bbox whose score is lower than score_threshold
  166. filtered_results = list()
  167. for res in det_results:
  168. if res['score'] > score_threshold:
  169. filtered_results.append(res)
  170. resized_meters = list()
  171. for res in filtered_results:
  172. # Crop the bbox area
  173. xmin, ymin, w, h = res['bbox']
  174. xmin = max(0, int(xmin))
  175. ymin = max(0, int(ymin))
  176. xmax = min(im.shape[1], int(xmin + w - 1))
  177. ymax = min(im.shape[0], int(ymin + h - 1))
  178. sub_image = im[ymin:(ymax + 1), xmin:(xmax + 1), :]
  179. # Resize the image with shape (METER_SHAPE, METER_SHAPE)
  180. meter_shape = sub_image.shape
  181. scale_x = float(METER_SHAPE) / float(meter_shape[1])
  182. scale_y = float(METER_SHAPE) / float(meter_shape[0])
  183. meter_meter = cv2.resize(
  184. sub_image,
  185. None,
  186. None,
  187. fx=scale_x,
  188. fy=scale_y,
  189. interpolation=cv2.INTER_LINEAR)
  190. meter_meter = meter_meter.astype('float32')
  191. resized_meters.append(meter_meter)
  192. meter_num = len(resized_meters)
  193. seg_results = list()
  194. for i in range(0, meter_num, seg_batch_size):
  195. im_size = min(meter_num, i + seg_batch_size)
  196. meter_images = list()
  197. for j in range(i, im_size):
  198. meter_images.append(resized_meters[j])
  199. # Segment scales and point in each meter area
  200. result = self.segmenter.batch_predict(
  201. transforms=self.seg_transforms, img_file_list=meter_images)
  202. # Do image erosion for the predicted label map of each meter
  203. if use_erode:
  204. kernel = np.ones((erode_kernel, erode_kernel), np.uint8)
  205. for i in range(len(result)):
  206. result[i]['label_map'] = cv2.erode(result[i]['label_map'],
  207. kernel)
  208. seg_results.extend(result)
  209. results = list()
  210. # The postprocess are done to get the point location relative to the scales
  211. for i, seg_result in enumerate(seg_results):
  212. result = self.read_process(seg_result['label_map'])
  213. results.append(result)
  214. # Provide a digital readout according to point location relative to the scales
  215. meter_values = list()
  216. for i, result in enumerate(results):
  217. if result['scale_num'] > TYPE_THRESHOLD:
  218. value = result['scales'] * METER_CONFIG[0]['scale_value']
  219. else:
  220. value = result['scales'] * METER_CONFIG[1]['scale_value']
  221. meter_values.append(value)
  222. print("-- Meter {} -- result: {} --\n".format(i, value))
  223. # Visualize the results
  224. visual_results = list()
  225. for i, res in enumerate(filtered_results):
  226. # Use `score` to represent the meter value
  227. res['score'] = meter_values[i]
  228. visual_results.append(res)
  229. pdx.det.visualize(im_file, visual_results, -1, save_dir=save_dir)
  230. def read_process(self, label_maps):
  231. """Get the pointer location relative to the scales.
  232. Args:
  233. label_maps (np.array): the label map output from a segmeter for a meter.
  234. """
  235. # Convert the circular meter into a rectangular meter
  236. line_images = self.creat_line_image(label_maps)
  237. # Get two one-dimension data where 0 represents background and >0 represents
  238. # a scale or a pointer
  239. scale_data, pointer_data = self.convert_1d_data(line_images)
  240. # Fliter scale data whose value is lower than the mean value
  241. self.scale_mean_filtration(scale_data)
  242. # Get the number of scales,the pointer location relative to the scales, the ratio between
  243. # the distance from the pointer to the starting scale and distance from the ending scale to the
  244. # starting scale.
  245. result = self.get_meter_reader(scale_data, pointer_data)
  246. return result
  247. def creat_line_image(self, meter_image):
  248. """Convert the circular meter into a rectangular meter.
  249. The minimum scale value is at the bottom left, the maximum scale value
  250. is at the bottom right, so the vertical down axis is the starting axis and
  251. rotates around the meter ceneter counterclockwise.
  252. Args:
  253. meter_image (np.array): the label map output from a segmeter for a meter.
  254. Returns:
  255. line_image (np.array): a rectangular meter.
  256. """
  257. line_image = np.zeros((LINE_HEIGHT, LINE_WIDTH), dtype=np.uint8)
  258. for row in range(LINE_HEIGHT):
  259. for col in range(LINE_WIDTH):
  260. theta = PI * 2 / LINE_WIDTH * (col + 1)
  261. rho = CIRCLE_RADIUS - row - 1
  262. y = int(CIRCLE_CENTER[0] + rho * math.cos(theta) + 0.5)
  263. x = int(CIRCLE_CENTER[1] - rho * math.sin(theta) + 0.5)
  264. line_image[row, col] = meter_image[y, x]
  265. return line_image
  266. def convert_1d_data(self, meter_image):
  267. """Get two one-dimension data where 0 represents background and >0 represents
  268. a scale or a pointer from the rectangular meter.
  269. Args:
  270. meter_image (np.array): the two-dimension rectangular meter output
  271. from function creat_line_image().
  272. Returns:
  273. scale_data (np.array): a one-dimension data where 0 represents background and
  274. >0 represents scales.
  275. pointer_data (np.array): a one-dimension data where 0 represents background and
  276. >0 represents a pointer.
  277. """
  278. scale_data = np.zeros((LINE_WIDTH), dtype=np.uint8)
  279. pointer_data = np.zeros((LINE_WIDTH), dtype=np.uint8)
  280. # Accumulte the number of positions whose label is 1 along the height axis.
  281. # Accumulte the number of positions whose label is 2 along the height axis.
  282. for col in range(LINE_WIDTH):
  283. for row in range(LINE_HEIGHT):
  284. if meter_image[row, col] == 1:
  285. pointer_data[col] += 1
  286. elif meter_image[row, col] == 2:
  287. scale_data[col] += 1
  288. return scale_data, pointer_data
  289. def scale_mean_filtration(self, scale_data):
  290. """Set the element in the scale data which is lower than its mean value to 0.
  291. Args:
  292. scale_data (np.array): the scale data output from function convert_1d_data().
  293. """
  294. mean_data = np.mean(scale_data)
  295. for col in range(LINE_WIDTH):
  296. if scale_data[col] < mean_data:
  297. scale_data[col] = 0
  298. def get_meter_reader(self, scale_data, pointer_data):
  299. """Calculate the number of scales,the pointer location relative to the scales, the ratio between
  300. the distance from the pointer to the starting scale and distance from the ending scale to the
  301. starting scale.
  302. Args:
  303. scale_data (np.array): a scale data output from function scale_mean_filtration().
  304. pointer_data (np.array): a pointer data output from function convert_1d_data().
  305. Returns:
  306. Dict (keys: 'scale_num', 'scales', 'ratio'):
  307. The value of key 'scale_num' (int): the number of scales;
  308. The value of 'scales' (float): the pointer location relative to the scales;
  309. the value of 'ratio' (float): the ratio between from the pointer to the starting scale and
  310. distance from the ending scale to the starting scale.
  311. """
  312. scale_flag = False
  313. pointer_flag = False
  314. one_scale_start = 0
  315. one_scale_end = 0
  316. one_pointer_start = 0
  317. one_pointer_end = 0
  318. scale_location = list()
  319. pointer_location = 0
  320. for i in range(LINE_WIDTH - 1):
  321. if scale_data[i] > 0 and scale_data[i + 1] > 0:
  322. if scale_flag == False:
  323. one_scale_start = i
  324. scale_flag = True
  325. if scale_flag:
  326. if scale_data[i] == 0 and scale_data[i + 1] == 0:
  327. one_scale_end = i - 1
  328. one_scale_location = (one_scale_start + one_scale_end) / 2
  329. scale_location.append(one_scale_location)
  330. one_scale_start = 0
  331. one_scale_end = 0
  332. scale_flag = False
  333. if pointer_data[i] > 0 and pointer_data[i + 1] > 0:
  334. if pointer_flag == False:
  335. one_pointer_start = i
  336. pointer_flag = True
  337. if pointer_flag:
  338. if pointer_data[i] == 0 and pointer_data[i + 1] == 0:
  339. one_pointer_end = i - 1
  340. pointer_location = (
  341. one_pointer_start + one_pointer_end) / 2
  342. one_pointer_start = 0
  343. one_pointer_end = 0
  344. pointer_flag = False
  345. scale_num = len(scale_location)
  346. scales = -1
  347. ratio = -1
  348. if scale_num > 0:
  349. for i in range(scale_num - 1):
  350. if scale_location[
  351. i] <= pointer_location and pointer_location < scale_location[
  352. i + 1]:
  353. scales = i + (pointer_location - scale_location[i]) / (
  354. scale_location[i + 1] - scale_location[i] + 1e-05) + 1
  355. ratio = (pointer_location - scale_location[0]) / (
  356. scale_location[scale_num - 1] - scale_location[0] + 1e-05)
  357. result = {'scale_num': scale_num, 'scales': scales, 'ratio': ratio}
  358. return result
  359. def infer(args):
  360. image_lists = list()
  361. if args.image is not None:
  362. if not osp.exists(args.image):
  363. raise Exception("Image {} does not exist.".format(args.image))
  364. if not is_pic(args.image):
  365. raise Exception("{} is not a picture.".format(args.image))
  366. image_lists.append(args.image)
  367. elif args.image_dir is not None:
  368. if not osp.exists(args.image_dir):
  369. raise Exception("Directory {} does not exist.".format(
  370. args.image_dir))
  371. for im_file in os.listdir(args.image_dir):
  372. if not is_pic(im_file):
  373. continue
  374. im_file = osp.join(args.image_dir, im_file)
  375. image_lists.append(im_file)
  376. meter_reader = MeterReader(args.detector_dir, args.segmenter_dir)
  377. if len(image_lists) > 0:
  378. for im_file in image_lists:
  379. meter_reader.predict(im_file, args.save_dir, args.use_erode,
  380. args.erode_kernel, args.score_threshold,
  381. args.seg_batch_size)
  382. elif args.use_camera:
  383. cap_video = cv2.VideoCapture(args.camera_id)
  384. if not cap_video.isOpened():
  385. raise Exception(
  386. "Error opening video stream, please make sure the camera is working"
  387. )
  388. while cap_video.isOpened():
  389. ret, frame = cap_video.read()
  390. if ret:
  391. meter_reader.predict(frame, args.save_dir, args.use_erode,
  392. args.erode_kernel, args.score_threshold,
  393. args.seg_batch_size)
  394. if cv2.waitKey(1) & 0xFF == ord('q'):
  395. break
  396. else:
  397. break
  398. cap_video.release()
  399. if __name__ == '__main__':
  400. args = parse_args()
  401. infer(args)