keypoint_coco.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import cv2
  16. import numpy as np
  17. import json
  18. import copy
  19. import pycocotools
  20. from pycocotools.coco import COCO
  21. from .dataset import DetDataset
  22. from paddlex.ppdet.core.workspace import register, serializable
  23. @serializable
  24. class KeypointBottomUpBaseDataset(DetDataset):
  25. """Base class for bottom-up datasets. Adapted from
  26. https://github.com/open-mmlab/mmpose
  27. All datasets should subclass it.
  28. All subclasses should overwrite:
  29. Methods:`_get_imganno`
  30. Args:
  31. dataset_dir (str): Root path to the dataset.
  32. anno_path (str): Relative path to the annotation file.
  33. image_dir (str): Path to a directory where images are held.
  34. Default: None.
  35. num_joints (int): keypoint numbers
  36. transform (composed(operators)): A sequence of data transforms.
  37. shard (list): [rank, worldsize], the distributed env params
  38. test_mode (bool): Store True when building test or
  39. validation dataset. Default: False.
  40. """
  41. def __init__(self,
  42. dataset_dir,
  43. image_dir,
  44. anno_path,
  45. num_joints,
  46. transform=[],
  47. shard=[0, 1],
  48. test_mode=False):
  49. super().__init__(dataset_dir, image_dir, anno_path)
  50. self.image_info = {}
  51. self.ann_info = {}
  52. self.img_prefix = os.path.join(dataset_dir, image_dir)
  53. self.transform = transform
  54. self.test_mode = test_mode
  55. self.ann_info['num_joints'] = num_joints
  56. self.img_ids = []
  57. def parse_dataset(self):
  58. pass
  59. def __len__(self):
  60. """Get dataset length."""
  61. return len(self.img_ids)
  62. def _get_imganno(self, idx):
  63. """Get anno for a single image."""
  64. raise NotImplementedError
  65. def __getitem__(self, idx):
  66. """Prepare image for training given the index."""
  67. records = copy.deepcopy(self._get_imganno(idx))
  68. records['image'] = cv2.imread(records['image_file'])
  69. records['image'] = cv2.cvtColor(records['image'], cv2.COLOR_BGR2RGB)
  70. records['mask'] = (records['mask'] + 0).astype('uint8')
  71. records = self.transform(records)
  72. return records
  73. def parse_dataset(self):
  74. return
  75. @register
  76. @serializable
  77. class KeypointBottomUpCocoDataset(KeypointBottomUpBaseDataset):
  78. """COCO dataset for bottom-up pose estimation. Adapted from
  79. https://github.com/open-mmlab/mmpose
  80. The dataset loads raw features and apply specified transforms
  81. to return a dict containing the image tensors and other information.
  82. COCO keypoint indexes::
  83. 0: 'nose',
  84. 1: 'left_eye',
  85. 2: 'right_eye',
  86. 3: 'left_ear',
  87. 4: 'right_ear',
  88. 5: 'left_shoulder',
  89. 6: 'right_shoulder',
  90. 7: 'left_elbow',
  91. 8: 'right_elbow',
  92. 9: 'left_wrist',
  93. 10: 'right_wrist',
  94. 11: 'left_hip',
  95. 12: 'right_hip',
  96. 13: 'left_knee',
  97. 14: 'right_knee',
  98. 15: 'left_ankle',
  99. 16: 'right_ankle'
  100. Args:
  101. dataset_dir (str): Root path to the dataset.
  102. anno_path (str): Relative path to the annotation file.
  103. image_dir (str): Path to a directory where images are held.
  104. Default: None.
  105. num_joints (int): keypoint numbers
  106. transform (composed(operators)): A sequence of data transforms.
  107. shard (list): [rank, worldsize], the distributed env params
  108. test_mode (bool): Store True when building test or
  109. validation dataset. Default: False.
  110. """
  111. def __init__(self,
  112. dataset_dir,
  113. image_dir,
  114. anno_path,
  115. num_joints,
  116. transform=[],
  117. shard=[0, 1],
  118. test_mode=False):
  119. super().__init__(dataset_dir, image_dir, anno_path, num_joints,
  120. transform, shard, test_mode)
  121. self.ann_file = os.path.join(dataset_dir, anno_path)
  122. self.shard = shard
  123. self.test_mode = test_mode
  124. def parse_dataset(self):
  125. self.coco = COCO(self.ann_file)
  126. self.img_ids = self.coco.getImgIds()
  127. if not self.test_mode:
  128. self.img_ids = [
  129. img_id for img_id in self.img_ids
  130. if len(self.coco.getAnnIds(
  131. imgIds=img_id, iscrowd=None)) > 0
  132. ]
  133. blocknum = int(len(self.img_ids) / self.shard[1])
  134. self.img_ids = self.img_ids[(blocknum * self.shard[0]):(blocknum * (
  135. self.shard[0] + 1))]
  136. self.num_images = len(self.img_ids)
  137. self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)
  138. self.dataset_name = 'coco'
  139. cat_ids = self.coco.getCatIds()
  140. self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
  141. print('=> num_images: {}'.format(self.num_images))
  142. @staticmethod
  143. def _get_mapping_id_name(imgs):
  144. """
  145. Args:
  146. imgs (dict): dict of image info.
  147. Returns:
  148. tuple: Image name & id mapping dicts.
  149. - id2name (dict): Mapping image id to name.
  150. - name2id (dict): Mapping image name to id.
  151. """
  152. id2name = {}
  153. name2id = {}
  154. for image_id, image in imgs.items():
  155. file_name = image['file_name']
  156. id2name[image_id] = file_name
  157. name2id[file_name] = image_id
  158. return id2name, name2id
  159. def _get_imganno(self, idx):
  160. """Get anno for a single image.
  161. Args:
  162. idx (int): image idx
  163. Returns:
  164. dict: info for model training
  165. """
  166. coco = self.coco
  167. img_id = self.img_ids[idx]
  168. ann_ids = coco.getAnnIds(imgIds=img_id)
  169. anno = coco.loadAnns(ann_ids)
  170. mask = self._get_mask(anno, idx)
  171. anno = [
  172. obj for obj in anno
  173. if obj['iscrowd'] == 0 or obj['num_keypoints'] > 0
  174. ]
  175. joints, orgsize = self._get_joints(anno, idx)
  176. db_rec = {}
  177. db_rec['im_id'] = img_id
  178. db_rec['image_file'] = os.path.join(self.img_prefix,
  179. self.id2name[img_id])
  180. db_rec['mask'] = mask
  181. db_rec['joints'] = joints
  182. db_rec['im_shape'] = orgsize
  183. return db_rec
  184. def _get_joints(self, anno, idx):
  185. """Get joints for all people in an image."""
  186. num_people = len(anno)
  187. joints = np.zeros(
  188. (num_people, self.ann_info['num_joints'], 3), dtype=np.float32)
  189. for i, obj in enumerate(anno):
  190. joints[i, :self.ann_info['num_joints'], :3] = \
  191. np.array(obj['keypoints']).reshape([-1, 3])
  192. img_info = self.coco.loadImgs(self.img_ids[idx])[0]
  193. joints[..., 0] /= img_info['width']
  194. joints[..., 1] /= img_info['height']
  195. orgsize = np.array([img_info['height'], img_info['width']])
  196. return joints, orgsize
  197. def _get_mask(self, anno, idx):
  198. """Get ignore masks to mask out losses."""
  199. coco = self.coco
  200. img_info = coco.loadImgs(self.img_ids[idx])[0]
  201. m = np.zeros((img_info['height'], img_info['width']), dtype=np.float32)
  202. for obj in anno:
  203. if 'segmentation' in obj:
  204. if obj['iscrowd']:
  205. rle = pycocotools.mask.frPyObjects(obj['segmentation'],
  206. img_info['height'],
  207. img_info['width'])
  208. m += pycocotools.mask.decode(rle)
  209. elif obj['num_keypoints'] == 0:
  210. rles = pycocotools.mask.frPyObjects(obj['segmentation'],
  211. img_info['height'],
  212. img_info['width'])
  213. for rle in rles:
  214. m += pycocotools.mask.decode(rle)
  215. return m < 0.5
  216. @register
  217. @serializable
  218. class KeypointBottomUpCrowdPoseDataset(KeypointBottomUpCocoDataset):
  219. """CrowdPose dataset for bottom-up pose estimation. Adapted from
  220. https://github.com/open-mmlab/mmpose
  221. The dataset loads raw features and apply specified transforms
  222. to return a dict containing the image tensors and other information.
  223. CrowdPose keypoint indexes::
  224. 0: 'left_shoulder',
  225. 1: 'right_shoulder',
  226. 2: 'left_elbow',
  227. 3: 'right_elbow',
  228. 4: 'left_wrist',
  229. 5: 'right_wrist',
  230. 6: 'left_hip',
  231. 7: 'right_hip',
  232. 8: 'left_knee',
  233. 9: 'right_knee',
  234. 10: 'left_ankle',
  235. 11: 'right_ankle',
  236. 12: 'top_head',
  237. 13: 'neck'
  238. Args:
  239. dataset_dir (str): Root path to the dataset.
  240. anno_path (str): Relative path to the annotation file.
  241. image_dir (str): Path to a directory where images are held.
  242. Default: None.
  243. num_joints (int): keypoint numbers
  244. transform (composed(operators)): A sequence of data transforms.
  245. shard (list): [rank, worldsize], the distributed env params
  246. test_mode (bool): Store True when building test or
  247. validation dataset. Default: False.
  248. """
  249. def __init__(self,
  250. dataset_dir,
  251. image_dir,
  252. anno_path,
  253. num_joints,
  254. transform=[],
  255. shard=[0, 1],
  256. test_mode=False):
  257. super().__init__(dataset_dir, image_dir, anno_path, num_joints,
  258. transform, shard, test_mode)
  259. self.ann_file = os.path.join(dataset_dir, anno_path)
  260. self.shard = shard
  261. self.test_mode = test_mode
  262. def parse_dataset(self):
  263. self.coco = COCO(self.ann_file)
  264. self.img_ids = self.coco.getImgIds()
  265. if not self.test_mode:
  266. self.img_ids = [
  267. img_id for img_id in self.img_ids
  268. if len(self.coco.getAnnIds(
  269. imgIds=img_id, iscrowd=None)) > 0
  270. ]
  271. blocknum = int(len(self.img_ids) / self.shard[1])
  272. self.img_ids = self.img_ids[(blocknum * self.shard[0]):(blocknum * (
  273. self.shard[0] + 1))]
  274. self.num_images = len(self.img_ids)
  275. self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)
  276. self.dataset_name = 'crowdpose'
  277. print('=> num_images: {}'.format(self.num_images))
  278. @serializable
  279. class KeypointTopDownBaseDataset(DetDataset):
  280. """Base class for top_down datasets.
  281. All datasets should subclass it.
  282. All subclasses should overwrite:
  283. Methods:`_get_db`
  284. Args:
  285. dataset_dir (str): Root path to the dataset.
  286. image_dir (str): Path to a directory where images are held.
  287. anno_path (str): Relative path to the annotation file.
  288. num_joints (int): keypoint numbers
  289. transform (composed(operators)): A sequence of data transforms.
  290. """
  291. def __init__(self,
  292. dataset_dir,
  293. image_dir,
  294. anno_path,
  295. num_joints,
  296. transform=[]):
  297. super().__init__(dataset_dir, image_dir, anno_path)
  298. self.image_info = {}
  299. self.ann_info = {}
  300. self.img_prefix = os.path.join(dataset_dir, image_dir)
  301. self.transform = transform
  302. self.ann_info['num_joints'] = num_joints
  303. self.db = []
  304. def __len__(self):
  305. """Get dataset length."""
  306. return len(self.db)
  307. def _get_db(self):
  308. """Get a sample"""
  309. raise NotImplementedError
  310. def __getitem__(self, idx):
  311. """Prepare sample for training given the index."""
  312. records = copy.deepcopy(self.db[idx])
  313. records['image'] = cv2.imread(records['image_file'], cv2.IMREAD_COLOR |
  314. cv2.IMREAD_IGNORE_ORIENTATION)
  315. records['image'] = cv2.cvtColor(records['image'], cv2.COLOR_BGR2RGB)
  316. records['score'] = records['score'] if 'score' in records else 1
  317. records = self.transform(records)
  318. # print('records', records)
  319. return records
  320. @register
  321. @serializable
  322. class KeypointTopDownCocoDataset(KeypointTopDownBaseDataset):
  323. """COCO dataset for top-down pose estimation. Adapted from
  324. https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
  325. Copyright (c) Microsoft, under the MIT License.
  326. The dataset loads raw features and apply specified transforms
  327. to return a dict containing the image tensors and other information.
  328. COCO keypoint indexes:
  329. 0: 'nose',
  330. 1: 'left_eye',
  331. 2: 'right_eye',
  332. 3: 'left_ear',
  333. 4: 'right_ear',
  334. 5: 'left_shoulder',
  335. 6: 'right_shoulder',
  336. 7: 'left_elbow',
  337. 8: 'right_elbow',
  338. 9: 'left_wrist',
  339. 10: 'right_wrist',
  340. 11: 'left_hip',
  341. 12: 'right_hip',
  342. 13: 'left_knee',
  343. 14: 'right_knee',
  344. 15: 'left_ankle',
  345. 16: 'right_ankle'
  346. Args:
  347. dataset_dir (str): Root path to the dataset.
  348. image_dir (str): Path to a directory where images are held.
  349. anno_path (str): Relative path to the annotation file.
  350. num_joints (int): Keypoint numbers
  351. trainsize (list):[w, h] Image target size
  352. transform (composed(operators)): A sequence of data transforms.
  353. bbox_file (str): Path to a detection bbox file
  354. Default: None.
  355. use_gt_bbox (bool): Whether to use ground truth bbox
  356. Default: True.
  357. pixel_std (int): The pixel std of the scale
  358. Default: 200.
  359. image_thre (float): The threshold to filter the detection box
  360. Default: 0.0.
  361. """
  362. def __init__(self,
  363. dataset_dir,
  364. image_dir,
  365. anno_path,
  366. num_joints,
  367. trainsize,
  368. transform=[],
  369. bbox_file=None,
  370. use_gt_bbox=True,
  371. pixel_std=200,
  372. image_thre=0.0):
  373. super().__init__(dataset_dir, image_dir, anno_path, num_joints,
  374. transform)
  375. self.bbox_file = bbox_file
  376. self.use_gt_bbox = use_gt_bbox
  377. self.trainsize = trainsize
  378. self.pixel_std = pixel_std
  379. self.image_thre = image_thre
  380. self.dataset_name = 'coco'
  381. def parse_dataset(self):
  382. if self.use_gt_bbox:
  383. self.db = self._load_coco_keypoint_annotations()
  384. else:
  385. self.db = self._load_coco_person_detection_results()
  386. def _load_coco_keypoint_annotations(self):
  387. coco = COCO(self.get_anno())
  388. img_ids = coco.getImgIds()
  389. gt_db = []
  390. for index in img_ids:
  391. im_ann = coco.loadImgs(index)[0]
  392. width = im_ann['width']
  393. height = im_ann['height']
  394. file_name = im_ann['file_name']
  395. im_id = int(im_ann["id"])
  396. annIds = coco.getAnnIds(imgIds=index, iscrowd=False)
  397. objs = coco.loadAnns(annIds)
  398. valid_objs = []
  399. for obj in objs:
  400. x, y, w, h = obj['bbox']
  401. x1 = np.max((0, x))
  402. y1 = np.max((0, y))
  403. x2 = np.min((width - 1, x1 + np.max((0, w - 1))))
  404. y2 = np.min((height - 1, y1 + np.max((0, h - 1))))
  405. if obj['area'] > 0 and x2 >= x1 and y2 >= y1:
  406. obj['clean_bbox'] = [x1, y1, x2 - x1, y2 - y1]
  407. valid_objs.append(obj)
  408. objs = valid_objs
  409. rec = []
  410. for obj in objs:
  411. if max(obj['keypoints']) == 0:
  412. continue
  413. joints = np.zeros(
  414. (self.ann_info['num_joints'], 3), dtype=np.float)
  415. joints_vis = np.zeros(
  416. (self.ann_info['num_joints'], 3), dtype=np.float)
  417. for ipt in range(self.ann_info['num_joints']):
  418. joints[ipt, 0] = obj['keypoints'][ipt * 3 + 0]
  419. joints[ipt, 1] = obj['keypoints'][ipt * 3 + 1]
  420. joints[ipt, 2] = 0
  421. t_vis = obj['keypoints'][ipt * 3 + 2]
  422. if t_vis > 1:
  423. t_vis = 1
  424. joints_vis[ipt, 0] = t_vis
  425. joints_vis[ipt, 1] = t_vis
  426. joints_vis[ipt, 2] = 0
  427. center, scale = self._box2cs(obj['clean_bbox'][:4])
  428. rec.append({
  429. 'image_file': os.path.join(self.img_prefix, file_name),
  430. 'center': center,
  431. 'scale': scale,
  432. 'joints': joints,
  433. 'joints_vis': joints_vis,
  434. 'im_id': im_id,
  435. })
  436. gt_db.extend(rec)
  437. return gt_db
  438. def _box2cs(self, box):
  439. x, y, w, h = box[:4]
  440. center = np.zeros((2), dtype=np.float32)
  441. center[0] = x + w * 0.5
  442. center[1] = y + h * 0.5
  443. aspect_ratio = self.trainsize[0] * 1.0 / self.trainsize[1]
  444. if w > aspect_ratio * h:
  445. h = w * 1.0 / aspect_ratio
  446. elif w < aspect_ratio * h:
  447. w = h * aspect_ratio
  448. scale = np.array(
  449. [w * 1.0 / self.pixel_std, h * 1.0 / self.pixel_std],
  450. dtype=np.float32)
  451. if center[0] != -1:
  452. scale = scale * 1.25
  453. return center, scale
  454. def _load_coco_person_detection_results(self):
  455. all_boxes = None
  456. bbox_file_path = os.path.join(self.dataset_dir, self.bbox_file)
  457. with open(bbox_file_path, 'r') as f:
  458. all_boxes = json.load(f)
  459. if not all_boxes:
  460. print('=> Load %s fail!' % bbox_file_path)
  461. return None
  462. kpt_db = []
  463. for n_img in range(0, len(all_boxes)):
  464. det_res = all_boxes[n_img]
  465. if det_res['category_id'] != 1:
  466. continue
  467. file_name = det_res[
  468. 'filename'] if 'filename' in det_res else '%012d.jpg' % det_res[
  469. 'image_id']
  470. img_name = os.path.join(self.img_prefix, file_name)
  471. box = det_res['bbox']
  472. score = det_res['score']
  473. im_id = int(det_res['image_id'])
  474. if score < self.image_thre:
  475. continue
  476. center, scale = self._box2cs(box)
  477. joints = np.zeros((self.ann_info['num_joints'], 3), dtype=np.float)
  478. joints_vis = np.ones(
  479. (self.ann_info['num_joints'], 3), dtype=np.float)
  480. kpt_db.append({
  481. 'image_file': img_name,
  482. 'im_id': im_id,
  483. 'center': center,
  484. 'scale': scale,
  485. 'score': score,
  486. 'joints': joints,
  487. 'joints_vis': joints_vis,
  488. })
  489. return kpt_db
  490. @register
  491. @serializable
  492. class KeypointTopDownMPIIDataset(KeypointTopDownBaseDataset):
  493. """MPII dataset for topdown pose estimation. Adapted from
  494. https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
  495. Copyright (c) Microsoft, under the MIT License.
  496. The dataset loads raw features and apply specified transforms
  497. to return a dict containing the image tensors and other information.
  498. MPII keypoint indexes::
  499. 0: 'right_ankle',
  500. 1: 'right_knee',
  501. 2: 'right_hip',
  502. 3: 'left_hip',
  503. 4: 'left_knee',
  504. 5: 'left_ankle',
  505. 6: 'pelvis',
  506. 7: 'thorax',
  507. 8: 'upper_neck',
  508. 9: 'head_top',
  509. 10: 'right_wrist',
  510. 11: 'right_elbow',
  511. 12: 'right_shoulder',
  512. 13: 'left_shoulder',
  513. 14: 'left_elbow',
  514. 15: 'left_wrist',
  515. Args:
  516. dataset_dir (str): Root path to the dataset.
  517. image_dir (str): Path to a directory where images are held.
  518. anno_path (str): Relative path to the annotation file.
  519. num_joints (int): Keypoint numbers
  520. trainsize (list):[w, h] Image target size
  521. transform (composed(operators)): A sequence of data transforms.
  522. """
  523. def __init__(self,
  524. dataset_dir,
  525. image_dir,
  526. anno_path,
  527. num_joints,
  528. transform=[]):
  529. super().__init__(dataset_dir, image_dir, anno_path, num_joints,
  530. transform)
  531. self.dataset_name = 'mpii'
  532. def parse_dataset(self):
  533. with open(self.get_anno()) as anno_file:
  534. anno = json.load(anno_file)
  535. gt_db = []
  536. for a in anno:
  537. image_name = a['image']
  538. im_id = a['image_id'] if 'image_id' in a else int(
  539. os.path.splitext(image_name)[0])
  540. c = np.array(a['center'], dtype=np.float)
  541. s = np.array([a['scale'], a['scale']], dtype=np.float)
  542. # Adjust center/scale slightly to avoid cropping limbs
  543. if c[0] != -1:
  544. c[1] = c[1] + 15 * s[1]
  545. s = s * 1.25
  546. c = c - 1
  547. joints = np.zeros((self.ann_info['num_joints'], 3), dtype=np.float)
  548. joints_vis = np.zeros(
  549. (self.ann_info['num_joints'], 3), dtype=np.float)
  550. if 'joints' in a:
  551. joints_ = np.array(a['joints'])
  552. joints_[:, 0:2] = joints_[:, 0:2] - 1
  553. joints_vis_ = np.array(a['joints_vis'])
  554. assert len(joints_) == self.ann_info[
  555. 'num_joints'], 'joint num diff: {} vs {}'.format(
  556. len(joints_), self.ann_info['num_joints'])
  557. joints[:, 0:2] = joints_[:, 0:2]
  558. joints_vis[:, 0] = joints_vis_[:]
  559. joints_vis[:, 1] = joints_vis_[:]
  560. gt_db.append({
  561. 'image_file': os.path.join(self.img_prefix, image_name),
  562. 'im_id': im_id,
  563. 'center': c,
  564. 'scale': s,
  565. 'joints': joints,
  566. 'joints_vis': joints_vis
  567. })
  568. print("number length: {}".format(len(gt_db)))
  569. self.db = gt_db