keypoint_coco.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import cv2
  16. import numpy as np
  17. import json
  18. import copy
  19. import pycocotools
  20. from pycocotools.coco import COCO
  21. from .dataset import DetDataset
  22. from paddlex.ppdet.core.workspace import register, serializable
  23. @serializable
  24. class KeypointBottomUpBaseDataset(DetDataset):
  25. """Base class for bottom-up datasets. Adapted from
  26. https://github.com/open-mmlab/mmpose
  27. All datasets should subclass it.
  28. All subclasses should overwrite:
  29. Methods:`_get_imganno`
  30. Args:
  31. dataset_dir (str): Root path to the dataset.
  32. anno_path (str): Relative path to the annotation file.
  33. image_dir (str): Path to a directory where images are held.
  34. Default: None.
  35. num_joints (int): keypoint numbers
  36. transform (composed(operators)): A sequence of data transforms.
  37. shard (list): [rank, worldsize], the distributed env params
  38. test_mode (bool): Store True when building test or
  39. validation dataset. Default: False.
  40. """
  41. def __init__(self,
  42. dataset_dir,
  43. image_dir,
  44. anno_path,
  45. num_joints,
  46. transform=[],
  47. shard=[0, 1],
  48. test_mode=False):
  49. super().__init__(dataset_dir, image_dir, anno_path)
  50. self.image_info = {}
  51. self.ann_info = {}
  52. self.img_prefix = os.path.join(dataset_dir, image_dir)
  53. self.transform = transform
  54. self.test_mode = test_mode
  55. self.ann_info['num_joints'] = num_joints
  56. self.img_ids = []
  57. def __len__(self):
  58. """Get dataset length."""
  59. return len(self.img_ids)
  60. def _get_imganno(self, idx):
  61. """Get anno for a single image."""
  62. raise NotImplementedError
  63. def __getitem__(self, idx):
  64. """Prepare image for training given the index."""
  65. records = copy.deepcopy(self._get_imganno(idx))
  66. records['image'] = cv2.imread(records['image_file'])
  67. records['image'] = cv2.cvtColor(records['image'], cv2.COLOR_BGR2RGB)
  68. records['mask'] = (records['mask'] + 0).astype('uint8')
  69. records = self.transform(records)
  70. return records
  71. def parse_dataset(self):
  72. return
  73. @register
  74. @serializable
  75. class KeypointBottomUpCocoDataset(KeypointBottomUpBaseDataset):
  76. """COCO dataset for bottom-up pose estimation. Adapted from
  77. https://github.com/open-mmlab/mmpose
  78. The dataset loads raw features and apply specified transforms
  79. to return a dict containing the image tensors and other information.
  80. COCO keypoint indexes::
  81. 0: 'nose',
  82. 1: 'left_eye',
  83. 2: 'right_eye',
  84. 3: 'left_ear',
  85. 4: 'right_ear',
  86. 5: 'left_shoulder',
  87. 6: 'right_shoulder',
  88. 7: 'left_elbow',
  89. 8: 'right_elbow',
  90. 9: 'left_wrist',
  91. 10: 'right_wrist',
  92. 11: 'left_hip',
  93. 12: 'right_hip',
  94. 13: 'left_knee',
  95. 14: 'right_knee',
  96. 15: 'left_ankle',
  97. 16: 'right_ankle'
  98. Args:
  99. dataset_dir (str): Root path to the dataset.
  100. anno_path (str): Relative path to the annotation file.
  101. image_dir (str): Path to a directory where images are held.
  102. Default: None.
  103. num_joints (int): keypoint numbers
  104. transform (composed(operators)): A sequence of data transforms.
  105. shard (list): [rank, worldsize], the distributed env params
  106. test_mode (bool): Store True when building test or
  107. validation dataset. Default: False.
  108. """
  109. def __init__(self,
  110. dataset_dir,
  111. image_dir,
  112. anno_path,
  113. num_joints,
  114. transform=[],
  115. shard=[0, 1],
  116. test_mode=False):
  117. super().__init__(dataset_dir, image_dir, anno_path, num_joints,
  118. transform, shard, test_mode)
  119. ann_file = os.path.join(dataset_dir, anno_path)
  120. self.coco = COCO(ann_file)
  121. self.img_ids = self.coco.getImgIds()
  122. if not test_mode:
  123. self.img_ids = [
  124. img_id for img_id in self.img_ids
  125. if len(self.coco.getAnnIds(
  126. imgIds=img_id, iscrowd=None)) > 0
  127. ]
  128. blocknum = int(len(self.img_ids) / shard[1])
  129. self.img_ids = self.img_ids[(blocknum * shard[0]):(blocknum * (shard[0]
  130. + 1))]
  131. self.num_images = len(self.img_ids)
  132. self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)
  133. self.dataset_name = 'coco'
  134. cat_ids = self.coco.getCatIds()
  135. self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
  136. print(f'=> num_images: {self.num_images}')
  137. @staticmethod
  138. def _get_mapping_id_name(imgs):
  139. """
  140. Args:
  141. imgs (dict): dict of image info.
  142. Returns:
  143. tuple: Image name & id mapping dicts.
  144. - id2name (dict): Mapping image id to name.
  145. - name2id (dict): Mapping image name to id.
  146. """
  147. id2name = {}
  148. name2id = {}
  149. for image_id, image in imgs.items():
  150. file_name = image['file_name']
  151. id2name[image_id] = file_name
  152. name2id[file_name] = image_id
  153. return id2name, name2id
  154. def _get_imganno(self, idx):
  155. """Get anno for a single image.
  156. Args:
  157. idx (int): image idx
  158. Returns:
  159. dict: info for model training
  160. """
  161. coco = self.coco
  162. img_id = self.img_ids[idx]
  163. ann_ids = coco.getAnnIds(imgIds=img_id)
  164. anno = coco.loadAnns(ann_ids)
  165. mask = self._get_mask(anno, idx)
  166. anno = [
  167. obj for obj in anno
  168. if obj['iscrowd'] == 0 or obj['num_keypoints'] > 0
  169. ]
  170. joints, orgsize = self._get_joints(anno, idx)
  171. db_rec = {}
  172. db_rec['im_id'] = img_id
  173. db_rec['image_file'] = os.path.join(self.img_prefix,
  174. self.id2name[img_id])
  175. db_rec['mask'] = mask
  176. db_rec['joints'] = joints
  177. db_rec['im_shape'] = orgsize
  178. return db_rec
  179. def _get_joints(self, anno, idx):
  180. """Get joints for all people in an image."""
  181. num_people = len(anno)
  182. joints = np.zeros(
  183. (num_people, self.ann_info['num_joints'], 3), dtype=np.float32)
  184. for i, obj in enumerate(anno):
  185. joints[i, :self.ann_info['num_joints'], :3] = \
  186. np.array(obj['keypoints']).reshape([-1, 3])
  187. img_info = self.coco.loadImgs(self.img_ids[idx])[0]
  188. joints[..., 0] /= img_info['width']
  189. joints[..., 1] /= img_info['height']
  190. orgsize = np.array([img_info['height'], img_info['width']])
  191. return joints, orgsize
  192. def _get_mask(self, anno, idx):
  193. """Get ignore masks to mask out losses."""
  194. coco = self.coco
  195. img_info = coco.loadImgs(self.img_ids[idx])[0]
  196. m = np.zeros((img_info['height'], img_info['width']), dtype=np.float32)
  197. for obj in anno:
  198. if 'segmentation' in obj:
  199. if obj['iscrowd']:
  200. rle = pycocotools.mask.frPyObjects(obj['segmentation'],
  201. img_info['height'],
  202. img_info['width'])
  203. m += pycocotools.mask.decode(rle)
  204. elif obj['num_keypoints'] == 0:
  205. rles = pycocotools.mask.frPyObjects(obj['segmentation'],
  206. img_info['height'],
  207. img_info['width'])
  208. for rle in rles:
  209. m += pycocotools.mask.decode(rle)
  210. return m < 0.5
  211. @register
  212. @serializable
  213. class KeypointBottomUpCrowdPoseDataset(KeypointBottomUpCocoDataset):
  214. """CrowdPose dataset for bottom-up pose estimation. Adapted from
  215. https://github.com/open-mmlab/mmpose
  216. The dataset loads raw features and apply specified transforms
  217. to return a dict containing the image tensors and other information.
  218. CrowdPose keypoint indexes::
  219. 0: 'left_shoulder',
  220. 1: 'right_shoulder',
  221. 2: 'left_elbow',
  222. 3: 'right_elbow',
  223. 4: 'left_wrist',
  224. 5: 'right_wrist',
  225. 6: 'left_hip',
  226. 7: 'right_hip',
  227. 8: 'left_knee',
  228. 9: 'right_knee',
  229. 10: 'left_ankle',
  230. 11: 'right_ankle',
  231. 12: 'top_head',
  232. 13: 'neck'
  233. Args:
  234. dataset_dir (str): Root path to the dataset.
  235. anno_path (str): Relative path to the annotation file.
  236. image_dir (str): Path to a directory where images are held.
  237. Default: None.
  238. num_joints (int): keypoint numbers
  239. transform (composed(operators)): A sequence of data transforms.
  240. shard (list): [rank, worldsize], the distributed env params
  241. test_mode (bool): Store True when building test or
  242. validation dataset. Default: False.
  243. """
  244. def __init__(self,
  245. dataset_dir,
  246. image_dir,
  247. anno_path,
  248. num_joints,
  249. transform=[],
  250. shard=[0, 1],
  251. test_mode=False):
  252. super().__init__(dataset_dir, image_dir, anno_path, num_joints,
  253. transform, shard, test_mode)
  254. ann_file = os.path.join(dataset_dir, anno_path)
  255. self.coco = COCO(ann_file)
  256. self.img_ids = self.coco.getImgIds()
  257. if not test_mode:
  258. self.img_ids = [
  259. img_id for img_id in self.img_ids
  260. if len(self.coco.getAnnIds(
  261. imgIds=img_id, iscrowd=None)) > 0
  262. ]
  263. blocknum = int(len(self.img_ids) / shard[1])
  264. self.img_ids = self.img_ids[(blocknum * shard[0]):(blocknum * (shard[0]
  265. + 1))]
  266. self.num_images = len(self.img_ids)
  267. self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)
  268. self.dataset_name = 'crowdpose'
  269. print('=> num_images: {}'.format(self.num_images))
  270. @serializable
  271. class KeypointTopDownBaseDataset(DetDataset):
  272. """Base class for top_down datasets.
  273. All datasets should subclass it.
  274. All subclasses should overwrite:
  275. Methods:`_get_db`
  276. Args:
  277. dataset_dir (str): Root path to the dataset.
  278. image_dir (str): Path to a directory where images are held.
  279. anno_path (str): Relative path to the annotation file.
  280. num_joints (int): keypoint numbers
  281. transform (composed(operators)): A sequence of data transforms.
  282. """
  283. def __init__(self,
  284. dataset_dir,
  285. image_dir,
  286. anno_path,
  287. num_joints,
  288. transform=[]):
  289. super().__init__(dataset_dir, image_dir, anno_path)
  290. self.image_info = {}
  291. self.ann_info = {}
  292. self.img_prefix = os.path.join(dataset_dir, image_dir)
  293. self.transform = transform
  294. self.ann_info['num_joints'] = num_joints
  295. self.db = []
  296. def __len__(self):
  297. """Get dataset length."""
  298. return len(self.db)
  299. def _get_db(self):
  300. """Get a sample"""
  301. raise NotImplementedError
  302. def __getitem__(self, idx):
  303. """Prepare sample for training given the index."""
  304. records = copy.deepcopy(self.db[idx])
  305. records['image'] = cv2.imread(records['image_file'], cv2.IMREAD_COLOR |
  306. cv2.IMREAD_IGNORE_ORIENTATION)
  307. records['image'] = cv2.cvtColor(records['image'], cv2.COLOR_BGR2RGB)
  308. records['score'] = records['score'] if 'score' in records else 1
  309. records = self.transform(records)
  310. # print('records', records)
  311. return records
  312. @register
  313. @serializable
  314. class KeypointTopDownCocoDataset(KeypointTopDownBaseDataset):
  315. """COCO dataset for top-down pose estimation. Adapted from
  316. https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
  317. Copyright (c) Microsoft, under the MIT License.
  318. The dataset loads raw features and apply specified transforms
  319. to return a dict containing the image tensors and other information.
  320. COCO keypoint indexes:
  321. 0: 'nose',
  322. 1: 'left_eye',
  323. 2: 'right_eye',
  324. 3: 'left_ear',
  325. 4: 'right_ear',
  326. 5: 'left_shoulder',
  327. 6: 'right_shoulder',
  328. 7: 'left_elbow',
  329. 8: 'right_elbow',
  330. 9: 'left_wrist',
  331. 10: 'right_wrist',
  332. 11: 'left_hip',
  333. 12: 'right_hip',
  334. 13: 'left_knee',
  335. 14: 'right_knee',
  336. 15: 'left_ankle',
  337. 16: 'right_ankle'
  338. Args:
  339. dataset_dir (str): Root path to the dataset.
  340. image_dir (str): Path to a directory where images are held.
  341. anno_path (str): Relative path to the annotation file.
  342. num_joints (int): Keypoint numbers
  343. trainsize (list):[w, h] Image target size
  344. transform (composed(operators)): A sequence of data transforms.
  345. bbox_file (str): Path to a detection bbox file
  346. Default: None.
  347. use_gt_bbox (bool): Whether to use ground truth bbox
  348. Default: True.
  349. pixel_std (int): The pixel std of the scale
  350. Default: 200.
  351. image_thre (float): The threshold to filter the detection box
  352. Default: 0.0.
  353. """
  354. def __init__(self,
  355. dataset_dir,
  356. image_dir,
  357. anno_path,
  358. num_joints,
  359. trainsize,
  360. transform=[],
  361. bbox_file=None,
  362. use_gt_bbox=True,
  363. pixel_std=200,
  364. image_thre=0.0):
  365. super().__init__(dataset_dir, image_dir, anno_path, num_joints,
  366. transform)
  367. self.bbox_file = bbox_file
  368. self.use_gt_bbox = use_gt_bbox
  369. self.trainsize = trainsize
  370. self.pixel_std = pixel_std
  371. self.image_thre = image_thre
  372. self.dataset_name = 'coco'
  373. def parse_dataset(self):
  374. if self.use_gt_bbox:
  375. self.db = self._load_coco_keypoint_annotations()
  376. else:
  377. self.db = self._load_coco_person_detection_results()
  378. def _load_coco_keypoint_annotations(self):
  379. coco = COCO(self.get_anno())
  380. img_ids = coco.getImgIds()
  381. gt_db = []
  382. for index in img_ids:
  383. im_ann = coco.loadImgs(index)[0]
  384. width = im_ann['width']
  385. height = im_ann['height']
  386. file_name = im_ann['file_name']
  387. im_id = int(im_ann["id"])
  388. annIds = coco.getAnnIds(imgIds=index, iscrowd=False)
  389. objs = coco.loadAnns(annIds)
  390. valid_objs = []
  391. for obj in objs:
  392. x, y, w, h = obj['bbox']
  393. x1 = np.max((0, x))
  394. y1 = np.max((0, y))
  395. x2 = np.min((width - 1, x1 + np.max((0, w - 1))))
  396. y2 = np.min((height - 1, y1 + np.max((0, h - 1))))
  397. if obj['area'] > 0 and x2 >= x1 and y2 >= y1:
  398. obj['clean_bbox'] = [x1, y1, x2 - x1, y2 - y1]
  399. valid_objs.append(obj)
  400. objs = valid_objs
  401. rec = []
  402. for obj in objs:
  403. if max(obj['keypoints']) == 0:
  404. continue
  405. joints = np.zeros(
  406. (self.ann_info['num_joints'], 3), dtype=np.float)
  407. joints_vis = np.zeros(
  408. (self.ann_info['num_joints'], 3), dtype=np.float)
  409. for ipt in range(self.ann_info['num_joints']):
  410. joints[ipt, 0] = obj['keypoints'][ipt * 3 + 0]
  411. joints[ipt, 1] = obj['keypoints'][ipt * 3 + 1]
  412. joints[ipt, 2] = 0
  413. t_vis = obj['keypoints'][ipt * 3 + 2]
  414. if t_vis > 1:
  415. t_vis = 1
  416. joints_vis[ipt, 0] = t_vis
  417. joints_vis[ipt, 1] = t_vis
  418. joints_vis[ipt, 2] = 0
  419. center, scale = self._box2cs(obj['clean_bbox'][:4])
  420. rec.append({
  421. 'image_file': os.path.join(self.img_prefix, file_name),
  422. 'center': center,
  423. 'scale': scale,
  424. 'joints': joints,
  425. 'joints_vis': joints_vis,
  426. 'im_id': im_id,
  427. })
  428. gt_db.extend(rec)
  429. return gt_db
  430. def _box2cs(self, box):
  431. x, y, w, h = box[:4]
  432. center = np.zeros((2), dtype=np.float32)
  433. center[0] = x + w * 0.5
  434. center[1] = y + h * 0.5
  435. aspect_ratio = self.trainsize[0] * 1.0 / self.trainsize[1]
  436. if w > aspect_ratio * h:
  437. h = w * 1.0 / aspect_ratio
  438. elif w < aspect_ratio * h:
  439. w = h * aspect_ratio
  440. scale = np.array(
  441. [w * 1.0 / self.pixel_std, h * 1.0 / self.pixel_std],
  442. dtype=np.float32)
  443. if center[0] != -1:
  444. scale = scale * 1.25
  445. return center, scale
  446. def _load_coco_person_detection_results(self):
  447. all_boxes = None
  448. bbox_file_path = os.path.join(self.dataset_dir, self.bbox_file)
  449. with open(bbox_file_path, 'r') as f:
  450. all_boxes = json.load(f)
  451. if not all_boxes:
  452. print('=> Load %s fail!' % bbox_file_path)
  453. return None
  454. kpt_db = []
  455. for n_img in range(0, len(all_boxes)):
  456. det_res = all_boxes[n_img]
  457. if det_res['category_id'] != 1:
  458. continue
  459. file_name = det_res[
  460. 'filename'] if 'filename' in det_res else '%012d.jpg' % det_res[
  461. 'image_id']
  462. img_name = os.path.join(self.img_prefix, file_name)
  463. box = det_res['bbox']
  464. score = det_res['score']
  465. im_id = int(det_res['image_id'])
  466. if score < self.image_thre:
  467. continue
  468. center, scale = self._box2cs(box)
  469. joints = np.zeros((self.ann_info['num_joints'], 3), dtype=np.float)
  470. joints_vis = np.ones(
  471. (self.ann_info['num_joints'], 3), dtype=np.float)
  472. kpt_db.append({
  473. 'image_file': img_name,
  474. 'im_id': im_id,
  475. 'center': center,
  476. 'scale': scale,
  477. 'score': score,
  478. 'joints': joints,
  479. 'joints_vis': joints_vis,
  480. })
  481. return kpt_db
  482. @register
  483. @serializable
  484. class KeypointTopDownMPIIDataset(KeypointTopDownBaseDataset):
  485. """MPII dataset for topdown pose estimation. Adapted from
  486. https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
  487. Copyright (c) Microsoft, under the MIT License.
  488. The dataset loads raw features and apply specified transforms
  489. to return a dict containing the image tensors and other information.
  490. MPII keypoint indexes::
  491. 0: 'right_ankle',
  492. 1: 'right_knee',
  493. 2: 'right_hip',
  494. 3: 'left_hip',
  495. 4: 'left_knee',
  496. 5: 'left_ankle',
  497. 6: 'pelvis',
  498. 7: 'thorax',
  499. 8: 'upper_neck',
  500. 9: 'head_top',
  501. 10: 'right_wrist',
  502. 11: 'right_elbow',
  503. 12: 'right_shoulder',
  504. 13: 'left_shoulder',
  505. 14: 'left_elbow',
  506. 15: 'left_wrist',
  507. Args:
  508. dataset_dir (str): Root path to the dataset.
  509. image_dir (str): Path to a directory where images are held.
  510. anno_path (str): Relative path to the annotation file.
  511. num_joints (int): Keypoint numbers
  512. trainsize (list):[w, h] Image target size
  513. transform (composed(operators)): A sequence of data transforms.
  514. """
  515. def __init__(self,
  516. dataset_dir,
  517. image_dir,
  518. anno_path,
  519. num_joints,
  520. transform=[]):
  521. super().__init__(dataset_dir, image_dir, anno_path, num_joints,
  522. transform)
  523. self.dataset_name = 'mpii'
  524. def parse_dataset(self):
  525. with open(self.get_anno()) as anno_file:
  526. anno = json.load(anno_file)
  527. gt_db = []
  528. for a in anno:
  529. image_name = a['image']
  530. im_id = a['image_id'] if 'image_id' in a else int(
  531. os.path.splitext(image_name)[0])
  532. c = np.array(a['center'], dtype=np.float)
  533. s = np.array([a['scale'], a['scale']], dtype=np.float)
  534. # Adjust center/scale slightly to avoid cropping limbs
  535. if c[0] != -1:
  536. c[1] = c[1] + 15 * s[1]
  537. s = s * 1.25
  538. c = c - 1
  539. joints = np.zeros((self.ann_info['num_joints'], 3), dtype=np.float)
  540. joints_vis = np.zeros(
  541. (self.ann_info['num_joints'], 3), dtype=np.float)
  542. if 'joints' in a:
  543. joints_ = np.array(a['joints'])
  544. joints_[:, 0:2] = joints_[:, 0:2] - 1
  545. joints_vis_ = np.array(a['joints_vis'])
  546. assert len(joints_) == self.ann_info[
  547. 'num_joints'], 'joint num diff: {} vs {}'.format(
  548. len(joints_), self.ann_info['num_joints'])
  549. joints[:, 0:2] = joints_[:, 0:2]
  550. joints_vis[:, 0] = joints_vis_[:]
  551. joints_vis[:, 1] = joints_vis_[:]
  552. gt_db.append({
  553. 'image_file': os.path.join(self.img_prefix, image_name),
  554. 'im_id': im_id,
  555. 'center': c,
  556. 'scale': s,
  557. 'joints': joints,
  558. 'joints_vis': joints_vis
  559. })
  560. print("number length: {}".format(len(gt_db)))
  561. self.db = gt_db