utils.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. import shutil
  17. from enum import Enum
  18. import traceback
  19. import chardet
  20. from PIL import Image
  21. import numpy as np
  22. import json
  23. from ..utils import set_folder_status, get_folder_status, DatasetStatus
  24. def copy_directory(src, dst, files=None):
  25. """从src目录copy文件至dst目录,
  26. 注意:拷贝前会先清空dst中的所有文件
  27. Args:
  28. src: 源目录路径
  29. dst: 目标目录路径
  30. files: 需要拷贝的文件列表(src的相对路径)
  31. """
  32. set_folder_status(dst, DatasetStatus.XCOPYING, os.getpid())
  33. if files is None:
  34. files = list_files(src)
  35. try:
  36. message = '{} {}'.format(os.getpid(), len(files))
  37. set_folder_status(dst, DatasetStatus.XCOPYING, message)
  38. if not osp.samefile(src, dst):
  39. for i, f in enumerate(files):
  40. items = osp.split(f)
  41. if len(items) > 2:
  42. continue
  43. if len(items) == 2:
  44. if not osp.isdir(osp.join(dst, items[0])):
  45. if osp.exists(osp.join(dst, items[0])):
  46. os.remove(osp.join(dst, items[0]))
  47. os.makedirs(osp.join(dst, items[0]))
  48. shutil.copy(osp.join(src, f), osp.join(dst, f))
  49. set_folder_status(dst, DatasetStatus.XCOPYDONE)
  50. except Exception as e:
  51. error_info = traceback.format_exc()
  52. set_folder_status(dst, DatasetStatus.XCOPYFAIL, error_info)
  53. def is_pic(filename):
  54. """ 判断文件是否为图片格式
  55. Args:
  56. filename: 文件路径
  57. """
  58. suffixes = {'JPEG', 'jpeg', 'JPG', 'jpg', 'BMP', 'bmp', 'PNG', 'png'}
  59. suffix = filename.strip().split('.')[-1]
  60. if suffix not in suffixes:
  61. return False
  62. return True
  63. def replace_ext(filename, new_ext):
  64. """ 替换文件后缀
  65. Args:
  66. filename: 文件路径
  67. new_ext: 需要替换的新的后缀
  68. """
  69. items = filename.split(".")
  70. items[-1] = new_ext
  71. new_filename = ".".join(items)
  72. return new_filename
  73. def get_encoding(filename):
  74. """ 获取文件编码方式
  75. Args:
  76. filename: 文件路径
  77. """
  78. f = open(filename, 'rb')
  79. data = f.read()
  80. file_encoding = chardet.detect(data).get('encoding')
  81. return file_encoding
  82. def pil_imread(file_path):
  83. """ 获取分割标注图片信息
  84. Args:
  85. filename: 文件路径
  86. """
  87. im = Image.open(file_path)
  88. return np.asarray(im)
  89. def check_list_txt(list_txts):
  90. """ 检查切分信息文件的格式
  91. Args:
  92. list_txts: 包含切分信息文件路径的list
  93. """
  94. for list_txt in list_txts:
  95. if not osp.exists(list_txt):
  96. continue
  97. with open(list_txt) as f:
  98. for line in f:
  99. items = line.strip().split()
  100. if len(items) != 2:
  101. raise Exception('{} 格式错误. 列表应包含两列,由空格分离。'.format(list_txt))
  102. def read_seg_ann(pngfile):
  103. """ 解析语义分割的标注png图片
  104. Args:
  105. pngfile: 包含标注信息的png图片路径
  106. """
  107. grt = pil_imread(pngfile)
  108. labels = list(np.unique(grt))
  109. if 255 in labels:
  110. labels.remove(255)
  111. return labels, grt.shape
  112. def read_coco_ann(img_id, coco, cid2cname, catid2clsid):
  113. img_anno = coco.loadImgs(img_id)[0]
  114. im_w = float(img_anno['width'])
  115. im_h = float(img_anno['height'])
  116. ins_anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=0)
  117. instances = coco.loadAnns(ins_anno_ids)
  118. bboxes = []
  119. for inst in instances:
  120. x, y, box_w, box_h = inst['bbox']
  121. x1 = max(0, x)
  122. y1 = max(0, y)
  123. x2 = min(im_w - 1, x1 + max(0, box_w - 1))
  124. y2 = min(im_h - 1, y1 + max(0, box_h - 1))
  125. if inst['area'] > 0 and x2 >= x1 and y2 >= y1:
  126. inst['clean_bbox'] = [x1, y1, x2, y2]
  127. bboxes.append(inst)
  128. else:
  129. raise Exception("标注文件存在错误")
  130. num_bbox = len(bboxes)
  131. gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
  132. gt_class = [""] * num_bbox
  133. gt_poly = [None] * num_bbox
  134. for i, box in enumerate(bboxes):
  135. catid = box['category_id']
  136. gt_class[i] = cid2cname[catid2clsid[catid]]
  137. gt_bbox[i, :] = box['clean_bbox']
  138. # is_crowd[i][0] = box['iscrowd']
  139. if 'segmentation' in box:
  140. gt_poly[i] = box['segmentation']
  141. anno_dict = {
  142. 'h': im_h,
  143. 'w': im_w,
  144. 'gt_class': gt_class,
  145. 'gt_bbox': gt_bbox,
  146. 'gt_poly': gt_poly,
  147. }
  148. return anno_dict
  149. def get_npy_from_coco_json(coco, npy_path, files):
  150. """ 从实例分割标注的json文件中,获取每张图片的信息,并存为npy格式
  151. Args:
  152. coco: 从json文件中解析出的标注信息
  153. npy_path: npy文件保存的地址
  154. files: 需要生成npy文件的目录
  155. """
  156. img_ids = coco.getImgIds()
  157. cat_ids = coco.getCatIds()
  158. anno_ids = coco.getAnnIds()
  159. catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
  160. cid2cname = dict({
  161. clsid: coco.loadCats(catid)[0]['name']
  162. for catid, clsid in catid2clsid.items()
  163. })
  164. iname2id = dict()
  165. for img_id in img_ids:
  166. img_name = osp.split(coco.loadImgs(img_id)[0]["file_name"])[-1]
  167. iname2id[img_name] = img_id
  168. if not osp.exists(npy_path):
  169. os.makedirs(npy_path)
  170. for img in files:
  171. img_id = iname2id[osp.split(img)[-1]]
  172. anno_dict = read_coco_ann(img_id, coco, cid2cname, catid2clsid)
  173. img_name = osp.split(img)[-1]
  174. npy_name = replace_ext(img_name, "npy")
  175. np.save(osp.join(npy_path, npy_name), anno_dict)
  176. def get_label_count(label_info):
  177. """ 根据存储的label_info字段,计算label_count字段
  178. Args:
  179. label_info: 存储的label_info
  180. """
  181. label_count = dict()
  182. for key in sorted(label_info):
  183. label_count[key] = len(label_info[key])
  184. return label_count
  185. class MyEncoder(json.JSONEncoder):
  186. # 调整json文件存储形式
  187. def default(self, obj):
  188. if isinstance(obj, np.integer):
  189. return int(obj)
  190. elif isinstance(obj, np.floating):
  191. return float(obj)
  192. elif isinstance(obj, np.ndarray):
  193. return obj.tolist()
  194. else:
  195. return super(MyEncoder, self).default(obj)