det_transforms.py 72 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. try:
  15. from collections.abc import Sequence
  16. except Exception:
  17. from collections import Sequence
  18. import random
  19. import os.path as osp
  20. import numpy as np
  21. import cv2
  22. from PIL import Image, ImageEnhance
  23. from .imgaug_support import execute_imgaug
  24. from .ops import *
  25. from .box_utils import *
  26. import paddlex.utils.logging as logging
  27. class DetTransform:
  28. """检测数据处理基类
  29. """
  30. def __init__(self):
  31. pass
  32. class Compose(DetTransform):
  33. """根据数据预处理/增强列表对输入数据进行操作。
  34. 所有操作的输入图像流形状均是[H, W, C],其中H为图像高,W为图像宽,C为图像通道数。
  35. Args:
  36. transforms (list): 数据预处理/增强列表。
  37. Raises:
  38. TypeError: 形参数据类型不满足需求。
  39. ValueError: 数据长度不匹配。
  40. """
  41. def __init__(self, transforms):
  42. if not isinstance(transforms, list):
  43. raise TypeError('The transforms must be a list!')
  44. if len(transforms) < 1:
  45. raise ValueError('The length of transforms ' + \
  46. 'must be equal or larger than 1!')
  47. self.transforms = transforms
  48. self.batch_transforms = None
  49. self.use_mixup = False
  50. for t in self.transforms:
  51. if type(t).__name__ == 'MixupImage':
  52. self.use_mixup = True
  53. # 检查transforms里面的操作,目前支持PaddleX定义的或者是imgaug操作
  54. for op in self.transforms:
  55. if not isinstance(op, DetTransform):
  56. import imgaug.augmenters as iaa
  57. if not isinstance(op, iaa.Augmenter):
  58. raise Exception(
  59. "Elements in transforms should be defined in 'paddlex.det.transforms' or class of imgaug.augmenters.Augmenter, see docs here: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/"
  60. )
  61. def __call__(self, im, im_info=None, label_info=None):
  62. """
  63. Args:
  64. im (str/np.ndarray): 图像路径/图像np.ndarray数据。
  65. im_info (dict): 存储与图像相关的信息,dict中的字段如下:
  66. - im_id (np.ndarray): 图像序列号,形状为(1,)。
  67. - image_shape (np.ndarray): 图像原始大小,形状为(2,),
  68. image_shape[0]为高,image_shape[1]为宽。
  69. - mixup (list): list为[im, im_info, label_info],分别对应
  70. 与当前图像进行mixup的图像np.ndarray数据、图像相关信息、标注框相关信息;
  71. 注意,当前epoch若无需进行mixup,则无该字段。
  72. label_info (dict): 存储与标注框相关的信息,dict中的字段如下:
  73. - gt_bbox (np.ndarray): 真实标注框坐标[x1, y1, x2, y2],形状为(n, 4),
  74. 其中n代表真实标注框的个数。
  75. - gt_class (np.ndarray): 每个真实标注框对应的类别序号,形状为(n, 1),
  76. 其中n代表真实标注框的个数。
  77. - gt_score (np.ndarray): 每个真实标注框对应的混合得分,形状为(n, 1),
  78. 其中n代表真实标注框的个数。
  79. - gt_poly (list): 每个真实标注框内的多边形分割区域,每个分割区域由点的x、y坐标组成,
  80. 长度为n,其中n代表真实标注框的个数。
  81. - is_crowd (np.ndarray): 每个真实标注框中是否是一组对象,形状为(n, 1),
  82. 其中n代表真实标注框的个数。
  83. - difficult (np.ndarray): 每个真实标注框中的对象是否为难识别对象,形状为(n, 1),
  84. 其中n代表真实标注框的个数。
  85. Returns:
  86. tuple: 根据网络所需字段所组成的tuple;
  87. 字段由transforms中的最后一个数据预处理操作决定。
  88. """
  89. def decode_image(im_file, im_info, label_info):
  90. if im_info is None:
  91. im_info = dict()
  92. if isinstance(im_file, np.ndarray):
  93. if len(im_file.shape) != 3:
  94. raise Exception(
  95. "im should be 3-dimensions, but now is {}-dimensions".
  96. format(len(im_file.shape)))
  97. im = im_file
  98. else:
  99. try:
  100. im = cv2.imread(im_file).astype('float32')
  101. except:
  102. raise TypeError('Can\'t read The image file {}!'.format(
  103. im_file))
  104. im = im.astype('float32')
  105. im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
  106. # make default im_info with [h, w, 1]
  107. im_info['im_resize_info'] = np.array(
  108. [im.shape[0], im.shape[1], 1.], dtype=np.float32)
  109. im_info['image_shape'] = np.array([im.shape[0],
  110. im.shape[1]]).astype('int32')
  111. if not self.use_mixup:
  112. if 'mixup' in im_info:
  113. del im_info['mixup']
  114. # decode mixup image
  115. if 'mixup' in im_info:
  116. im_info['mixup'] = \
  117. decode_image(im_info['mixup'][0],
  118. im_info['mixup'][1],
  119. im_info['mixup'][2])
  120. if label_info is None:
  121. return (im, im_info)
  122. else:
  123. return (im, im_info, label_info)
  124. outputs = decode_image(im, im_info, label_info)
  125. im = outputs[0]
  126. im_info = outputs[1]
  127. if len(outputs) == 3:
  128. label_info = outputs[2]
  129. for op in self.transforms:
  130. if im is None:
  131. return None
  132. if isinstance(op, DetTransform):
  133. outputs = op(im, im_info, label_info)
  134. im = outputs[0]
  135. else:
  136. im = execute_imgaug(op, im)
  137. if label_info is not None:
  138. outputs = (im, im_info, label_info)
  139. else:
  140. outputs = (im, im_info)
  141. return outputs
  142. def add_augmenters(self, augmenters):
  143. if not isinstance(augmenters, list):
  144. raise Exception(
  145. "augmenters should be list type in func add_augmenters()")
  146. transform_names = [type(x).__name__ for x in self.transforms]
  147. for aug in augmenters:
  148. if type(aug).__name__ in transform_names:
  149. logging.error(
  150. "{} is already in ComposedTransforms, need to remove it from add_augmenters().".
  151. format(type(aug).__name__))
  152. self.transforms = augmenters + self.transforms
  153. class ResizeByShort(DetTransform):
  154. """根据图像的短边调整图像大小(resize)。
  155. 1. 获取图像的长边和短边长度。
  156. 2. 根据短边与short_size的比例,计算长边的目标长度,
  157. 此时高、宽的resize比例为short_size/原图短边长度。
  158. 3. 如果max_size>0,调整resize比例:
  159. 如果长边的目标长度>max_size,则高、宽的resize比例为max_size/原图长边长度。
  160. 4. 根据调整大小的比例对图像进行resize。
  161. Args:
  162. target_size (int): 短边目标长度。默认为800。
  163. max_size (int): 长边目标长度的最大限制。默认为1333。
  164. Raises:
  165. TypeError: 形参数据类型不满足需求。
  166. """
  167. def __init__(self, short_size=800, max_size=1333):
  168. self.max_size = int(max_size)
  169. if not isinstance(short_size, int):
  170. raise TypeError(
  171. "Type of short_size is invalid. Must be Integer, now is {}".
  172. format(type(short_size)))
  173. self.short_size = short_size
  174. if not (isinstance(self.max_size, int)):
  175. raise TypeError("max_size: input type is invalid.")
  176. def __call__(self, im, im_info=None, label_info=None):
  177. """
  178. Args:
  179. im (numnp.ndarraypy): 图像np.ndarray数据。
  180. im_info (dict, 可选): 存储与图像相关的信息。
  181. label_info (dict, 可选): 存储与标注框相关的信息。
  182. Returns:
  183. tuple: 当label_info为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  184. 当label_info不为空时,返回的tuple为(im, im_info, label_info),分别对应图像np.ndarray数据、
  185. 存储与标注框相关信息的字典。
  186. 其中,im_info更新字段为:
  187. - im_resize_info (np.ndarray): resize后的图像高、resize后的图像宽、resize后的图像相对原始图的缩放比例
  188. 三者组成的np.ndarray,形状为(3,)。
  189. Raises:
  190. TypeError: 形参数据类型不满足需求。
  191. ValueError: 数据长度不匹配。
  192. """
  193. if im_info is None:
  194. im_info = dict()
  195. if not isinstance(im, np.ndarray):
  196. raise TypeError("ResizeByShort: image type is not numpy.")
  197. if len(im.shape) != 3:
  198. raise ValueError('ResizeByShort: image is not 3-dimensional.')
  199. im_short_size = min(im.shape[0], im.shape[1])
  200. im_long_size = max(im.shape[0], im.shape[1])
  201. scale = float(self.short_size) / im_short_size
  202. if self.max_size > 0 and np.round(scale *
  203. im_long_size) > self.max_size:
  204. scale = float(self.max_size) / float(im_long_size)
  205. resized_width = int(round(im.shape[1] * scale))
  206. resized_height = int(round(im.shape[0] * scale))
  207. im_resize_info = [resized_height, resized_width, scale]
  208. im = cv2.resize(
  209. im, (resized_width, resized_height),
  210. interpolation=cv2.INTER_LINEAR)
  211. im_info['im_resize_info'] = np.array(im_resize_info).astype(np.float32)
  212. if label_info is None:
  213. return (im, im_info)
  214. else:
  215. return (im, im_info, label_info)
  216. class Padding(DetTransform):
  217. """1.将图像的长和宽padding至coarsest_stride的倍数。如输入图像为[300, 640],
  218. `coarest_stride`为32,则由于300不为32的倍数,因此在图像最右和最下使用0值
  219. 进行padding,最终输出图像为[320, 640]。
  220. 2.或者,将图像的长和宽padding到target_size指定的shape,如输入的图像为[300,640],
  221. a. `target_size` = 960,在图像最右和最下使用0值进行padding,最终输出
  222. 图像为[960, 960]。
  223. b. `target_size` = [640, 960],在图像最右和最下使用0值进行padding,最终
  224. 输出图像为[640, 960]。
  225. 1. 如果coarsest_stride为1,target_size为None则直接返回。
  226. 2. 获取图像的高H、宽W。
  227. 3. 计算填充后图像的高H_new、宽W_new。
  228. 4. 构建大小为(H_new, W_new, 3)像素值为0的np.ndarray,
  229. 并将原图的np.ndarray粘贴于左上角。
  230. Args:
  231. coarsest_stride (int): 填充后的图像长、宽为该参数的倍数,默认为1。
  232. target_size (int|list|tuple): 填充后的图像长、宽,默认为None,coarset_stride优先级更高。
  233. Raises:
  234. TypeError: 形参`target_size`数据类型不满足需求。
  235. ValueError: 形参`target_size`为(list|tuple)时,长度不满足需求。
  236. """
  237. def __init__(self, coarsest_stride=1, target_size=None):
  238. self.coarsest_stride = coarsest_stride
  239. if target_size is not None:
  240. if not isinstance(target_size, int):
  241. if not isinstance(target_size, tuple) and not isinstance(
  242. target_size, list):
  243. raise TypeError(
  244. "Padding: Type of target_size must in (int|list|tuple)."
  245. )
  246. elif len(target_size) != 2:
  247. raise ValueError(
  248. "Padding: Length of target_size must equal 2.")
  249. self.target_size = target_size
  250. def __call__(self, im, im_info=None, label_info=None):
  251. """
  252. Args:
  253. im (numnp.ndarraypy): 图像np.ndarray数据。
  254. im_info (dict, 可选): 存储与图像相关的信息。
  255. label_info (dict, 可选): 存储与标注框相关的信息。
  256. Returns:
  257. tuple: 当label_info为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  258. 当label_info不为空时,返回的tuple为(im, im_info, label_info),分别对应图像np.ndarray数据、
  259. 存储与标注框相关信息的字典。
  260. Raises:
  261. TypeError: 形参数据类型不满足需求。
  262. ValueError: 数据长度不匹配。
  263. ValueError: coarsest_stride,target_size需有且只有一个被指定。
  264. ValueError: target_size小于原图的大小。
  265. """
  266. if im_info is None:
  267. im_info = dict()
  268. if not isinstance(im, np.ndarray):
  269. raise TypeError("Padding: image type is not numpy.")
  270. if len(im.shape) != 3:
  271. raise ValueError('Padding: image is not 3-dimensional.')
  272. im_h, im_w, im_c = im.shape[:]
  273. if isinstance(self.target_size, int):
  274. padding_im_h = self.target_size
  275. padding_im_w = self.target_size
  276. elif isinstance(self.target_size, list) or isinstance(self.target_size,
  277. tuple):
  278. padding_im_w = self.target_size[0]
  279. padding_im_h = self.target_size[1]
  280. elif self.coarsest_stride > 0:
  281. padding_im_h = int(
  282. np.ceil(im_h / self.coarsest_stride) * self.coarsest_stride)
  283. padding_im_w = int(
  284. np.ceil(im_w / self.coarsest_stride) * self.coarsest_stride)
  285. else:
  286. raise ValueError(
  287. "coarsest_stridei(>1) or target_size(list|int) need setting in Padding transform"
  288. )
  289. pad_height = padding_im_h - im_h
  290. pad_width = padding_im_w - im_w
  291. if pad_height < 0 or pad_width < 0:
  292. raise ValueError(
  293. 'the size of image should be less than target_size, but the size of image ({}, {}), is larger than target_size ({}, {})'
  294. .format(im_w, im_h, padding_im_w, padding_im_h))
  295. padding_im = np.zeros(
  296. (padding_im_h, padding_im_w, im_c), dtype=np.float32)
  297. padding_im[:im_h, :im_w, :] = im
  298. if label_info is None:
  299. return (padding_im, im_info)
  300. else:
  301. return (padding_im, im_info, label_info)
  302. class Resize(DetTransform):
  303. """调整图像大小(resize)。
  304. - 当目标大小(target_size)类型为int时,根据插值方式,
  305. 将图像resize为[target_size, target_size]。
  306. - 当目标大小(target_size)类型为list或tuple时,根据插值方式,
  307. 将图像resize为target_size。
  308. 注意:当插值方式为“RANDOM”时,则随机选取一种插值方式进行resize。
  309. Args:
  310. target_size (int/list/tuple): 短边目标长度。默认为608。
  311. interp (str): resize的插值方式,与opencv的插值方式对应,取值范围为
  312. ['NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM']。默认为"LINEAR"。
  313. Raises:
  314. TypeError: 形参数据类型不满足需求。
  315. ValueError: 插值方式不在['NEAREST', 'LINEAR', 'CUBIC',
  316. 'AREA', 'LANCZOS4', 'RANDOM']中。
  317. """
  318. # The interpolation mode
  319. interp_dict = {
  320. 'NEAREST': cv2.INTER_NEAREST,
  321. 'LINEAR': cv2.INTER_LINEAR,
  322. 'CUBIC': cv2.INTER_CUBIC,
  323. 'AREA': cv2.INTER_AREA,
  324. 'LANCZOS4': cv2.INTER_LANCZOS4
  325. }
  326. def __init__(self, target_size=608, interp='LINEAR'):
  327. self.interp = interp
  328. if not (interp == "RANDOM" or interp in self.interp_dict):
  329. raise ValueError("interp should be one of {}".format(
  330. self.interp_dict.keys()))
  331. if isinstance(target_size, list) or isinstance(target_size, tuple):
  332. if len(target_size) != 2:
  333. raise TypeError(
  334. 'when target is list or tuple, it should include 2 elements, but it is {}'
  335. .format(target_size))
  336. elif not isinstance(target_size, int):
  337. raise TypeError(
  338. "Type of target_size is invalid. Must be Integer or List or tuple, now is {}"
  339. .format(type(target_size)))
  340. self.target_size = target_size
  341. def __call__(self, im, im_info=None, label_info=None):
  342. """
  343. Args:
  344. im (np.ndarray): 图像np.ndarray数据。
  345. im_info (dict, 可选): 存储与图像相关的信息。
  346. label_info (dict, 可选): 存储与标注框相关的信息。
  347. Returns:
  348. tuple: 当label_info为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  349. 当label_info不为空时,返回的tuple为(im, im_info, label_info),分别对应图像np.ndarray数据、
  350. 存储与标注框相关信息的字典。
  351. Raises:
  352. TypeError: 形参数据类型不满足需求。
  353. ValueError: 数据长度不匹配。
  354. """
  355. if im_info is None:
  356. im_info = dict()
  357. if not isinstance(im, np.ndarray):
  358. raise TypeError("Resize: image type is not numpy.")
  359. if len(im.shape) != 3:
  360. raise ValueError('Resize: image is not 3-dimensional.')
  361. if self.interp == "RANDOM":
  362. interp = random.choice(list(self.interp_dict.keys()))
  363. else:
  364. interp = self.interp
  365. im = resize(im, self.target_size, self.interp_dict[interp])
  366. if label_info is None:
  367. return (im, im_info)
  368. else:
  369. return (im, im_info, label_info)
  370. class RandomHorizontalFlip(DetTransform):
  371. """随机翻转图像、标注框、分割信息,模型训练时的数据增强操作。
  372. 1. 随机采样一个0-1之间的小数,当小数小于水平翻转概率时,
  373. 执行2-4步操作,否则直接返回。
  374. 2. 水平翻转图像。
  375. 3. 计算翻转后的真实标注框的坐标,更新label_info中的gt_bbox信息。
  376. 4. 计算翻转后的真实分割区域的坐标,更新label_info中的gt_poly信息。
  377. Args:
  378. prob (float): 随机水平翻转的概率。默认为0.5。
  379. Raises:
  380. TypeError: 形参数据类型不满足需求。
  381. """
  382. def __init__(self, prob=0.5):
  383. self.prob = prob
  384. if not isinstance(self.prob, float):
  385. raise TypeError("RandomHorizontalFlip: input type is invalid.")
  386. def __call__(self, im, im_info=None, label_info=None):
  387. """
  388. Args:
  389. im (np.ndarray): 图像np.ndarray数据。
  390. im_info (dict, 可选): 存储与图像相关的信息。
  391. label_info (dict, 可选): 存储与标注框相关的信息。
  392. Returns:
  393. tuple: 当label_info为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  394. 当label_info不为空时,返回的tuple为(im, im_info, label_info),分别对应图像np.ndarray数据、
  395. 存储与标注框相关信息的字典。
  396. 其中,im_info更新字段为:
  397. - gt_bbox (np.ndarray): 水平翻转后的标注框坐标[x1, y1, x2, y2],形状为(n, 4),
  398. 其中n代表真实标注框的个数。
  399. - gt_poly (list): 水平翻转后的多边形分割区域的x、y坐标,长度为n,
  400. 其中n代表真实标注框的个数。
  401. Raises:
  402. TypeError: 形参数据类型不满足需求。
  403. ValueError: 数据长度不匹配。
  404. """
  405. if not isinstance(im, np.ndarray):
  406. raise TypeError(
  407. "RandomHorizontalFlip: image is not a numpy array.")
  408. if len(im.shape) != 3:
  409. raise ValueError(
  410. "RandomHorizontalFlip: image is not 3-dimensional.")
  411. if im_info is None or label_info is None:
  412. raise TypeError(
  413. 'Cannot do RandomHorizontalFlip! ' +
  414. 'Becasuse the im_info and label_info can not be None!')
  415. if 'gt_bbox' not in label_info:
  416. raise TypeError('Cannot do RandomHorizontalFlip! ' + \
  417. 'Becasuse gt_bbox is not in label_info!')
  418. image_shape = im_info['image_shape']
  419. gt_bbox = label_info['gt_bbox']
  420. height = image_shape[0]
  421. width = image_shape[1]
  422. if np.random.uniform(0, 1) < self.prob:
  423. im = horizontal_flip(im)
  424. if gt_bbox.shape[0] == 0:
  425. if label_info is None:
  426. return (im, im_info)
  427. else:
  428. return (im, im_info, label_info)
  429. label_info['gt_bbox'] = box_horizontal_flip(gt_bbox, width)
  430. if 'gt_poly' in label_info and \
  431. len(label_info['gt_poly']) != 0:
  432. label_info['gt_poly'] = segms_horizontal_flip(
  433. label_info['gt_poly'], height, width)
  434. if label_info is None:
  435. return (im, im_info)
  436. else:
  437. return (im, im_info, label_info)
  438. class Normalize(DetTransform):
  439. """对图像进行标准化。
  440. 1. 归一化图像到到区间[0.0, 1.0]。
  441. 2. 对图像进行减均值除以标准差操作。
  442. Args:
  443. mean (list): 图像数据集的均值。默认为[0.485, 0.456, 0.406]。
  444. std (list): 图像数据集的标准差。默认为[0.229, 0.224, 0.225]。
  445. Raises:
  446. TypeError: 形参数据类型不满足需求。
  447. """
  448. def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
  449. self.mean = mean
  450. self.std = std
  451. if not (isinstance(self.mean, list) and isinstance(self.std, list)):
  452. raise TypeError("NormalizeImage: input type is invalid.")
  453. from functools import reduce
  454. if reduce(lambda x, y: x * y, self.std) == 0:
  455. raise TypeError('NormalizeImage: std is invalid!')
  456. def __call__(self, im, im_info=None, label_info=None):
  457. """
  458. Args:
  459. im (numnp.ndarraypy): 图像np.ndarray数据。
  460. im_info (dict, 可选): 存储与图像相关的信息。
  461. label_info (dict, 可选): 存储与标注框相关的信息。
  462. Returns:
  463. tuple: 当label_info为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  464. 当label_info不为空时,返回的tuple为(im, im_info, label_info),分别对应图像np.ndarray数据、
  465. 存储与标注框相关信息的字典。
  466. """
  467. mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
  468. std = np.array(self.std)[np.newaxis, np.newaxis, :]
  469. im = normalize(im, mean, std)
  470. if label_info is None:
  471. return (im, im_info)
  472. else:
  473. return (im, im_info, label_info)
  474. class RandomDistort(DetTransform):
  475. """以一定的概率对图像进行随机像素内容变换,模型训练时的数据增强操作
  476. 1. 对变换的操作顺序进行随机化操作。
  477. 2. 按照1中的顺序以一定的概率在范围[-range, range]对图像进行随机像素内容变换。
  478. Args:
  479. brightness_range (float): 明亮度因子的范围。默认为0.5。
  480. brightness_prob (float): 随机调整明亮度的概率。默认为0.5。
  481. contrast_range (float): 对比度因子的范围。默认为0.5。
  482. contrast_prob (float): 随机调整对比度的概率。默认为0.5。
  483. saturation_range (float): 饱和度因子的范围。默认为0.5。
  484. saturation_prob (float): 随机调整饱和度的概率。默认为0.5。
  485. hue_range (int): 色调因子的范围。默认为18。
  486. hue_prob (float): 随机调整色调的概率。默认为0.5。
  487. """
  488. def __init__(self,
  489. brightness_range=0.5,
  490. brightness_prob=0.5,
  491. contrast_range=0.5,
  492. contrast_prob=0.5,
  493. saturation_range=0.5,
  494. saturation_prob=0.5,
  495. hue_range=18,
  496. hue_prob=0.5):
  497. self.brightness_range = brightness_range
  498. self.brightness_prob = brightness_prob
  499. self.contrast_range = contrast_range
  500. self.contrast_prob = contrast_prob
  501. self.saturation_range = saturation_range
  502. self.saturation_prob = saturation_prob
  503. self.hue_range = hue_range
  504. self.hue_prob = hue_prob
  505. def __call__(self, im, im_info=None, label_info=None):
  506. """
  507. Args:
  508. im (np.ndarray): 图像np.ndarray数据。
  509. im_info (dict, 可选): 存储与图像相关的信息。
  510. label_info (dict, 可选): 存储与标注框相关的信息。
  511. Returns:
  512. tuple: 当label_info为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  513. 当label_info不为空时,返回的tuple为(im, im_info, label_info),分别对应图像np.ndarray数据、
  514. 存储与标注框相关信息的字典。
  515. """
  516. brightness_lower = 1 - self.brightness_range
  517. brightness_upper = 1 + self.brightness_range
  518. contrast_lower = 1 - self.contrast_range
  519. contrast_upper = 1 + self.contrast_range
  520. saturation_lower = 1 - self.saturation_range
  521. saturation_upper = 1 + self.saturation_range
  522. hue_lower = -self.hue_range
  523. hue_upper = self.hue_range
  524. ops = [brightness, contrast, saturation, hue]
  525. random.shuffle(ops)
  526. params_dict = {
  527. 'brightness': {
  528. 'brightness_lower': brightness_lower,
  529. 'brightness_upper': brightness_upper
  530. },
  531. 'contrast': {
  532. 'contrast_lower': contrast_lower,
  533. 'contrast_upper': contrast_upper
  534. },
  535. 'saturation': {
  536. 'saturation_lower': saturation_lower,
  537. 'saturation_upper': saturation_upper
  538. },
  539. 'hue': {
  540. 'hue_lower': hue_lower,
  541. 'hue_upper': hue_upper
  542. }
  543. }
  544. prob_dict = {
  545. 'brightness': self.brightness_prob,
  546. 'contrast': self.contrast_prob,
  547. 'saturation': self.saturation_prob,
  548. 'hue': self.hue_prob
  549. }
  550. for id in range(4):
  551. params = params_dict[ops[id].__name__]
  552. prob = prob_dict[ops[id].__name__]
  553. params['im'] = im
  554. if np.random.uniform(0, 1) < prob:
  555. im = ops[id](**params)
  556. im = im.astype('float32')
  557. if label_info is None:
  558. return (im, im_info)
  559. else:
  560. return (im, im_info, label_info)
  561. class MixupImage(DetTransform):
  562. """对图像进行mixup操作,模型训练时的数据增强操作,目前仅YOLOv3模型支持该transform。
  563. 当label_info中不存在mixup字段时,直接返回,否则进行下述操作:
  564. 1. 从随机beta分布中抽取出随机因子factor。
  565. 2.
  566. - 当factor>=1.0时,去除label_info中的mixup字段,直接返回。
  567. - 当factor<=0.0时,直接返回label_info中的mixup字段,并在label_info中去除该字段。
  568. - 其余情况,执行下述操作:
  569. (1)原图像乘以factor,mixup图像乘以(1-factor),叠加2个结果。
  570. (2)拼接原图像标注框和mixup图像标注框。
  571. (3)拼接原图像标注框类别和mixup图像标注框类别。
  572. (4)原图像标注框混合得分乘以factor,mixup图像标注框混合得分乘以(1-factor),叠加2个结果。
  573. 3. 更新im_info中的image_shape信息。
  574. Args:
  575. alpha (float): 随机beta分布的下限。默认为1.5。
  576. beta (float): 随机beta分布的上限。默认为1.5。
  577. mixup_epoch (int): 在前mixup_epoch轮使用mixup增强操作;当该参数为-1时,该策略不会生效。
  578. 默认为-1。
  579. Raises:
  580. ValueError: 数据长度不匹配。
  581. """
  582. def __init__(self, alpha=1.5, beta=1.5, mixup_epoch=-1):
  583. self.alpha = alpha
  584. self.beta = beta
  585. if self.alpha <= 0.0:
  586. raise ValueError("alpha shold be positive in MixupImage")
  587. if self.beta <= 0.0:
  588. raise ValueError("beta shold be positive in MixupImage")
  589. self.mixup_epoch = mixup_epoch
  590. def _mixup_img(self, img1, img2, factor):
  591. h = max(img1.shape[0], img2.shape[0])
  592. w = max(img1.shape[1], img2.shape[1])
  593. img = np.zeros((h, w, img1.shape[2]), 'float32')
  594. img[:img1.shape[0], :img1.shape[1], :] = \
  595. img1.astype('float32') * factor
  596. img[:img2.shape[0], :img2.shape[1], :] += \
  597. img2.astype('float32') * (1.0 - factor)
  598. return img.astype('float32')
  599. def __call__(self, im, im_info=None, label_info=None):
  600. """
  601. Args:
  602. im (np.ndarray): 图像np.ndarray数据。
  603. im_info (dict, 可选): 存储与图像相关的信息。
  604. label_info (dict, 可选): 存储与标注框相关的信息。
  605. Returns:
  606. tuple: 当label_info为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  607. 当label_info不为空时,返回的tuple为(im, im_info, label_info),分别对应图像np.ndarray数据、
  608. 存储与标注框相关信息的字典。
  609. 其中,im_info更新字段为:
  610. - image_shape (np.ndarray): mixup后的图像高、宽二者组成的np.ndarray,形状为(2,)。
  611. im_info删除的字段:
  612. - mixup (list): 与当前字段进行mixup的图像相关信息。
  613. label_info更新字段为:
  614. - gt_bbox (np.ndarray): mixup后真实标注框坐标,形状为(n, 4),
  615. 其中n代表真实标注框的个数。
  616. - gt_class (np.ndarray): mixup后每个真实标注框对应的类别序号,形状为(n, 1),
  617. 其中n代表真实标注框的个数。
  618. - gt_score (np.ndarray): mixup后每个真实标注框对应的混合得分,形状为(n, 1),
  619. 其中n代表真实标注框的个数。
  620. Raises:
  621. TypeError: 形参数据类型不满足需求。
  622. """
  623. if im_info is None:
  624. raise TypeError('Cannot do MixupImage! ' +
  625. 'Becasuse the im_info can not be None!')
  626. if 'mixup' not in im_info:
  627. if label_info is None:
  628. return (im, im_info)
  629. else:
  630. return (im, im_info, label_info)
  631. factor = np.random.beta(self.alpha, self.beta)
  632. factor = max(0.0, min(1.0, factor))
  633. if im_info['epoch'] > self.mixup_epoch \
  634. or factor >= 1.0:
  635. im_info.pop('mixup')
  636. if label_info is None:
  637. return (im, im_info)
  638. else:
  639. return (im, im_info, label_info)
  640. if factor <= 0.0:
  641. return im_info.pop('mixup')
  642. im = self._mixup_img(im, im_info['mixup'][0], factor)
  643. if label_info is None:
  644. raise TypeError('Cannot do MixupImage! ' +
  645. 'Becasuse the label_info can not be None!')
  646. if 'gt_bbox' not in label_info or \
  647. 'gt_class' not in label_info or \
  648. 'gt_score' not in label_info:
  649. raise TypeError('Cannot do MixupImage! ' + \
  650. 'Becasuse gt_bbox/gt_class/gt_score is not in label_info!')
  651. gt_bbox1 = label_info['gt_bbox']
  652. gt_bbox2 = im_info['mixup'][2]['gt_bbox']
  653. gt_class1 = label_info['gt_class']
  654. gt_class2 = im_info['mixup'][2]['gt_class']
  655. gt_score1 = label_info['gt_score']
  656. gt_score2 = im_info['mixup'][2]['gt_score']
  657. if 'gt_poly' in label_info:
  658. gt_poly1 = label_info['gt_poly']
  659. gt_poly2 = im_info['mixup'][2]['gt_poly']
  660. is_crowd1 = label_info['is_crowd']
  661. is_crowd2 = im_info['mixup'][2]['is_crowd']
  662. if 0 not in gt_class1 and 0 not in gt_class2:
  663. gt_bbox = np.concatenate((gt_bbox1, gt_bbox2), axis=0)
  664. gt_class = np.concatenate((gt_class1, gt_class2), axis=0)
  665. gt_score = np.concatenate(
  666. (gt_score1 * factor, gt_score2 * (1. - factor)), axis=0)
  667. if 'gt_poly' in label_info:
  668. label_info['gt_poly'] = gt_poly1 + gt_poly2
  669. is_crowd = np.concatenate((is_crowd1, is_crowd2), axis=0)
  670. elif 0 in gt_class1:
  671. gt_bbox = gt_bbox2
  672. gt_class = gt_class2
  673. gt_score = gt_score2 * (1. - factor)
  674. if 'gt_poly' in label_info:
  675. label_info['gt_poly'] = gt_poly2
  676. is_crowd = is_crowd2
  677. else:
  678. gt_bbox = gt_bbox1
  679. gt_class = gt_class1
  680. gt_score = gt_score1 * factor
  681. if 'gt_poly' in label_info:
  682. label_info['gt_poly'] = gt_poly1
  683. is_crowd = is_crowd1
  684. label_info['gt_bbox'] = gt_bbox
  685. label_info['gt_score'] = gt_score
  686. label_info['gt_class'] = gt_class
  687. label_info['is_crowd'] = is_crowd
  688. im_info['image_shape'] = np.array([im.shape[0],
  689. im.shape[1]]).astype('int32')
  690. im_info.pop('mixup')
  691. if label_info is None:
  692. return (im, im_info)
  693. else:
  694. return (im, im_info, label_info)
  695. class RandomExpand(DetTransform):
  696. """随机扩张图像,模型训练时的数据增强操作。
  697. 1. 随机选取扩张比例(扩张比例大于1时才进行扩张)。
  698. 2. 计算扩张后图像大小。
  699. 3. 初始化像素值为输入填充值的图像,并将原图像随机粘贴于该图像上。
  700. 4. 根据原图像粘贴位置换算出扩张后真实标注框的位置坐标。
  701. 5. 根据原图像粘贴位置换算出扩张后真实分割区域的位置坐标。
  702. Args:
  703. ratio (float): 图像扩张的最大比例。默认为4.0。
  704. prob (float): 随机扩张的概率。默认为0.5。
  705. fill_value (list): 扩张图像的初始填充值(0-255)。默认为[123.675, 116.28, 103.53]。
  706. """
  707. def __init__(self,
  708. ratio=4.,
  709. prob=0.5,
  710. fill_value=[123.675, 116.28, 103.53]):
  711. super(RandomExpand, self).__init__()
  712. assert ratio > 1.01, "expand ratio must be larger than 1.01"
  713. self.ratio = ratio
  714. self.prob = prob
  715. assert isinstance(fill_value, Sequence), \
  716. "fill value must be sequence"
  717. if not isinstance(fill_value, tuple):
  718. fill_value = tuple(fill_value)
  719. self.fill_value = fill_value
  720. def __call__(self, im, im_info=None, label_info=None):
  721. """
  722. Args:
  723. im (np.ndarray): 图像np.ndarray数据。
  724. im_info (dict, 可选): 存储与图像相关的信息。
  725. label_info (dict, 可选): 存储与标注框相关的信息。
  726. Returns:
  727. tuple: 当label_info为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  728. 当label_info不为空时,返回的tuple为(im, im_info, label_info),分别对应图像np.ndarray数据、
  729. 存储与标注框相关信息的字典。
  730. 其中,im_info更新字段为:
  731. - image_shape (np.ndarray): 扩张后的图像高、宽二者组成的np.ndarray,形状为(2,)。
  732. label_info更新字段为:
  733. - gt_bbox (np.ndarray): 随机扩张后真实标注框坐标,形状为(n, 4),
  734. 其中n代表真实标注框的个数。
  735. - gt_class (np.ndarray): 随机扩张后每个真实标注框对应的类别序号,形状为(n, 1),
  736. 其中n代表真实标注框的个数。
  737. Raises:
  738. TypeError: 形参数据类型不满足需求。
  739. """
  740. if im_info is None or label_info is None:
  741. raise TypeError(
  742. 'Cannot do RandomExpand! ' +
  743. 'Becasuse the im_info and label_info can not be None!')
  744. if 'gt_bbox' not in label_info or \
  745. 'gt_class' not in label_info:
  746. raise TypeError('Cannot do RandomExpand! ' + \
  747. 'Becasuse gt_bbox/gt_class is not in label_info!')
  748. if np.random.uniform(0., 1.) > self.prob:
  749. return (im, im_info, label_info)
  750. if 'gt_class' in label_info and 0 in label_info['gt_class']:
  751. return (im, im_info, label_info)
  752. image_shape = im_info['image_shape']
  753. height = int(image_shape[0])
  754. width = int(image_shape[1])
  755. expand_ratio = np.random.uniform(1., self.ratio)
  756. h = int(height * expand_ratio)
  757. w = int(width * expand_ratio)
  758. if not h > height or not w > width:
  759. return (im, im_info, label_info)
  760. y = np.random.randint(0, h - height)
  761. x = np.random.randint(0, w - width)
  762. canvas = np.ones((h, w, 3), dtype=np.float32)
  763. canvas *= np.array(self.fill_value, dtype=np.float32)
  764. canvas[y:y + height, x:x + width, :] = im
  765. im_info['image_shape'] = np.array([h, w]).astype('int32')
  766. if 'gt_bbox' in label_info and len(label_info['gt_bbox']) > 0:
  767. label_info['gt_bbox'] += np.array([x, y] * 2, dtype=np.float32)
  768. if 'gt_poly' in label_info and len(label_info['gt_poly']) > 0:
  769. label_info['gt_poly'] = expand_segms(label_info['gt_poly'], x, y,
  770. height, width, expand_ratio)
  771. return (canvas, im_info, label_info)
  772. class RandomCrop(DetTransform):
  773. """随机裁剪图像。
  774. 1. 若allow_no_crop为True,则在thresholds加入’no_crop’。
  775. 2. 随机打乱thresholds。
  776. 3. 遍历thresholds中各元素:
  777. (1) 如果当前thresh为’no_crop’,则返回原始图像和标注信息。
  778. (2) 随机取出aspect_ratio和scaling中的值并由此计算出候选裁剪区域的高、宽、起始点。
  779. (3) 计算真实标注框与候选裁剪区域IoU,若全部真实标注框的IoU都小于thresh,则继续第3步。
  780. (4) 如果cover_all_box为True且存在真实标注框的IoU小于thresh,则继续第3步。
  781. (5) 筛选出位于候选裁剪区域内的真实标注框,若有效框的个数为0,则继续第3步,否则进行第4步。
  782. 4. 换算有效真值标注框相对候选裁剪区域的位置坐标。
  783. 5. 换算有效分割区域相对候选裁剪区域的位置坐标。
  784. Args:
  785. aspect_ratio (list): 裁剪后短边缩放比例的取值范围,以[min, max]形式表示。默认值为[.5, 2.]。
  786. thresholds (list): 判断裁剪候选区域是否有效所需的IoU阈值取值列表。默认值为[.0, .1, .3, .5, .7, .9]。
  787. scaling (list): 裁剪面积相对原面积的取值范围,以[min, max]形式表示。默认值为[.3, 1.]。
  788. num_attempts (int): 在放弃寻找有效裁剪区域前尝试的次数。默认值为50。
  789. allow_no_crop (bool): 是否允许未进行裁剪。默认值为True。
  790. cover_all_box (bool): 是否要求所有的真实标注框都必须在裁剪区域内。默认值为False。
  791. """
  792. def __init__(self,
  793. aspect_ratio=[.5, 2.],
  794. thresholds=[.0, .1, .3, .5, .7, .9],
  795. scaling=[.3, 1.],
  796. num_attempts=50,
  797. allow_no_crop=True,
  798. cover_all_box=False):
  799. self.aspect_ratio = aspect_ratio
  800. self.thresholds = thresholds
  801. self.scaling = scaling
  802. self.num_attempts = num_attempts
  803. self.allow_no_crop = allow_no_crop
  804. self.cover_all_box = cover_all_box
  805. def __call__(self, im, im_info=None, label_info=None):
  806. """
  807. Args:
  808. im (np.ndarray): 图像np.ndarray数据。
  809. im_info (dict, 可选): 存储与图像相关的信息。
  810. label_info (dict, 可选): 存储与标注框相关的信息。
  811. Returns:
  812. tuple: 当label_info为空时,返回的tuple为(im, im_info),分别对应图像np.ndarray数据、存储与图像相关信息的字典;
  813. 当label_info不为空时,返回的tuple为(im, im_info, label_info),分别对应图像np.ndarray数据、
  814. 存储与标注框相关信息的字典。
  815. 其中,im_info更新字段为:
  816. - image_shape (np.ndarray): 扩裁剪的图像高、宽二者组成的np.ndarray,形状为(2,)。
  817. label_info更新字段为:
  818. - gt_bbox (np.ndarray): 随机裁剪后真实标注框坐标,形状为(n, 4),
  819. 其中n代表真实标注框的个数。
  820. - gt_class (np.ndarray): 随机裁剪后每个真实标注框对应的类别序号,形状为(n, 1),
  821. 其中n代表真实标注框的个数。
  822. - gt_score (np.ndarray): 随机裁剪后每个真实标注框对应的混合得分,形状为(n, 1),
  823. 其中n代表真实标注框的个数。
  824. Raises:
  825. TypeError: 形参数据类型不满足需求。
  826. """
  827. if im_info is None or label_info is None:
  828. raise TypeError(
  829. 'Cannot do RandomCrop! ' +
  830. 'Becasuse the im_info and label_info can not be None!')
  831. if 'gt_bbox' not in label_info or \
  832. 'gt_class' not in label_info:
  833. raise TypeError('Cannot do RandomCrop! ' + \
  834. 'Becasuse gt_bbox/gt_class is not in label_info!')
  835. if len(label_info['gt_bbox']) == 0:
  836. return (im, im_info, label_info)
  837. if 'gt_class' in label_info and 0 in label_info['gt_class']:
  838. return (im, im_info, label_info)
  839. image_shape = im_info['image_shape']
  840. w = image_shape[1]
  841. h = image_shape[0]
  842. gt_bbox = label_info['gt_bbox']
  843. thresholds = list(self.thresholds)
  844. if self.allow_no_crop:
  845. thresholds.append('no_crop')
  846. np.random.shuffle(thresholds)
  847. for thresh in thresholds:
  848. if thresh == 'no_crop':
  849. return (im, im_info, label_info)
  850. found = False
  851. for i in range(self.num_attempts):
  852. scale = np.random.uniform(*self.scaling)
  853. min_ar, max_ar = self.aspect_ratio
  854. aspect_ratio = np.random.uniform(
  855. max(min_ar, scale**2), min(max_ar, scale**-2))
  856. crop_h = int(h * scale / np.sqrt(aspect_ratio))
  857. crop_w = int(w * scale * np.sqrt(aspect_ratio))
  858. crop_y = np.random.randint(0, h - crop_h)
  859. crop_x = np.random.randint(0, w - crop_w)
  860. crop_box = [crop_x, crop_y, crop_x + crop_w, crop_y + crop_h]
  861. iou = iou_matrix(
  862. gt_bbox, np.array(
  863. [crop_box], dtype=np.float32))
  864. if iou.max() < thresh:
  865. continue
  866. if self.cover_all_box and iou.min() < thresh:
  867. continue
  868. cropped_box, valid_ids = crop_box_with_center_constraint(
  869. gt_bbox, np.array(
  870. crop_box, dtype=np.float32))
  871. if valid_ids.size > 0:
  872. found = True
  873. break
  874. if found:
  875. if 'gt_poly' in label_info and len(label_info['gt_poly']) > 0:
  876. crop_polys = crop_segms(
  877. label_info['gt_poly'],
  878. valid_ids,
  879. np.array(
  880. crop_box, dtype=np.int64),
  881. h,
  882. w)
  883. if [] in crop_polys:
  884. delete_id = list()
  885. valid_polys = list()
  886. for id, crop_poly in enumerate(crop_polys):
  887. if crop_poly == []:
  888. delete_id.append(id)
  889. else:
  890. valid_polys.append(crop_poly)
  891. valid_ids = np.delete(valid_ids, delete_id)
  892. if len(valid_polys) == 0:
  893. return (im, im_info, label_info)
  894. label_info['gt_poly'] = valid_polys
  895. else:
  896. label_info['gt_poly'] = crop_polys
  897. im = crop_image(im, crop_box)
  898. label_info['gt_bbox'] = np.take(cropped_box, valid_ids, axis=0)
  899. label_info['gt_class'] = np.take(
  900. label_info['gt_class'], valid_ids, axis=0)
  901. im_info['image_shape'] = np.array(
  902. [crop_box[3] - crop_box[1],
  903. crop_box[2] - crop_box[0]]).astype('int32')
  904. if 'gt_score' in label_info:
  905. label_info['gt_score'] = np.take(
  906. label_info['gt_score'], valid_ids, axis=0)
  907. if 'is_crowd' in label_info:
  908. label_info['is_crowd'] = np.take(
  909. label_info['is_crowd'], valid_ids, axis=0)
  910. return (im, im_info, label_info)
  911. return (im, im_info, label_info)
  912. class ArrangeFasterRCNN(DetTransform):
  913. """获取FasterRCNN模型训练/验证/预测所需信息。
  914. Args:
  915. mode (str): 指定数据用于何种用途,取值范围为['train', 'eval', 'test', 'quant']。
  916. Raises:
  917. ValueError: mode的取值不在['train', 'eval', 'test', 'quant']之内。
  918. """
  919. def __init__(self, mode=None):
  920. if mode not in ['train', 'eval', 'test', 'quant']:
  921. raise ValueError(
  922. "mode must be in ['train', 'eval', 'test', 'quant']!")
  923. self.mode = mode
  924. def __call__(self, im, im_info=None, label_info=None):
  925. """
  926. Args:
  927. im (np.ndarray): 图像np.ndarray数据。
  928. im_info (dict, 可选): 存储与图像相关的信息。
  929. label_info (dict, 可选): 存储与标注框相关的信息。
  930. Returns:
  931. tuple: 当mode为'train'时,返回(im, im_resize_info, gt_bbox, gt_class, is_crowd),分别对应
  932. 图像np.ndarray数据、图像相当对于原图的resize信息、真实标注框、真实标注框对应的类别、真实标注框内是否是一组对象;
  933. 当mode为'eval'时,返回(im, im_resize_info, im_id, im_shape, gt_bbox, gt_class, is_difficult),
  934. 分别对应图像np.ndarray数据、图像相当对于原图的resize信息、图像id、图像大小信息、真实标注框、真实标注框对应的类别、
  935. 真实标注框是否为难识别对象;当mode为'test'或'quant'时,返回(im, im_resize_info, im_shape),分别对应图像np.ndarray数据、
  936. 图像相当对于原图的resize信息、图像大小信息。
  937. Raises:
  938. TypeError: 形参数据类型不满足需求。
  939. ValueError: 数据长度不匹配。
  940. """
  941. im = permute(im, False)
  942. if self.mode == 'train':
  943. if im_info is None or label_info is None:
  944. raise TypeError(
  945. 'Cannot do ArrangeFasterRCNN! ' +
  946. 'Becasuse the im_info and label_info can not be None!')
  947. if len(label_info['gt_bbox']) != len(label_info['gt_class']):
  948. raise ValueError("gt num mismatch: bbox and class.")
  949. im_resize_info = im_info['im_resize_info']
  950. gt_bbox = label_info['gt_bbox']
  951. gt_class = label_info['gt_class']
  952. is_crowd = label_info['is_crowd']
  953. outputs = (im, im_resize_info, gt_bbox, gt_class, is_crowd)
  954. elif self.mode == 'eval':
  955. if im_info is None or label_info is None:
  956. raise TypeError(
  957. 'Cannot do ArrangeFasterRCNN! ' +
  958. 'Becasuse the im_info and label_info can not be None!')
  959. im_resize_info = im_info['im_resize_info']
  960. im_id = im_info['im_id']
  961. im_shape = np.array(
  962. (im_info['image_shape'][0], im_info['image_shape'][1], 1),
  963. dtype=np.float32)
  964. gt_bbox = label_info['gt_bbox']
  965. gt_class = label_info['gt_class']
  966. is_difficult = label_info['difficult']
  967. outputs = (im, im_resize_info, im_id, im_shape, gt_bbox, gt_class,
  968. is_difficult)
  969. else:
  970. if im_info is None:
  971. raise TypeError('Cannot do ArrangeFasterRCNN! ' +
  972. 'Becasuse the im_info can not be None!')
  973. im_resize_info = im_info['im_resize_info']
  974. im_shape = np.array(
  975. (im_info['image_shape'][0], im_info['image_shape'][1], 1),
  976. dtype=np.float32)
  977. outputs = (im, im_resize_info, im_shape)
  978. return outputs
  979. class ArrangeMaskRCNN(DetTransform):
  980. """获取MaskRCNN模型训练/验证/预测所需信息。
  981. Args:
  982. mode (str): 指定数据用于何种用途,取值范围为['train', 'eval', 'test', 'quant']。
  983. Raises:
  984. ValueError: mode的取值不在['train', 'eval', 'test', 'quant']之内。
  985. """
  986. def __init__(self, mode=None):
  987. if mode not in ['train', 'eval', 'test', 'quant']:
  988. raise ValueError(
  989. "mode must be in ['train', 'eval', 'test', 'quant']!")
  990. self.mode = mode
  991. def __call__(self, im, im_info=None, label_info=None):
  992. """
  993. Args:
  994. im (np.ndarray): 图像np.ndarray数据。
  995. im_info (dict, 可选): 存储与图像相关的信息。
  996. label_info (dict, 可选): 存储与标注框相关的信息。
  997. Returns:
  998. tuple: 当mode为'train'时,返回(im, im_resize_info, gt_bbox, gt_class, is_crowd, gt_masks),分别对应
  999. 图像np.ndarray数据、图像相当对于原图的resize信息、真实标注框、真实标注框对应的类别、真实标注框内是否是一组对象、
  1000. 真实分割区域;当mode为'eval'时,返回(im, im_resize_info, im_id, im_shape),分别对应图像np.ndarray数据、
  1001. 图像相当对于原图的resize信息、图像id、图像大小信息;当mode为'test'或'quant'时,返回(im, im_resize_info, im_shape),
  1002. 分别对应图像np.ndarray数据、图像相当对于原图的resize信息、图像大小信息。
  1003. Raises:
  1004. TypeError: 形参数据类型不满足需求。
  1005. ValueError: 数据长度不匹配。
  1006. """
  1007. im = permute(im, False)
  1008. if self.mode == 'train':
  1009. if im_info is None or label_info is None:
  1010. raise TypeError(
  1011. 'Cannot do ArrangeTrainMaskRCNN! ' +
  1012. 'Becasuse the im_info and label_info can not be None!')
  1013. if len(label_info['gt_bbox']) != len(label_info['gt_class']):
  1014. raise ValueError("gt num mismatch: bbox and class.")
  1015. im_resize_info = im_info['im_resize_info']
  1016. gt_bbox = label_info['gt_bbox']
  1017. gt_class = label_info['gt_class']
  1018. is_crowd = label_info['is_crowd']
  1019. assert 'gt_poly' in label_info
  1020. segms = label_info['gt_poly']
  1021. if len(segms) != 0:
  1022. assert len(segms) == is_crowd.shape[0]
  1023. gt_masks = []
  1024. valid = True
  1025. for i in range(len(segms)):
  1026. segm = segms[i]
  1027. gt_segm = []
  1028. if is_crowd[i]:
  1029. gt_segm.append([[0, 0]])
  1030. else:
  1031. for poly in segm:
  1032. if len(poly) == 0:
  1033. valid = False
  1034. break
  1035. gt_segm.append(np.array(poly).reshape(-1, 2))
  1036. if (not valid) or len(gt_segm) == 0:
  1037. break
  1038. gt_masks.append(gt_segm)
  1039. outputs = (im, im_resize_info, gt_bbox, gt_class, is_crowd,
  1040. gt_masks)
  1041. else:
  1042. if im_info is None:
  1043. raise TypeError('Cannot do ArrangeMaskRCNN! ' +
  1044. 'Becasuse the im_info can not be None!')
  1045. im_resize_info = im_info['im_resize_info']
  1046. im_shape = np.array(
  1047. (im_info['image_shape'][0], im_info['image_shape'][1], 1),
  1048. dtype=np.float32)
  1049. if self.mode == 'eval':
  1050. im_id = im_info['im_id']
  1051. outputs = (im, im_resize_info, im_id, im_shape)
  1052. else:
  1053. outputs = (im, im_resize_info, im_shape)
  1054. return outputs
  1055. class ArrangeYOLOv3(DetTransform):
  1056. """获取YOLOv3模型训练/验证/预测所需信息。
  1057. Args:
  1058. mode (str): 指定数据用于何种用途,取值范围为['train', 'eval', 'test', 'quant']。
  1059. Raises:
  1060. ValueError: mode的取值不在['train', 'eval', 'test', 'quant']之内。
  1061. """
  1062. def __init__(self, mode=None):
  1063. if mode not in ['train', 'eval', 'test', 'quant']:
  1064. raise ValueError(
  1065. "mode must be in ['train', 'eval', 'test', 'quant']!")
  1066. self.mode = mode
  1067. def __call__(self, im, im_info=None, label_info=None):
  1068. """
  1069. Args:
  1070. im (np.ndarray): 图像np.ndarray数据。
  1071. im_info (dict, 可选): 存储与图像相关的信息。
  1072. label_info (dict, 可选): 存储与标注框相关的信息。
  1073. Returns:
  1074. tuple: 当mode为'train'时,返回(im, gt_bbox, gt_class, gt_score, im_shape),分别对应
  1075. 图像np.ndarray数据、真实标注框、真实标注框对应的类别、真实标注框混合得分、图像大小信息;
  1076. 当mode为'eval'时,返回(im, im_shape, im_id, gt_bbox, gt_class, difficult),
  1077. 分别对应图像np.ndarray数据、图像大小信息、图像id、真实标注框、真实标注框对应的类别、
  1078. 真实标注框是否为难识别对象;当mode为'test'或'quant'时,返回(im, im_shape),
  1079. 分别对应图像np.ndarray数据、图像大小信息。
  1080. Raises:
  1081. TypeError: 形参数据类型不满足需求。
  1082. ValueError: 数据长度不匹配。
  1083. """
  1084. im = permute(im, False)
  1085. if self.mode == 'train':
  1086. if im_info is None or label_info is None:
  1087. raise TypeError(
  1088. 'Cannot do ArrangeYolov3! ' +
  1089. 'Becasuse the im_info and label_info can not be None!')
  1090. im_shape = im_info['image_shape']
  1091. if len(label_info['gt_bbox']) != len(label_info['gt_class']):
  1092. raise ValueError("gt num mismatch: bbox and class.")
  1093. if len(label_info['gt_bbox']) != len(label_info['gt_score']):
  1094. raise ValueError("gt num mismatch: bbox and score.")
  1095. gt_bbox = np.zeros((50, 4), dtype=im.dtype)
  1096. gt_class = np.zeros((50, ), dtype=np.int32)
  1097. gt_score = np.zeros((50, ), dtype=im.dtype)
  1098. gt_num = min(50, len(label_info['gt_bbox']))
  1099. if gt_num > 0:
  1100. label_info['gt_class'][:gt_num, 0] = label_info[
  1101. 'gt_class'][:gt_num, 0] - 1
  1102. if -1 not in label_info['gt_class']:
  1103. gt_bbox[:gt_num, :] = label_info['gt_bbox'][:gt_num, :]
  1104. gt_class[:gt_num] = label_info['gt_class'][:gt_num, 0]
  1105. gt_score[:gt_num] = label_info['gt_score'][:gt_num, 0]
  1106. # parse [x1, y1, x2, y2] to [x, y, w, h]
  1107. gt_bbox[:, 2:4] = gt_bbox[:, 2:4] - gt_bbox[:, :2]
  1108. gt_bbox[:, :2] = gt_bbox[:, :2] + gt_bbox[:, 2:4] / 2.
  1109. outputs = (im, gt_bbox, gt_class, gt_score, im_shape)
  1110. elif self.mode == 'eval':
  1111. if im_info is None or label_info is None:
  1112. raise TypeError(
  1113. 'Cannot do ArrangeYolov3! ' +
  1114. 'Becasuse the im_info and label_info can not be None!')
  1115. im_shape = im_info['image_shape']
  1116. if len(label_info['gt_bbox']) != len(label_info['gt_class']):
  1117. raise ValueError("gt num mismatch: bbox and class.")
  1118. im_id = im_info['im_id']
  1119. gt_bbox = np.zeros((50, 4), dtype=im.dtype)
  1120. gt_class = np.zeros((50, ), dtype=np.int32)
  1121. difficult = np.zeros((50, ), dtype=np.int32)
  1122. gt_num = min(50, len(label_info['gt_bbox']))
  1123. if gt_num > 0:
  1124. label_info['gt_class'][:gt_num, 0] = label_info[
  1125. 'gt_class'][:gt_num, 0] - 1
  1126. gt_bbox[:gt_num, :] = label_info['gt_bbox'][:gt_num, :]
  1127. gt_class[:gt_num] = label_info['gt_class'][:gt_num, 0]
  1128. difficult[:gt_num] = label_info['difficult'][:gt_num, 0]
  1129. outputs = (im, im_shape, im_id, gt_bbox, gt_class, difficult)
  1130. else:
  1131. if im_info is None:
  1132. raise TypeError('Cannot do ArrangeYolov3! ' +
  1133. 'Becasuse the im_info can not be None!')
  1134. im_shape = im_info['image_shape']
  1135. outputs = (im, im_shape)
  1136. return outputs
  1137. class ComposedRCNNTransforms(Compose):
  1138. """ RCNN模型(faster-rcnn/mask-rcnn)图像处理流程,具体如下,
  1139. 训练阶段:
  1140. 1. 随机以0.5的概率将图像水平翻转
  1141. 2. 图像归一化
  1142. 3. 图像按比例Resize,scale计算方式如下
  1143. scale = min_max_size[0] / short_size_of_image
  1144. if max_size_of_image * scale > min_max_size[1]:
  1145. scale = min_max_size[1] / max_size_of_image
  1146. 4. 将3步骤的长宽进行padding,使得长宽为32的倍数
  1147. 验证阶段:
  1148. 1. 图像归一化
  1149. 2. 图像按比例Resize,scale计算方式同上训练阶段
  1150. 3. 将2步骤的长宽进行padding,使得长宽为32的倍数
  1151. Args:
  1152. mode(str): 图像处理流程所处阶段,训练/验证/预测,分别对应'train', 'eval', 'test'
  1153. min_max_size(list): 图像在缩放时,最小边和最大边的约束条件
  1154. mean(list): 图像均值
  1155. std(list): 图像方差
  1156. random_horizontal_flip(bool): 是否以0.5的概率使用随机水平翻转增强,该仅在mode为`train`时生效,默认为True
  1157. """
  1158. def __init__(self,
  1159. mode,
  1160. min_max_size=[800, 1333],
  1161. mean=[0.485, 0.456, 0.406],
  1162. std=[0.229, 0.224, 0.225],
  1163. random_horizontal_flip=True):
  1164. if mode == 'train':
  1165. # 训练时的transforms,包含数据增强
  1166. transforms = [
  1167. Normalize(
  1168. mean=mean, std=std), ResizeByShort(
  1169. short_size=min_max_size[0], max_size=min_max_size[1]),
  1170. Padding(coarsest_stride=32)
  1171. ]
  1172. if random_horizontal_flip:
  1173. transforms.insert(0, RandomHorizontalFlip())
  1174. else:
  1175. # 验证/预测时的transforms
  1176. transforms = [
  1177. Normalize(
  1178. mean=mean, std=std), ResizeByShort(
  1179. short_size=min_max_size[0], max_size=min_max_size[1]),
  1180. Padding(coarsest_stride=32)
  1181. ]
  1182. super(ComposedRCNNTransforms, self).__init__(transforms)
  1183. class ComposedYOLOv3Transforms(Compose):
  1184. """YOLOv3模型的图像预处理流程,具体如下,
  1185. 训练阶段:
  1186. 1. 在前mixup_epoch轮迭代中,使用MixupImage策略,见https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/det_transforms.html#mixupimage
  1187. 2. 对图像进行随机扰动,包括亮度,对比度,饱和度和色调
  1188. 3. 随机扩充图像,见https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/det_transforms.html#randomexpand
  1189. 4. 随机裁剪图像
  1190. 5. 将4步骤的输出图像Resize成shape参数的大小
  1191. 6. 随机0.5的概率水平翻转图像
  1192. 7. 图像归一化
  1193. 验证/预测阶段:
  1194. 1. 将图像Resize成shape参数大小
  1195. 2. 图像归一化
  1196. Args:
  1197. mode(str): 图像处理流程所处阶段,训练/验证/预测,分别对应'train', 'eval', 'test'
  1198. shape(list): 输入模型中图像的大小,输入模型的图像会被Resize成此大小
  1199. mixup_epoch(int): 模型训练过程中,前mixup_epoch会使用mixup策略, 若设为-1,则表示不使用该策略
  1200. mean(list): 图像均值
  1201. std(list): 图像方差
  1202. random_distort(bool): 数据增强方式,参数仅在mode为`train`时生效,表示是否在训练过程中随机扰动图像,默认为True
  1203. random_expand(bool): 数据增强方式,参数仅在mode为`train`时生效,表示是否在训练过程中随机扩张图像,默认为True
  1204. random_crop(bool): 数据增强方式,参数仅在mode为`train`时生效,表示是否在训练过程中随机裁剪图像,默认为True
  1205. random_horizontal_flip(bool): 数据增强方式,参数仅在mode为`train`时生效,表示是否在训练过程中随机水平翻转图像,默认为True
  1206. """
  1207. def __init__(self,
  1208. mode,
  1209. shape=[608, 608],
  1210. mixup_epoch=250,
  1211. mean=[0.485, 0.456, 0.406],
  1212. std=[0.229, 0.224, 0.225],
  1213. random_distort=True,
  1214. random_expand=True,
  1215. random_crop=True,
  1216. random_horizontal_flip=True):
  1217. width = shape
  1218. if isinstance(shape, list):
  1219. if shape[0] != shape[1]:
  1220. raise Exception(
  1221. "In YOLOv3 model, width and height should be equal")
  1222. width = shape[0]
  1223. if width % 32 != 0:
  1224. raise Exception(
  1225. "In YOLOv3 model, width and height should be multiple of 32, e.g 224、256、320...."
  1226. )
  1227. if mode == 'train':
  1228. # 训练时的transforms,包含数据增强
  1229. transforms = [
  1230. MixupImage(mixup_epoch=mixup_epoch), Resize(
  1231. target_size=width, interp='RANDOM'), Normalize(
  1232. mean=mean, std=std)
  1233. ]
  1234. if random_horizontal_flip:
  1235. transforms.insert(1, RandomHorizontalFlip())
  1236. if random_crop:
  1237. transforms.insert(1, RandomCrop())
  1238. if random_expand:
  1239. transforms.insert(1, RandomExpand())
  1240. if random_distort:
  1241. transforms.insert(1, RandomDistort())
  1242. else:
  1243. # 验证/预测时的transforms
  1244. transforms = [
  1245. Resize(
  1246. target_size=width, interp='CUBIC'), Normalize(
  1247. mean=mean, std=std)
  1248. ]
  1249. super(ComposedYOLOv3Transforms, self).__init__(transforms)
  1250. class BatchRandomShape(DetTransform):
  1251. """调整图像大小(resize)。
  1252. 对batch数据中的每张图像全部resize到random_shapes中任意一个大小。
  1253. 注意:当插值方式为“RANDOM”时,则随机选取一种插值方式进行resize。
  1254. Args:
  1255. random_shapes (list): resize大小选择列表。
  1256. 默认为[320, 352, 384, 416, 448, 480, 512, 544, 576, 608]。
  1257. interp (str): resize的插值方式,与opencv的插值方式对应,取值范围为
  1258. ['NEAREST', 'LINEAR', 'CUBIC', 'AREA', 'LANCZOS4', 'RANDOM']。默认为"RANDOM"。
  1259. Raises:
  1260. ValueError: 插值方式不在['NEAREST', 'LINEAR', 'CUBIC',
  1261. 'AREA', 'LANCZOS4', 'RANDOM']中。
  1262. """
  1263. # The interpolation mode
  1264. interp_dict = {
  1265. 'NEAREST': cv2.INTER_NEAREST,
  1266. 'LINEAR': cv2.INTER_LINEAR,
  1267. 'CUBIC': cv2.INTER_CUBIC,
  1268. 'AREA': cv2.INTER_AREA,
  1269. 'LANCZOS4': cv2.INTER_LANCZOS4
  1270. }
  1271. def __init__(
  1272. self,
  1273. random_shapes=[320, 352, 384, 416, 448, 480, 512, 544, 576, 608],
  1274. interp='RANDOM'):
  1275. if not (interp == "RANDOM" or interp in self.interp_dict):
  1276. raise ValueError("interp should be one of {}".format(
  1277. self.interp_dict.keys()))
  1278. self.random_shapes = random_shapes
  1279. self.interp = interp
  1280. def __call__(self, batch_data):
  1281. """
  1282. Args:
  1283. batch_data (list): 由与图像相关的各种信息组成的batch数据。
  1284. Returns:
  1285. list: 由与图像相关的各种信息组成的batch数据。
  1286. """
  1287. shape = np.random.choice(self.random_shapes)
  1288. if self.interp == "RANDOM":
  1289. interp = random.choice(list(self.interp_dict.keys()))
  1290. else:
  1291. interp = self.interp
  1292. for data_id, data in enumerate(batch_data):
  1293. data_list = list(data)
  1294. im = data_list[0]
  1295. im = np.swapaxes(im, 1, 0)
  1296. im = np.swapaxes(im, 1, 2)
  1297. im = resize(im, shape, self.interp_dict[interp])
  1298. im = np.swapaxes(im, 1, 2)
  1299. im = np.swapaxes(im, 1, 0)
  1300. data_list[0] = im
  1301. batch_data[data_id] = tuple(data_list)
  1302. return batch_data
  1303. class GenerateYoloTarget(object):
  1304. """生成YOLOv3的ground truth(真实标注框)在不同特征层的位置转换信息。
  1305. 该transform只在YOLOv3计算细粒度loss时使用。
  1306. Args:
  1307. anchors (list|tuple): anchor框的宽度和高度。
  1308. anchor_masks (list|tuple): 在计算损失时,使用anchor的mask索引。
  1309. num_classes (int): 类别数。默认为80。
  1310. iou_thresh (float): iou阈值,当anchor和真实标注框的iou大于该阈值时,计入target。默认为1.0。
  1311. """
  1312. def __init__(self,
  1313. anchors,
  1314. anchor_masks,
  1315. downsample_ratios,
  1316. num_classes=80,
  1317. iou_thresh=1.):
  1318. super(GenerateYoloTarget, self).__init__()
  1319. self.anchors = anchors
  1320. self.anchor_masks = anchor_masks
  1321. self.downsample_ratios = downsample_ratios
  1322. self.num_classes = num_classes
  1323. self.iou_thresh = iou_thresh
  1324. def __call__(self, batch_data):
  1325. """
  1326. Args:
  1327. batch_data (list): 由与图像相关的各种信息组成的batch数据。
  1328. Returns:
  1329. list: 由与图像相关的各种信息组成的batch数据。
  1330. 其中,每个数据新添加的字段为:
  1331. - target0 (np.ndarray): YOLOv3的ground truth在特征层0的位置转换信息,
  1332. 形状为(特征层0的anchor数量, 6+类别数, 特征层0的h, 特征层0的w)。
  1333. - target1 (np.ndarray): YOLOv3的ground truth在特征层1的位置转换信息,
  1334. 形状为(特征层1的anchor数量, 6+类别数, 特征层1的h, 特征层1的w)。
  1335. - ...
  1336. -targetn (np.ndarray): YOLOv3的ground truth在特征层n的位置转换信息,
  1337. 形状为(特征层n的anchor数量, 6+类别数, 特征层n的h, 特征层n的w)。
  1338. n的是大小由anchor_masks的长度决定。
  1339. """
  1340. im = batch_data[0][0]
  1341. h = im.shape[1]
  1342. w = im.shape[2]
  1343. an_hw = np.array(self.anchors) / np.array([[w, h]])
  1344. for data_id, data in enumerate(batch_data):
  1345. gt_bbox = data[1]
  1346. gt_class = data[2]
  1347. gt_score = data[3]
  1348. im_shape = data[4]
  1349. origin_h = float(im_shape[0])
  1350. origin_w = float(im_shape[1])
  1351. data_list = list(data)
  1352. for i, (
  1353. mask, downsample_ratio
  1354. ) in enumerate(zip(self.anchor_masks, self.downsample_ratios)):
  1355. grid_h = int(h / downsample_ratio)
  1356. grid_w = int(w / downsample_ratio)
  1357. target = np.zeros(
  1358. (len(mask), 6 + self.num_classes, grid_h, grid_w),
  1359. dtype=np.float32)
  1360. for b in range(gt_bbox.shape[0]):
  1361. gx = gt_bbox[b, 0] / float(origin_w)
  1362. gy = gt_bbox[b, 1] / float(origin_h)
  1363. gw = gt_bbox[b, 2] / float(origin_w)
  1364. gh = gt_bbox[b, 3] / float(origin_h)
  1365. cls = gt_class[b]
  1366. score = gt_score[b]
  1367. if gw <= 0. or gh <= 0. or score <= 0.:
  1368. continue
  1369. # find best match anchor index
  1370. best_iou = 0.
  1371. best_idx = -1
  1372. for an_idx in range(an_hw.shape[0]):
  1373. iou = jaccard_overlap(
  1374. [0., 0., gw, gh],
  1375. [0., 0., an_hw[an_idx, 0], an_hw[an_idx, 1]])
  1376. if iou > best_iou:
  1377. best_iou = iou
  1378. best_idx = an_idx
  1379. gi = int(gx * grid_w)
  1380. gj = int(gy * grid_h)
  1381. # gtbox should be regresed in this layes if best match
  1382. # anchor index in anchor mask of this layer
  1383. if best_idx in mask:
  1384. best_n = mask.index(best_idx)
  1385. # x, y, w, h, scale
  1386. target[best_n, 0, gj, gi] = gx * grid_w - gi
  1387. target[best_n, 1, gj, gi] = gy * grid_h - gj
  1388. target[best_n, 2, gj, gi] = np.log(
  1389. gw * w / self.anchors[best_idx][0])
  1390. target[best_n, 3, gj, gi] = np.log(
  1391. gh * h / self.anchors[best_idx][1])
  1392. target[best_n, 4, gj, gi] = 2.0 - gw * gh
  1393. # objectness record gt_score
  1394. target[best_n, 5, gj, gi] = score
  1395. # classification
  1396. target[best_n, 6 + cls, gj, gi] = 1.
  1397. # For non-matched anchors, calculate the target if the iou
  1398. # between anchor and gt is larger than iou_thresh
  1399. if self.iou_thresh < 1:
  1400. for idx, mask_i in enumerate(mask):
  1401. if mask_i == best_idx: continue
  1402. iou = jaccard_overlap(
  1403. [0., 0., gw, gh],
  1404. [0., 0., an_hw[mask_i, 0], an_hw[mask_i, 1]])
  1405. if iou > self.iou_thresh:
  1406. # x, y, w, h, scale
  1407. target[idx, 0, gj, gi] = gx * grid_w - gi
  1408. target[idx, 1, gj, gi] = gy * grid_h - gj
  1409. target[idx, 2, gj, gi] = np.log(
  1410. gw * w / self.anchors[mask_i][0])
  1411. target[idx, 3, gj, gi] = np.log(
  1412. gh * h / self.anchors[mask_i][1])
  1413. target[idx, 4, gj, gi] = 2.0 - gw * gh
  1414. # objectness record gt_score
  1415. target[idx, 5, gj, gi] = score
  1416. # classification
  1417. target[idx, 6 + cls, gj, gi] = 1.
  1418. data_list.append(target)
  1419. batch_data[data_id] = tuple(data_list)
  1420. return batch_data