x2imagenet.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. import os
  16. import os.path as osp
  17. import shutil
  18. from paddlex.utils import is_pic, get_encoding
  19. class X2ImageNet(object):
  20. def __init__(self):
  21. pass
  22. def convert(self, image_dir, json_dir, dataset_save_dir):
  23. """转换。
  24. Args:
  25. image_dir (str): 图像文件存放的路径。
  26. json_dir (str): 与每张图像对应的json文件的存放路径。
  27. dataset_save_dir (str): 转换后数据集存放路径。
  28. """
  29. assert osp.exists(image_dir), "The image folder does not exist!"
  30. assert osp.exists(json_dir), "The json folder does not exist!"
  31. if not osp.exists(dataset_save_dir):
  32. os.makedirs(dataset_save_dir)
  33. assert len(os.listdir(
  34. dataset_save_dir)) == 0, "The save folder must be empty!"
  35. for img_name in os.listdir(image_dir):
  36. img_name_part = osp.splitext(img_name)[0]
  37. json_file = osp.join(json_dir, img_name_part + ".json")
  38. if not osp.exists(json_file):
  39. continue
  40. with open(
  41. json_file, mode="r",
  42. encoding=get_encoding(json_file)) as j:
  43. json_info = self.get_json_info(j)
  44. for output in json_info:
  45. cls_name = output['name']
  46. new_image_dir = osp.join(dataset_save_dir, cls_name)
  47. if not osp.exists(new_image_dir):
  48. os.makedirs(new_image_dir)
  49. if is_pic(img_name):
  50. shutil.copyfile(
  51. osp.join(image_dir, img_name),
  52. osp.join(new_image_dir, img_name))
  53. class EasyData2ImageNet(X2ImageNet):
  54. """将使用EasyData标注的分类数据集转换为ImageNet数据集。
  55. """
  56. def __init__(self):
  57. super(EasyData2ImageNet, self).__init__()
  58. def get_json_info(self, json_file):
  59. json_info = json.load(json_file)
  60. json_info = json_info['labels']
  61. return json_info
  62. class JingLing2ImageNet(X2ImageNet):
  63. """将使用标注精灵标注的分类数据集转换为ImageNet数据集。
  64. """
  65. def __init__(self):
  66. super(X2ImageNet, self).__init__()
  67. def get_json_info(self, json_file):
  68. json_info = json.load(json_file)
  69. json_info = json_info['outputs']['object']
  70. return json_info