predictor.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import List, Union
  15. import numpy as np
  16. from ....modules.text_detection.model_list import MODELS
  17. from ....utils.func_register import FuncRegister
  18. from ...common.batch_sampler import ImageBatchSampler
  19. from ...common.reader import ReadImage
  20. from ..base import BasePredictor
  21. from ..common import ToBatch, ToCHWImage
  22. from .processors import DBPostProcess, DetResizeForTest, NormalizeImage
  23. from .result import TextDetResult
  24. class TextDetPredictor(BasePredictor):
  25. entities = MODELS
  26. _FUNC_MAP = {}
  27. register = FuncRegister(_FUNC_MAP)
  28. def __init__(
  29. self,
  30. limit_side_len: Union[int, None] = None,
  31. limit_type: Union[str, None] = None,
  32. thresh: Union[float, None] = None,
  33. box_thresh: Union[float, None] = None,
  34. unclip_ratio: Union[float, None] = None,
  35. input_shape=None,
  36. max_side_limit: int = 4000,
  37. *args,
  38. **kwargs
  39. ):
  40. super().__init__(*args, **kwargs)
  41. self.limit_side_len = limit_side_len
  42. self.limit_type = limit_type
  43. self.thresh = thresh
  44. self.box_thresh = box_thresh
  45. self.unclip_ratio = unclip_ratio
  46. self.input_shape = input_shape
  47. self.max_side_limit = max_side_limit
  48. self.pre_tfs, self.infer, self.post_op = self._build()
  49. def _build_batch_sampler(self):
  50. return ImageBatchSampler()
  51. def _get_result_class(self):
  52. return TextDetResult
  53. def _build(self):
  54. pre_tfs = {"Read": ReadImage(format="RGB")}
  55. for cfg in self.config["PreProcess"]["transform_ops"]:
  56. tf_key = list(cfg.keys())[0]
  57. func = self._FUNC_MAP[tf_key]
  58. args = cfg.get(tf_key, {})
  59. name, op = func(self, **args) if args else func(self)
  60. if op:
  61. pre_tfs[name] = op
  62. pre_tfs["ToBatch"] = ToBatch()
  63. infer = self.create_static_infer()
  64. post_op = self.build_postprocess(**self.config["PostProcess"])
  65. return pre_tfs, infer, post_op
  66. def process(
  67. self,
  68. batch_data: List[Union[str, np.ndarray]],
  69. limit_side_len: Union[int, None] = None,
  70. limit_type: Union[str, None] = None,
  71. thresh: Union[float, None] = None,
  72. box_thresh: Union[float, None] = None,
  73. unclip_ratio: Union[float, None] = None,
  74. max_side_limit: Union[int, None] = None,
  75. ):
  76. batch_raw_imgs = self.pre_tfs["Read"](imgs=batch_data.instances)
  77. batch_imgs, batch_shapes = self.pre_tfs["Resize"](
  78. imgs=batch_raw_imgs,
  79. limit_side_len=limit_side_len or self.limit_side_len,
  80. limit_type=limit_type or self.limit_type,
  81. max_side_limit=(
  82. max_side_limit if max_side_limit is not None else self.max_side_limit
  83. ),
  84. )
  85. batch_imgs = self.pre_tfs["Normalize"](imgs=batch_imgs)
  86. batch_imgs = self.pre_tfs["ToCHW"](imgs=batch_imgs)
  87. x = self.pre_tfs["ToBatch"](imgs=batch_imgs)
  88. batch_preds = self.infer(x=x)
  89. polys, scores = self.post_op(
  90. batch_preds,
  91. batch_shapes,
  92. thresh=thresh or self.thresh,
  93. box_thresh=box_thresh or self.box_thresh,
  94. unclip_ratio=unclip_ratio or self.unclip_ratio,
  95. )
  96. return {
  97. "input_path": batch_data.input_paths,
  98. "page_index": batch_data.page_indexes,
  99. "input_img": batch_raw_imgs,
  100. "dt_polys": polys,
  101. "dt_scores": scores,
  102. }
  103. @register("DecodeImage")
  104. def build_readimg(self, channel_first, img_mode):
  105. assert channel_first == False
  106. return "Read", ReadImage(format=img_mode)
  107. @register("DetResizeForTest")
  108. def build_resize(
  109. self,
  110. limit_side_len: Union[int, None] = None,
  111. limit_type: Union[str, None] = None,
  112. **kwargs
  113. ):
  114. # TODO: align to PaddleOCR
  115. if self.model_name in (
  116. "PP-OCRv5_server_det",
  117. "PP-OCRv5_mobile_det",
  118. "PP-OCRv4_server_det",
  119. "PP-OCRv4_mobile_det",
  120. "PP-OCRv3_server_det",
  121. "PP-OCRv3_mobile_det",
  122. ):
  123. limit_side_len = self.limit_side_len or kwargs.get("resize_long", 960)
  124. limit_type = self.limit_type or kwargs.get("limit_type", "max")
  125. else:
  126. limit_side_len = self.limit_side_len or kwargs.get("resize_long", 736)
  127. limit_type = self.limit_type or kwargs.get("limit_type", "min")
  128. return "Resize", DetResizeForTest(
  129. limit_side_len=limit_side_len,
  130. limit_type=limit_type,
  131. input_shape=self.input_shape,
  132. **kwargs
  133. )
  134. @register("NormalizeImage")
  135. def build_normalize(
  136. self,
  137. mean=[0.485, 0.456, 0.406],
  138. std=[0.229, 0.224, 0.225],
  139. scale=1 / 255,
  140. order="",
  141. ):
  142. return "Normalize", NormalizeImage(mean=mean, std=std, scale=scale, order=order)
  143. @register("ToCHWImage")
  144. def build_to_chw(self):
  145. return "ToCHW", ToCHWImage()
  146. def build_postprocess(self, **kwargs):
  147. if kwargs.get("name") == "DBPostProcess":
  148. return DBPostProcess(
  149. thresh=self.thresh or kwargs.get("thresh", 0.3),
  150. box_thresh=self.box_thresh or kwargs.get("box_thresh", 0.6),
  151. unclip_ratio=self.unclip_ratio or kwargs.get("unclip_ratio", 2.0),
  152. max_candidates=kwargs.get("max_candidates", 1000),
  153. use_dilation=kwargs.get("use_dilation", False),
  154. score_mode=kwargs.get("score_mode", "fast"),
  155. box_type=kwargs.get("box_type", "quad"),
  156. )
  157. else:
  158. raise Exception()
  159. @register("DetLabelEncode")
  160. def foo(self, *args, **kwargs):
  161. return None, None
  162. @register("KeepKeys")
  163. def foo(self, *args, **kwargs):
  164. return None, None