keypoint_coco.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import cv2
  16. import numpy as np
  17. import json
  18. import copy
  19. import pycocotools
  20. from pycocotools.coco import COCO
  21. from .dataset import DetDataset
  22. from paddlex.ppdet.core.workspace import register, serializable
  23. @serializable
  24. class KeypointBottomUpBaseDataset(DetDataset):
  25. """Base class for bottom-up datasets.
  26. All datasets should subclass it.
  27. All subclasses should overwrite:
  28. Methods:`_get_imganno`
  29. Args:
  30. dataset_dir (str): Root path to the dataset.
  31. anno_path (str): Relative path to the annotation file.
  32. image_dir (str): Path to a directory where images are held.
  33. Default: None.
  34. num_joints (int): keypoint numbers
  35. transform (composed(operators)): A sequence of data transforms.
  36. shard (list): [rank, worldsize], the distributed env params
  37. test_mode (bool): Store True when building test or
  38. validation dataset. Default: False.
  39. """
  40. def __init__(self,
  41. dataset_dir,
  42. image_dir,
  43. anno_path,
  44. num_joints,
  45. transform=[],
  46. shard=[0, 1],
  47. test_mode=False):
  48. super().__init__(dataset_dir, image_dir, anno_path)
  49. self.image_info = {}
  50. self.ann_info = {}
  51. self.img_prefix = os.path.join(dataset_dir, image_dir)
  52. self.transform = transform
  53. self.test_mode = test_mode
  54. self.ann_info['num_joints'] = num_joints
  55. self.img_ids = []
  56. def __len__(self):
  57. """Get dataset length."""
  58. return len(self.img_ids)
  59. def _get_imganno(self, idx):
  60. """Get anno for a single image."""
  61. raise NotImplementedError
  62. def __getitem__(self, idx):
  63. """Prepare image for training given the index."""
  64. records = copy.deepcopy(self._get_imganno(idx))
  65. records['image'] = cv2.imread(records['image_file'])
  66. records['image'] = cv2.cvtColor(records['image'], cv2.COLOR_BGR2RGB)
  67. records['mask'] = (records['mask'] + 0).astype('uint8')
  68. records = self.transform(records)
  69. return records
  70. def parse_dataset(self):
  71. return
  72. @register
  73. @serializable
  74. class KeypointBottomUpCocoDataset(KeypointBottomUpBaseDataset):
  75. """COCO dataset for bottom-up pose estimation.
  76. The dataset loads raw features and apply specified transforms
  77. to return a dict containing the image tensors and other information.
  78. COCO keypoint indexes::
  79. 0: 'nose',
  80. 1: 'left_eye',
  81. 2: 'right_eye',
  82. 3: 'left_ear',
  83. 4: 'right_ear',
  84. 5: 'left_shoulder',
  85. 6: 'right_shoulder',
  86. 7: 'left_elbow',
  87. 8: 'right_elbow',
  88. 9: 'left_wrist',
  89. 10: 'right_wrist',
  90. 11: 'left_hip',
  91. 12: 'right_hip',
  92. 13: 'left_knee',
  93. 14: 'right_knee',
  94. 15: 'left_ankle',
  95. 16: 'right_ankle'
  96. Args:
  97. dataset_dir (str): Root path to the dataset.
  98. anno_path (str): Relative path to the annotation file.
  99. image_dir (str): Path to a directory where images are held.
  100. Default: None.
  101. num_joints (int): keypoint numbers
  102. transform (composed(operators)): A sequence of data transforms.
  103. shard (list): [rank, worldsize], the distributed env params
  104. test_mode (bool): Store True when building test or
  105. validation dataset. Default: False.
  106. """
  107. def __init__(self,
  108. dataset_dir,
  109. image_dir,
  110. anno_path,
  111. num_joints,
  112. transform=[],
  113. shard=[0, 1],
  114. test_mode=False):
  115. super().__init__(dataset_dir, image_dir, anno_path, num_joints,
  116. transform, shard, test_mode)
  117. ann_file = os.path.join(dataset_dir, anno_path)
  118. self.coco = COCO(ann_file)
  119. self.img_ids = self.coco.getImgIds()
  120. if not test_mode:
  121. self.img_ids = [
  122. img_id for img_id in self.img_ids
  123. if len(self.coco.getAnnIds(
  124. imgIds=img_id, iscrowd=None)) > 0
  125. ]
  126. blocknum = int(len(self.img_ids) / shard[1])
  127. self.img_ids = self.img_ids[(blocknum * shard[0]):(blocknum * (shard[0]
  128. + 1))]
  129. self.num_images = len(self.img_ids)
  130. self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)
  131. self.dataset_name = 'coco'
  132. cat_ids = self.coco.getCatIds()
  133. self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
  134. print(f'=> num_images: {self.num_images}')
  135. @staticmethod
  136. def _get_mapping_id_name(imgs):
  137. """
  138. Args:
  139. imgs (dict): dict of image info.
  140. Returns:
  141. tuple: Image name & id mapping dicts.
  142. - id2name (dict): Mapping image id to name.
  143. - name2id (dict): Mapping image name to id.
  144. """
  145. id2name = {}
  146. name2id = {}
  147. for image_id, image in imgs.items():
  148. file_name = image['file_name']
  149. id2name[image_id] = file_name
  150. name2id[file_name] = image_id
  151. return id2name, name2id
  152. def _get_imganno(self, idx):
  153. """Get anno for a single image.
  154. Args:
  155. idx (int): image idx
  156. Returns:
  157. dict: info for model training
  158. """
  159. coco = self.coco
  160. img_id = self.img_ids[idx]
  161. ann_ids = coco.getAnnIds(imgIds=img_id)
  162. anno = coco.loadAnns(ann_ids)
  163. mask = self._get_mask(anno, idx)
  164. anno = [
  165. obj for obj in anno
  166. if obj['iscrowd'] == 0 or obj['num_keypoints'] > 0
  167. ]
  168. joints, orgsize = self._get_joints(anno, idx)
  169. db_rec = {}
  170. db_rec['im_id'] = img_id
  171. db_rec['image_file'] = os.path.join(self.img_prefix,
  172. self.id2name[img_id])
  173. db_rec['mask'] = mask
  174. db_rec['joints'] = joints
  175. db_rec['im_shape'] = orgsize
  176. return db_rec
  177. def _get_joints(self, anno, idx):
  178. """Get joints for all people in an image."""
  179. num_people = len(anno)
  180. joints = np.zeros(
  181. (num_people, self.ann_info['num_joints'], 3), dtype=np.float32)
  182. for i, obj in enumerate(anno):
  183. joints[i, :self.ann_info['num_joints'], :3] = \
  184. np.array(obj['keypoints']).reshape([-1, 3])
  185. img_info = self.coco.loadImgs(self.img_ids[idx])[0]
  186. joints[..., 0] /= img_info['width']
  187. joints[..., 1] /= img_info['height']
  188. orgsize = np.array([img_info['height'], img_info['width']])
  189. return joints, orgsize
  190. def _get_mask(self, anno, idx):
  191. """Get ignore masks to mask out losses."""
  192. coco = self.coco
  193. img_info = coco.loadImgs(self.img_ids[idx])[0]
  194. m = np.zeros((img_info['height'], img_info['width']), dtype=np.float32)
  195. for obj in anno:
  196. if 'segmentation' in obj:
  197. if obj['iscrowd']:
  198. rle = pycocotools.mask.frPyObjects(obj['segmentation'],
  199. img_info['height'],
  200. img_info['width'])
  201. m += pycocotools.mask.decode(rle)
  202. elif obj['num_keypoints'] == 0:
  203. rles = pycocotools.mask.frPyObjects(obj['segmentation'],
  204. img_info['height'],
  205. img_info['width'])
  206. for rle in rles:
  207. m += pycocotools.mask.decode(rle)
  208. return m < 0.5
  209. @register
  210. @serializable
  211. class KeypointBottomUpCrowdPoseDataset(KeypointBottomUpCocoDataset):
  212. """CrowdPose dataset for bottom-up pose estimation.
  213. The dataset loads raw features and apply specified transforms
  214. to return a dict containing the image tensors and other information.
  215. CrowdPose keypoint indexes::
  216. 0: 'left_shoulder',
  217. 1: 'right_shoulder',
  218. 2: 'left_elbow',
  219. 3: 'right_elbow',
  220. 4: 'left_wrist',
  221. 5: 'right_wrist',
  222. 6: 'left_hip',
  223. 7: 'right_hip',
  224. 8: 'left_knee',
  225. 9: 'right_knee',
  226. 10: 'left_ankle',
  227. 11: 'right_ankle',
  228. 12: 'top_head',
  229. 13: 'neck'
  230. Args:
  231. dataset_dir (str): Root path to the dataset.
  232. anno_path (str): Relative path to the annotation file.
  233. image_dir (str): Path to a directory where images are held.
  234. Default: None.
  235. num_joints (int): keypoint numbers
  236. transform (composed(operators)): A sequence of data transforms.
  237. shard (list): [rank, worldsize], the distributed env params
  238. test_mode (bool): Store True when building test or
  239. validation dataset. Default: False.
  240. """
  241. def __init__(self,
  242. dataset_dir,
  243. image_dir,
  244. anno_path,
  245. num_joints,
  246. transform=[],
  247. shard=[0, 1],
  248. test_mode=False):
  249. super().__init__(dataset_dir, image_dir, anno_path, num_joints,
  250. transform, shard, test_mode)
  251. ann_file = os.path.join(dataset_dir, anno_path)
  252. self.coco = COCO(ann_file)
  253. self.img_ids = self.coco.getImgIds()
  254. if not test_mode:
  255. self.img_ids = [
  256. img_id for img_id in self.img_ids
  257. if len(self.coco.getAnnIds(
  258. imgIds=img_id, iscrowd=None)) > 0
  259. ]
  260. blocknum = int(len(self.img_ids) / shard[1])
  261. self.img_ids = self.img_ids[(blocknum * shard[0]):(blocknum * (shard[0]
  262. + 1))]
  263. self.num_images = len(self.img_ids)
  264. self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)
  265. self.dataset_name = 'crowdpose'
  266. print('=> num_images: {}'.format(self.num_images))
  267. @serializable
  268. class KeypointTopDownBaseDataset(DetDataset):
  269. """Base class for top_down datasets.
  270. All datasets should subclass it.
  271. All subclasses should overwrite:
  272. Methods:`_get_db`
  273. Args:
  274. dataset_dir (str): Root path to the dataset.
  275. image_dir (str): Path to a directory where images are held.
  276. anno_path (str): Relative path to the annotation file.
  277. num_joints (int): keypoint numbers
  278. transform (composed(operators)): A sequence of data transforms.
  279. """
  280. def __init__(self,
  281. dataset_dir,
  282. image_dir,
  283. anno_path,
  284. num_joints,
  285. transform=[]):
  286. super().__init__(dataset_dir, image_dir, anno_path)
  287. self.image_info = {}
  288. self.ann_info = {}
  289. self.img_prefix = os.path.join(dataset_dir, image_dir)
  290. self.transform = transform
  291. self.ann_info['num_joints'] = num_joints
  292. self.db = []
  293. def __len__(self):
  294. """Get dataset length."""
  295. return len(self.db)
  296. def _get_db(self):
  297. """Get a sample"""
  298. raise NotImplementedError
  299. def __getitem__(self, idx):
  300. """Prepare sample for training given the index."""
  301. records = copy.deepcopy(self.db[idx])
  302. records['image'] = cv2.imread(records['image_file'], cv2.IMREAD_COLOR |
  303. cv2.IMREAD_IGNORE_ORIENTATION)
  304. records['image'] = cv2.cvtColor(records['image'], cv2.COLOR_BGR2RGB)
  305. records['score'] = records['score'] if 'score' in records else 1
  306. records = self.transform(records)
  307. # print('records', records)
  308. return records
  309. @register
  310. @serializable
  311. class KeypointTopDownCocoDataset(KeypointTopDownBaseDataset):
  312. """COCO dataset for top-down pose estimation.
  313. The dataset loads raw features and apply specified transforms
  314. to return a dict containing the image tensors and other information.
  315. COCO keypoint indexes:
  316. 0: 'nose',
  317. 1: 'left_eye',
  318. 2: 'right_eye',
  319. 3: 'left_ear',
  320. 4: 'right_ear',
  321. 5: 'left_shoulder',
  322. 6: 'right_shoulder',
  323. 7: 'left_elbow',
  324. 8: 'right_elbow',
  325. 9: 'left_wrist',
  326. 10: 'right_wrist',
  327. 11: 'left_hip',
  328. 12: 'right_hip',
  329. 13: 'left_knee',
  330. 14: 'right_knee',
  331. 15: 'left_ankle',
  332. 16: 'right_ankle'
  333. Args:
  334. dataset_dir (str): Root path to the dataset.
  335. image_dir (str): Path to a directory where images are held.
  336. anno_path (str): Relative path to the annotation file.
  337. num_joints (int): Keypoint numbers
  338. trainsize (list):[w, h] Image target size
  339. transform (composed(operators)): A sequence of data transforms.
  340. bbox_file (str): Path to a detection bbox file
  341. Default: None.
  342. use_gt_bbox (bool): Whether to use ground truth bbox
  343. Default: True.
  344. pixel_std (int): The pixel std of the scale
  345. Default: 200.
  346. image_thre (float): The threshold to filter the detection box
  347. Default: 0.0.
  348. """
  349. def __init__(self,
  350. dataset_dir,
  351. image_dir,
  352. anno_path,
  353. num_joints,
  354. trainsize,
  355. transform=[],
  356. bbox_file=None,
  357. use_gt_bbox=True,
  358. pixel_std=200,
  359. image_thre=0.0):
  360. super().__init__(dataset_dir, image_dir, anno_path, num_joints,
  361. transform)
  362. self.bbox_file = bbox_file
  363. self.use_gt_bbox = use_gt_bbox
  364. self.trainsize = trainsize
  365. self.pixel_std = pixel_std
  366. self.image_thre = image_thre
  367. self.dataset_name = 'coco'
  368. def parse_dataset(self):
  369. if self.use_gt_bbox:
  370. self.db = self._load_coco_keypoint_annotations()
  371. else:
  372. self.db = self._load_coco_person_detection_results()
  373. def _load_coco_keypoint_annotations(self):
  374. coco = COCO(self.get_anno())
  375. img_ids = coco.getImgIds()
  376. gt_db = []
  377. for index in img_ids:
  378. im_ann = coco.loadImgs(index)[0]
  379. width = im_ann['width']
  380. height = im_ann['height']
  381. file_name = im_ann['file_name']
  382. im_id = int(im_ann["id"])
  383. annIds = coco.getAnnIds(imgIds=index, iscrowd=False)
  384. objs = coco.loadAnns(annIds)
  385. valid_objs = []
  386. for obj in objs:
  387. x, y, w, h = obj['bbox']
  388. x1 = np.max((0, x))
  389. y1 = np.max((0, y))
  390. x2 = np.min((width - 1, x1 + np.max((0, w - 1))))
  391. y2 = np.min((height - 1, y1 + np.max((0, h - 1))))
  392. if obj['area'] > 0 and x2 >= x1 and y2 >= y1:
  393. obj['clean_bbox'] = [x1, y1, x2 - x1, y2 - y1]
  394. valid_objs.append(obj)
  395. objs = valid_objs
  396. rec = []
  397. for obj in objs:
  398. if max(obj['keypoints']) == 0:
  399. continue
  400. joints = np.zeros(
  401. (self.ann_info['num_joints'], 3), dtype=np.float)
  402. joints_vis = np.zeros(
  403. (self.ann_info['num_joints'], 3), dtype=np.float)
  404. for ipt in range(self.ann_info['num_joints']):
  405. joints[ipt, 0] = obj['keypoints'][ipt * 3 + 0]
  406. joints[ipt, 1] = obj['keypoints'][ipt * 3 + 1]
  407. joints[ipt, 2] = 0
  408. t_vis = obj['keypoints'][ipt * 3 + 2]
  409. if t_vis > 1:
  410. t_vis = 1
  411. joints_vis[ipt, 0] = t_vis
  412. joints_vis[ipt, 1] = t_vis
  413. joints_vis[ipt, 2] = 0
  414. center, scale = self._box2cs(obj['clean_bbox'][:4])
  415. rec.append({
  416. 'image_file': os.path.join(self.img_prefix, file_name),
  417. 'center': center,
  418. 'scale': scale,
  419. 'joints': joints,
  420. 'joints_vis': joints_vis,
  421. 'im_id': im_id,
  422. })
  423. gt_db.extend(rec)
  424. return gt_db
  425. def _box2cs(self, box):
  426. x, y, w, h = box[:4]
  427. center = np.zeros((2), dtype=np.float32)
  428. center[0] = x + w * 0.5
  429. center[1] = y + h * 0.5
  430. aspect_ratio = self.trainsize[0] * 1.0 / self.trainsize[1]
  431. if w > aspect_ratio * h:
  432. h = w * 1.0 / aspect_ratio
  433. elif w < aspect_ratio * h:
  434. w = h * aspect_ratio
  435. scale = np.array(
  436. [w * 1.0 / self.pixel_std, h * 1.0 / self.pixel_std],
  437. dtype=np.float32)
  438. if center[0] != -1:
  439. scale = scale * 1.25
  440. return center, scale
  441. def _load_coco_person_detection_results(self):
  442. all_boxes = None
  443. bbox_file_path = os.path.join(self.dataset_dir, self.bbox_file)
  444. with open(bbox_file_path, 'r') as f:
  445. all_boxes = json.load(f)
  446. if not all_boxes:
  447. print('=> Load %s fail!' % bbox_file_path)
  448. return None
  449. kpt_db = []
  450. for n_img in range(0, len(all_boxes)):
  451. det_res = all_boxes[n_img]
  452. if det_res['category_id'] != 1:
  453. continue
  454. file_name = det_res[
  455. 'filename'] if 'filename' in det_res else '%012d.jpg' % det_res[
  456. 'image_id']
  457. img_name = os.path.join(self.img_prefix, file_name)
  458. box = det_res['bbox']
  459. score = det_res['score']
  460. im_id = int(det_res['image_id'])
  461. if score < self.image_thre:
  462. continue
  463. center, scale = self._box2cs(box)
  464. joints = np.zeros((self.ann_info['num_joints'], 3), dtype=np.float)
  465. joints_vis = np.ones(
  466. (self.ann_info['num_joints'], 3), dtype=np.float)
  467. kpt_db.append({
  468. 'image_file': img_name,
  469. 'im_id': im_id,
  470. 'center': center,
  471. 'scale': scale,
  472. 'score': score,
  473. 'joints': joints,
  474. 'joints_vis': joints_vis,
  475. })
  476. return kpt_db
  477. @register
  478. @serializable
  479. class KeypointTopDownMPIIDataset(KeypointTopDownBaseDataset):
  480. """MPII dataset for topdown pose estimation.
  481. The dataset loads raw features and apply specified transforms
  482. to return a dict containing the image tensors and other information.
  483. MPII keypoint indexes::
  484. 0: 'right_ankle',
  485. 1: 'right_knee',
  486. 2: 'right_hip',
  487. 3: 'left_hip',
  488. 4: 'left_knee',
  489. 5: 'left_ankle',
  490. 6: 'pelvis',
  491. 7: 'thorax',
  492. 8: 'upper_neck',
  493. 9: 'head_top',
  494. 10: 'right_wrist',
  495. 11: 'right_elbow',
  496. 12: 'right_shoulder',
  497. 13: 'left_shoulder',
  498. 14: 'left_elbow',
  499. 15: 'left_wrist',
  500. Args:
  501. dataset_dir (str): Root path to the dataset.
  502. image_dir (str): Path to a directory where images are held.
  503. anno_path (str): Relative path to the annotation file.
  504. num_joints (int): Keypoint numbers
  505. trainsize (list):[w, h] Image target size
  506. transform (composed(operators)): A sequence of data transforms.
  507. """
  508. def __init__(self,
  509. dataset_dir,
  510. image_dir,
  511. anno_path,
  512. num_joints,
  513. transform=[]):
  514. super().__init__(dataset_dir, image_dir, anno_path, num_joints,
  515. transform)
  516. self.dataset_name = 'mpii'
  517. def parse_dataset(self):
  518. with open(self.get_anno()) as anno_file:
  519. anno = json.load(anno_file)
  520. gt_db = []
  521. for a in anno:
  522. image_name = a['image']
  523. im_id = a['image_id'] if 'image_id' in a else int(
  524. os.path.splitext(image_name)[0])
  525. c = np.array(a['center'], dtype=np.float)
  526. s = np.array([a['scale'], a['scale']], dtype=np.float)
  527. # Adjust center/scale slightly to avoid cropping limbs
  528. if c[0] != -1:
  529. c[1] = c[1] + 15 * s[1]
  530. s = s * 1.25
  531. c = c - 1
  532. joints = np.zeros((self.ann_info['num_joints'], 3), dtype=np.float)
  533. joints_vis = np.zeros(
  534. (self.ann_info['num_joints'], 3), dtype=np.float)
  535. if 'joints' in a:
  536. joints_ = np.array(a['joints'])
  537. joints_[:, 0:2] = joints_[:, 0:2] - 1
  538. joints_vis_ = np.array(a['joints_vis'])
  539. assert len(joints_) == self.ann_info[
  540. 'num_joints'], 'joint num diff: {} vs {}'.format(
  541. len(joints_), self.ann_info['num_joints'])
  542. joints[:, 0:2] = joints_[:, 0:2]
  543. joints_vis[:, 0] = joints_vis_[:]
  544. joints_vis[:, 1] = joints_vis_[:]
  545. gt_db.append({
  546. 'image_file': os.path.join(self.img_prefix, image_name),
  547. 'im_id': im_id,
  548. 'center': c,
  549. 'scale': s,
  550. 'joints': joints,
  551. 'joints_vis': joints_vis
  552. })
  553. self.db = gt_db