transforms.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import sys
  16. import cv2
  17. import copy
  18. import math
  19. import pyclipper
  20. import numpy as np
  21. from PIL import Image
  22. from shapely.geometry import Polygon
  23. from ....utils import logging
  24. from ...base.predictor.io.writers import ImageWriter
  25. from ...base.predictor.io.readers import ImageReader
  26. from ...base.predictor import BaseTransform
  27. from .keys import TextDetKeys as K
  28. __all__ = [
  29. "DetResizeForTest",
  30. "NormalizeImage",
  31. "DBPostProcess",
  32. "SaveTextDetResults",
  33. "PrintResult",
  34. ]
  35. class DetResizeForTest(BaseTransform):
  36. """DetResizeForTest"""
  37. def __init__(self, **kwargs):
  38. super(DetResizeForTest, self).__init__()
  39. self.resize_type = 0
  40. self.keep_ratio = False
  41. if "image_shape" in kwargs:
  42. self.image_shape = kwargs["image_shape"]
  43. self.resize_type = 1
  44. if "keep_ratio" in kwargs:
  45. self.keep_ratio = kwargs["keep_ratio"]
  46. elif "limit_side_len" in kwargs:
  47. self.limit_side_len = kwargs["limit_side_len"]
  48. self.limit_type = kwargs.get("limit_type", "min")
  49. elif "resize_long" in kwargs:
  50. self.resize_type = 2
  51. self.resize_long = kwargs.get("resize_long", 960)
  52. else:
  53. self.limit_side_len = 736
  54. self.limit_type = "min"
  55. def apply(self, data):
  56. """apply"""
  57. img = data[K.IMAGE]
  58. src_h, src_w, _ = img.shape
  59. if sum([src_h, src_w]) < 64:
  60. img = self.image_padding(img)
  61. if self.resize_type == 0:
  62. # img, shape = self.resize_image_type0(img)
  63. img, [ratio_h, ratio_w] = self.resize_image_type0(img)
  64. elif self.resize_type == 2:
  65. img, [ratio_h, ratio_w] = self.resize_image_type2(img)
  66. else:
  67. # img, shape = self.resize_image_type1(img)
  68. img, [ratio_h, ratio_w] = self.resize_image_type1(img)
  69. data[K.IMAGE] = img
  70. data[K.SHAPE] = np.array([src_h, src_w, ratio_h, ratio_w])
  71. return data
  72. @classmethod
  73. def get_input_keys(cls):
  74. """get input keys"""
  75. return [K.IMAGE]
  76. @classmethod
  77. def get_output_keys(cls):
  78. """get output keys"""
  79. return [K.IMAGE, K.SHAPE]
  80. def image_padding(self, im, value=0):
  81. """padding image"""
  82. h, w, c = im.shape
  83. im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value
  84. im_pad[:h, :w, :] = im
  85. return im_pad
  86. def resize_image_type1(self, img):
  87. """resize the image"""
  88. resize_h, resize_w = self.image_shape
  89. ori_h, ori_w = img.shape[:2] # (h, w, c)
  90. if self.keep_ratio is True:
  91. resize_w = ori_w * resize_h / ori_h
  92. N = math.ceil(resize_w / 32)
  93. resize_w = N * 32
  94. ratio_h = float(resize_h) / ori_h
  95. ratio_w = float(resize_w) / ori_w
  96. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  97. # return img, np.array([ori_h, ori_w])
  98. return img, [ratio_h, ratio_w]
  99. def resize_image_type0(self, img):
  100. """
  101. resize image to a size multiple of 32 which is required by the network
  102. args:
  103. img(array): array with shape [h, w, c]
  104. return(tuple):
  105. img, (ratio_h, ratio_w)
  106. """
  107. limit_side_len = self.limit_side_len
  108. h, w, c = img.shape
  109. # limit the max side
  110. if self.limit_type == "max":
  111. if max(h, w) > limit_side_len:
  112. if h > w:
  113. ratio = float(limit_side_len) / h
  114. else:
  115. ratio = float(limit_side_len) / w
  116. else:
  117. ratio = 1.0
  118. elif self.limit_type == "min":
  119. if min(h, w) < limit_side_len:
  120. if h < w:
  121. ratio = float(limit_side_len) / h
  122. else:
  123. ratio = float(limit_side_len) / w
  124. else:
  125. ratio = 1.0
  126. elif self.limit_type == "resize_long":
  127. ratio = float(limit_side_len) / max(h, w)
  128. else:
  129. raise Exception("not support limit type, image ")
  130. resize_h = int(h * ratio)
  131. resize_w = int(w * ratio)
  132. resize_h = max(int(round(resize_h / 32) * 32), 32)
  133. resize_w = max(int(round(resize_w / 32) * 32), 32)
  134. try:
  135. if int(resize_w) <= 0 or int(resize_h) <= 0:
  136. return None, (None, None)
  137. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  138. except:
  139. logging.info(img.shape, resize_w, resize_h)
  140. sys.exit(0)
  141. ratio_h = resize_h / float(h)
  142. ratio_w = resize_w / float(w)
  143. return img, [ratio_h, ratio_w]
  144. def resize_image_type2(self, img):
  145. """resize image size"""
  146. h, w, _ = img.shape
  147. resize_w = w
  148. resize_h = h
  149. if resize_h > resize_w:
  150. ratio = float(self.resize_long) / resize_h
  151. else:
  152. ratio = float(self.resize_long) / resize_w
  153. resize_h = int(resize_h * ratio)
  154. resize_w = int(resize_w * ratio)
  155. max_stride = 128
  156. resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
  157. resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
  158. img = cv2.resize(img, (int(resize_w), int(resize_h)))
  159. ratio_h = resize_h / float(h)
  160. ratio_w = resize_w / float(w)
  161. return img, [ratio_h, ratio_w]
  162. class NormalizeImage(BaseTransform):
  163. """normalize image such as substract mean, divide std"""
  164. def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs):
  165. if isinstance(scale, str):
  166. scale = eval(scale)
  167. self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
  168. mean = mean if mean is not None else [0.485, 0.456, 0.406]
  169. std = std if std is not None else [0.229, 0.224, 0.225]
  170. shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
  171. self.mean = np.array(mean).reshape(shape).astype("float32")
  172. self.std = np.array(std).reshape(shape).astype("float32")
  173. def apply(self, data):
  174. """apply"""
  175. img = data[K.IMAGE]
  176. from PIL import Image
  177. if isinstance(img, Image.Image):
  178. img = np.array(img)
  179. assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage"
  180. data[K.IMAGE] = (img.astype("float32") * self.scale - self.mean) / self.std
  181. return data
  182. @classmethod
  183. def get_input_keys(cls):
  184. """get input keys"""
  185. return [K.IMAGE]
  186. @classmethod
  187. def get_output_keys(cls):
  188. """get output keys"""
  189. return [K.IMAGE]
  190. class DBPostProcess(BaseTransform):
  191. """
  192. The post process for Differentiable Binarization (DB).
  193. """
  194. def __init__(
  195. self,
  196. thresh=0.3,
  197. box_thresh=0.7,
  198. max_candidates=1000,
  199. unclip_ratio=2.0,
  200. use_dilation=False,
  201. score_mode="fast",
  202. box_type="quad",
  203. **kwargs
  204. ):
  205. self.thresh = thresh
  206. self.box_thresh = box_thresh
  207. self.max_candidates = max_candidates
  208. self.unclip_ratio = unclip_ratio
  209. self.min_size = 3
  210. self.score_mode = score_mode
  211. self.box_type = box_type
  212. assert score_mode in [
  213. "slow",
  214. "fast",
  215. ], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
  216. self.dilation_kernel = None if not use_dilation else np.array([[1, 1], [1, 1]])
  217. def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
  218. """_bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1}"""
  219. bitmap = _bitmap
  220. height, width = bitmap.shape
  221. boxes = []
  222. scores = []
  223. contours, _ = cv2.findContours(
  224. (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
  225. )
  226. for contour in contours[: self.max_candidates]:
  227. epsilon = 0.002 * cv2.arcLength(contour, True)
  228. approx = cv2.approxPolyDP(contour, epsilon, True)
  229. points = approx.reshape((-1, 2))
  230. if points.shape[0] < 4:
  231. continue
  232. score = self.box_score_fast(pred, points.reshape(-1, 2))
  233. if self.box_thresh > score:
  234. continue
  235. if points.shape[0] > 2:
  236. box = self.unclip(points, self.unclip_ratio)
  237. if len(box) > 1:
  238. continue
  239. else:
  240. continue
  241. box = box.reshape(-1, 2)
  242. _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
  243. if sside < self.min_size + 2:
  244. continue
  245. box = np.array(box)
  246. box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
  247. box[:, 1] = np.clip(
  248. np.round(box[:, 1] / height * dest_height), 0, dest_height
  249. )
  250. boxes.append(box.tolist())
  251. scores.append(score)
  252. return boxes, scores
  253. def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
  254. """_bitmap: single map with shape (1, H, W), whose values are binarized as {0, 1}"""
  255. bitmap = _bitmap
  256. height, width = bitmap.shape
  257. outs = cv2.findContours(
  258. (bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE
  259. )
  260. if len(outs) == 3:
  261. img, contours, _ = outs[0], outs[1], outs[2]
  262. elif len(outs) == 2:
  263. contours, _ = outs[0], outs[1]
  264. num_contours = min(len(contours), self.max_candidates)
  265. boxes = []
  266. scores = []
  267. for index in range(num_contours):
  268. contour = contours[index]
  269. points, sside = self.get_mini_boxes(contour)
  270. if sside < self.min_size:
  271. continue
  272. points = np.array(points)
  273. if self.score_mode == "fast":
  274. score = self.box_score_fast(pred, points.reshape(-1, 2))
  275. else:
  276. score = self.box_score_slow(pred, contour)
  277. if self.box_thresh > score:
  278. continue
  279. box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2)
  280. box, sside = self.get_mini_boxes(box)
  281. if sside < self.min_size + 2:
  282. continue
  283. box = np.array(box)
  284. box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
  285. box[:, 1] = np.clip(
  286. np.round(box[:, 1] / height * dest_height), 0, dest_height
  287. )
  288. boxes.append(box.astype(np.int16))
  289. scores.append(score)
  290. return np.array(boxes, dtype=np.int16), scores
  291. def unclip(self, box, unclip_ratio):
  292. """unclip"""
  293. poly = Polygon(box)
  294. distance = poly.area * unclip_ratio / poly.length
  295. offset = pyclipper.PyclipperOffset()
  296. offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
  297. expanded = np.array(offset.Execute(distance))
  298. return expanded
  299. def get_mini_boxes(self, contour):
  300. """get mini boxes"""
  301. bounding_box = cv2.minAreaRect(contour)
  302. points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
  303. index_1, index_2, index_3, index_4 = 0, 1, 2, 3
  304. if points[1][1] > points[0][1]:
  305. index_1 = 0
  306. index_4 = 1
  307. else:
  308. index_1 = 1
  309. index_4 = 0
  310. if points[3][1] > points[2][1]:
  311. index_2 = 2
  312. index_3 = 3
  313. else:
  314. index_2 = 3
  315. index_3 = 2
  316. box = [points[index_1], points[index_2], points[index_3], points[index_4]]
  317. return box, min(bounding_box[1])
  318. def box_score_fast(self, bitmap, _box):
  319. """box_score_fast: use bbox mean score as the mean score"""
  320. h, w = bitmap.shape[:2]
  321. box = _box.copy()
  322. xmin = np.clip(np.floor(box[:, 0].min()).astype("int"), 0, w - 1)
  323. xmax = np.clip(np.ceil(box[:, 0].max()).astype("int"), 0, w - 1)
  324. ymin = np.clip(np.floor(box[:, 1].min()).astype("int"), 0, h - 1)
  325. ymax = np.clip(np.ceil(box[:, 1].max()).astype("int"), 0, h - 1)
  326. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  327. box[:, 0] = box[:, 0] - xmin
  328. box[:, 1] = box[:, 1] - ymin
  329. cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
  330. return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
  331. def box_score_slow(self, bitmap, contour):
  332. """box_score_slow: use polyon mean score as the mean score"""
  333. h, w = bitmap.shape[:2]
  334. contour = contour.copy()
  335. contour = np.reshape(contour, (-1, 2))
  336. xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
  337. xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
  338. ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
  339. ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
  340. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  341. contour[:, 0] = contour[:, 0] - xmin
  342. contour[:, 1] = contour[:, 1] - ymin
  343. cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
  344. return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0]
  345. def apply(self, data):
  346. """apply"""
  347. pred = data[K.PROB_MAP]
  348. shape_list = [data[K.SHAPE]]
  349. pred = pred[0][:, 0, :, :]
  350. segmentation = pred > self.thresh
  351. boxes_batch = []
  352. for batch_index in range(pred.shape[0]):
  353. src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
  354. if self.dilation_kernel is not None:
  355. mask = cv2.dilate(
  356. np.array(segmentation[batch_index]).astype(np.uint8),
  357. self.dilation_kernel,
  358. )
  359. else:
  360. mask = segmentation[batch_index]
  361. if self.box_type == "poly":
  362. boxes, scores = self.polygons_from_bitmap(
  363. pred[batch_index], mask, src_w, src_h
  364. )
  365. elif self.box_type == "quad":
  366. boxes, scores = self.boxes_from_bitmap(
  367. pred[batch_index], mask, src_w, src_h
  368. )
  369. else:
  370. raise ValueError("box_type can only be one of ['quad', 'poly']")
  371. data[K.DT_POLYS] = boxes
  372. data[K.DT_SCORES] = scores
  373. return data
  374. @classmethod
  375. def get_input_keys(cls):
  376. """get input keys"""
  377. return [K.PROB_MAP]
  378. @classmethod
  379. def get_output_keys(cls):
  380. """get output keys"""
  381. return [K.DT_POLYS, K.DT_SCORES]
  382. class CropByPolys(BaseTransform):
  383. """Crop Image by Polys"""
  384. def __init__(self, det_box_type="quad"):
  385. super().__init__()
  386. self.det_box_type = det_box_type
  387. def apply(self, data):
  388. """apply"""
  389. ori_im = data[K.ORI_IM]
  390. # TODO
  391. # dt_boxes = self.sorted_boxes(data[K.DT_POLYS])
  392. dt_boxes = np.array(data[K.DT_POLYS])
  393. img_crop_list = []
  394. for bno in range(len(dt_boxes)):
  395. tmp_box = copy.deepcopy(dt_boxes[bno])
  396. if self.det_box_type == "quad":
  397. img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
  398. else:
  399. img_crop = self.get_minarea_rect_crop(ori_im, tmp_box)
  400. img_crop_list.append(img_crop)
  401. data[K.SUB_IMGS] = img_crop_list
  402. return data
  403. @classmethod
  404. def get_input_keys(cls):
  405. """get input keys"""
  406. return [K.IM_PATH, K.DT_POLYS]
  407. @classmethod
  408. def get_output_keys(cls):
  409. """get output keys"""
  410. return [K.SUB_IMGS]
  411. def sorted_boxes(self, dt_boxes):
  412. """
  413. Sort text boxes in order from top to bottom, left to right
  414. args:
  415. dt_boxes(array):detected text boxes with shape [4, 2]
  416. return:
  417. sorted boxes(array) with shape [4, 2]
  418. """
  419. dt_boxes = np.array(dt_boxes)
  420. num_boxes = dt_boxes.shape[0]
  421. sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
  422. _boxes = list(sorted_boxes)
  423. for i in range(num_boxes - 1):
  424. for j in range(i, -1, -1):
  425. if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and (
  426. _boxes[j + 1][0][0] < _boxes[j][0][0]
  427. ):
  428. tmp = _boxes[j]
  429. _boxes[j] = _boxes[j + 1]
  430. _boxes[j + 1] = tmp
  431. else:
  432. break
  433. return _boxes
  434. def get_minarea_rect_crop(self, img, points):
  435. """get_minarea_rect_crop"""
  436. bounding_box = cv2.minAreaRect(np.array(points).astype(np.int32))
  437. points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
  438. index_a, index_b, index_c, index_d = 0, 1, 2, 3
  439. if points[1][1] > points[0][1]:
  440. index_a = 0
  441. index_d = 1
  442. else:
  443. index_a = 1
  444. index_d = 0
  445. if points[3][1] > points[2][1]:
  446. index_b = 2
  447. index_c = 3
  448. else:
  449. index_b = 3
  450. index_c = 2
  451. box = [points[index_a], points[index_b], points[index_c], points[index_d]]
  452. crop_img = self.get_rotate_crop_image(img, np.array(box))
  453. return crop_img
  454. def get_rotate_crop_image(self, img, points):
  455. """
  456. img_height, img_width = img.shape[0:2]
  457. left = int(np.min(points[:, 0]))
  458. right = int(np.max(points[:, 0]))
  459. top = int(np.min(points[:, 1]))
  460. bottom = int(np.max(points[:, 1]))
  461. img_crop = img[top:bottom, left:right, :].copy()
  462. points[:, 0] = points[:, 0] - left
  463. points[:, 1] = points[:, 1] - top
  464. """
  465. assert len(points) == 4, "shape of points must be 4*2"
  466. img_crop_width = int(
  467. max(
  468. np.linalg.norm(points[0] - points[1]),
  469. np.linalg.norm(points[2] - points[3]),
  470. )
  471. )
  472. img_crop_height = int(
  473. max(
  474. np.linalg.norm(points[0] - points[3]),
  475. np.linalg.norm(points[1] - points[2]),
  476. )
  477. )
  478. pts_std = np.float32(
  479. [
  480. [0, 0],
  481. [img_crop_width, 0],
  482. [img_crop_width, img_crop_height],
  483. [0, img_crop_height],
  484. ]
  485. )
  486. M = cv2.getPerspectiveTransform(points, pts_std)
  487. dst_img = cv2.warpPerspective(
  488. img,
  489. M,
  490. (img_crop_width, img_crop_height),
  491. borderMode=cv2.BORDER_REPLICATE,
  492. flags=cv2.INTER_CUBIC,
  493. )
  494. dst_img_height, dst_img_width = dst_img.shape[0:2]
  495. if dst_img_height * 1.0 / dst_img_width >= 1.5:
  496. dst_img = np.rot90(dst_img)
  497. return dst_img
  498. class SaveTextDetResults(BaseTransform):
  499. """Save Text Det Results"""
  500. def __init__(self, save_dir):
  501. super().__init__()
  502. self.save_dir = save_dir
  503. # We use pillow backend to save both numpy arrays and PIL Image objects
  504. self._writer = ImageWriter(backend="opencv")
  505. def apply(self, data):
  506. """apply"""
  507. if self.save_dir is None:
  508. logging.warning(
  509. "The `save_dir` has been set to None, so the text detection result won't to be saved."
  510. )
  511. return data
  512. fn = os.path.basename(data["input_path"])
  513. save_path = os.path.join(self.save_dir, fn)
  514. bbox_res = data[K.DT_POLYS]
  515. vis_img = self.draw_rectangle(data[K.IM_PATH], bbox_res)
  516. self._writer.write(save_path, vis_img)
  517. return data
  518. @classmethod
  519. def get_input_keys(cls):
  520. """get input keys"""
  521. return [K.IM_PATH, K.DT_POLYS, K.DT_SCORES]
  522. @classmethod
  523. def get_output_keys(cls):
  524. """get output keys"""
  525. return []
  526. def draw_rectangle(self, img_path, boxes):
  527. """draw rectangle"""
  528. boxes = np.array(boxes)
  529. img = cv2.imread(img_path)
  530. img_show = img.copy()
  531. for box in boxes.astype(int):
  532. box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
  533. cv2.polylines(img_show, [box], True, (0, 0, 255), 2)
  534. return img_show
  535. class PrintResult(BaseTransform):
  536. """Print Result Transform"""
  537. def apply(self, data):
  538. """apply"""
  539. logging.info("The prediction result is:")
  540. logging.info(data[K.DT_POLYS])
  541. return data
  542. @classmethod
  543. def get_input_keys(cls):
  544. """get input keys"""
  545. return [K.DT_SCORES]
  546. @classmethod
  547. def get_output_keys(cls):
  548. """get output keys"""
  549. return []
  550. # DT_SCORES = 'dt_scores'
  551. # DT_POLYS = 'dt_polys'