predictor.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import List, Union
  15. import numpy as np
  16. from ....modules.text_detection.model_list import MODELS
  17. from ....utils.func_register import FuncRegister
  18. from ...common.batch_sampler import ImageBatchSampler
  19. from ...common.reader import ReadImage
  20. from ..base import BasePredictor
  21. from ..common import ToBatch, ToCHWImage
  22. from .processors import DBPostProcess, DetResizeForTest, NormalizeImage
  23. from .result import TextDetResult
  24. class TextDetPredictor(BasePredictor):
  25. entities = MODELS
  26. _FUNC_MAP = {}
  27. register = FuncRegister(_FUNC_MAP)
  28. def __init__(
  29. self,
  30. limit_side_len: Union[int, None] = None,
  31. limit_type: Union[str, None] = None,
  32. thresh: Union[float, None] = None,
  33. box_thresh: Union[float, None] = None,
  34. unclip_ratio: Union[float, None] = None,
  35. input_shape=None,
  36. *args,
  37. **kwargs
  38. ):
  39. super().__init__(*args, **kwargs)
  40. self.limit_side_len = limit_side_len
  41. self.limit_type = limit_type
  42. self.thresh = thresh
  43. self.box_thresh = box_thresh
  44. self.unclip_ratio = unclip_ratio
  45. self.input_shape = input_shape
  46. self.pre_tfs, self.infer, self.post_op = self._build()
  47. def _build_batch_sampler(self):
  48. return ImageBatchSampler()
  49. def _get_result_class(self):
  50. return TextDetResult
  51. def _build(self):
  52. pre_tfs = {"Read": ReadImage(format="RGB")}
  53. for cfg in self.config["PreProcess"]["transform_ops"]:
  54. tf_key = list(cfg.keys())[0]
  55. func = self._FUNC_MAP[tf_key]
  56. args = cfg.get(tf_key, {})
  57. name, op = func(self, **args) if args else func(self)
  58. if op:
  59. pre_tfs[name] = op
  60. pre_tfs["ToBatch"] = ToBatch()
  61. infer = self.create_static_infer()
  62. post_op = self.build_postprocess(**self.config["PostProcess"])
  63. return pre_tfs, infer, post_op
  64. def process(
  65. self,
  66. batch_data: List[Union[str, np.ndarray]],
  67. limit_side_len: Union[int, None] = None,
  68. limit_type: Union[str, None] = None,
  69. thresh: Union[float, None] = None,
  70. box_thresh: Union[float, None] = None,
  71. unclip_ratio: Union[float, None] = None,
  72. ):
  73. batch_raw_imgs = self.pre_tfs["Read"](imgs=batch_data.instances)
  74. batch_imgs, batch_shapes = self.pre_tfs["Resize"](
  75. imgs=batch_raw_imgs,
  76. limit_side_len=limit_side_len or self.limit_side_len,
  77. limit_type=limit_type or self.limit_type,
  78. )
  79. batch_imgs = self.pre_tfs["Normalize"](imgs=batch_imgs)
  80. batch_imgs = self.pre_tfs["ToCHW"](imgs=batch_imgs)
  81. x = self.pre_tfs["ToBatch"](imgs=batch_imgs)
  82. batch_preds = self.infer(x=x)
  83. polys, scores = self.post_op(
  84. batch_preds,
  85. batch_shapes,
  86. thresh=thresh or self.thresh,
  87. box_thresh=box_thresh or self.box_thresh,
  88. unclip_ratio=unclip_ratio or self.unclip_ratio,
  89. )
  90. return {
  91. "input_path": batch_data.input_paths,
  92. "page_index": batch_data.page_indexes,
  93. "input_img": batch_raw_imgs,
  94. "dt_polys": polys,
  95. "dt_scores": scores,
  96. }
  97. @register("DecodeImage")
  98. def build_readimg(self, channel_first, img_mode):
  99. assert channel_first == False
  100. return "Read", ReadImage(format=img_mode)
  101. @register("DetResizeForTest")
  102. def build_resize(
  103. self,
  104. limit_side_len: Union[int, None] = None,
  105. limit_type: Union[str, None] = None,
  106. **kwargs
  107. ):
  108. # TODO: align to PaddleOCR
  109. if self.model_name in (
  110. "PP-OCRv4_server_det",
  111. "PP-OCRv4_mobile_det",
  112. "PP-OCRv3_server_det",
  113. "PP-OCRv3_mobile_det",
  114. ):
  115. limit_side_len = self.limit_side_len or kwargs.get("resize_long", 960)
  116. limit_type = self.limit_type or kwargs.get("limit_type", "max")
  117. else:
  118. limit_side_len = self.limit_side_len or kwargs.get("resize_long", 736)
  119. limit_type = self.limit_type or kwargs.get("limit_type", "min")
  120. return "Resize", DetResizeForTest(
  121. limit_side_len=limit_side_len,
  122. limit_type=limit_type,
  123. input_shape=self.input_shape,
  124. **kwargs
  125. )
  126. @register("NormalizeImage")
  127. def build_normalize(
  128. self,
  129. mean=[0.485, 0.456, 0.406],
  130. std=[0.229, 0.224, 0.225],
  131. scale=1 / 255,
  132. order="",
  133. ):
  134. return "Normalize", NormalizeImage(mean=mean, std=std, scale=scale, order=order)
  135. @register("ToCHWImage")
  136. def build_to_chw(self):
  137. return "ToCHW", ToCHWImage()
  138. def build_postprocess(self, **kwargs):
  139. if kwargs.get("name") == "DBPostProcess":
  140. return DBPostProcess(
  141. thresh=self.thresh or kwargs.get("thresh", 0.3),
  142. box_thresh=self.box_thresh or kwargs.get("box_thresh", 0.6),
  143. unclip_ratio=self.unclip_ratio or kwargs.get("unclip_ratio", 2.0),
  144. max_candidates=kwargs.get("max_candidates", 1000),
  145. use_dilation=kwargs.get("use_dilation", False),
  146. score_mode=kwargs.get("score_mode", "fast"),
  147. box_type=kwargs.get("box_type", "quad"),
  148. )
  149. else:
  150. raise Exception()
  151. @register("DetLabelEncode")
  152. def foo(self, *args, **kwargs):
  153. return None, None
  154. @register("KeepKeys")
  155. def foo(self, *args, **kwargs):
  156. return None, None