seal_recognition.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from .base import BasePipeline
  16. from .ocr import OCRPipeline
  17. from ..components import CropByBoxes, ReadImage
  18. from ..results import SealOCRResult, OCRResult
  19. from ...utils import logging
  20. def get_ocr_res(pipeline, input):
  21. """get ocr res"""
  22. ocr_res_list = []
  23. if isinstance(input, list):
  24. img = [im["img"] for im in input]
  25. elif isinstance(input, dict):
  26. img = input["img"]
  27. else:
  28. img = input
  29. for ocr_res in pipeline(img):
  30. ocr_res_list.append(ocr_res)
  31. return ocr_res_list
  32. class SealOCRPipeline(BasePipeline):
  33. """Seal Recognition Pipeline"""
  34. entities = "seal_recognition"
  35. def __init__(
  36. self,
  37. layout_model,
  38. text_det_model,
  39. text_rec_model,
  40. layout_batch_size=1,
  41. text_det_batch_size=1,
  42. text_rec_batch_size=1,
  43. device=None,
  44. predictor_kwargs=None,
  45. ):
  46. super().__init__(device, predictor_kwargs)
  47. self._build_predictor(
  48. layout_model=layout_model,
  49. text_det_model=text_det_model,
  50. text_rec_model=text_rec_model,
  51. layout_batch_size=layout_batch_size,
  52. text_det_batch_size=text_det_batch_size,
  53. text_rec_batch_size=text_rec_batch_size,
  54. )
  55. self.set_predictor(
  56. layout_batch_size=layout_batch_size,
  57. text_det_batch_size=text_det_batch_size,
  58. text_rec_batch_size=text_rec_batch_size,
  59. )
  60. self.img_reader = ReadImage(format="BGR")
  61. def _build_predictor(
  62. self,
  63. layout_model,
  64. text_det_model,
  65. text_rec_model,
  66. layout_batch_size,
  67. text_det_batch_size,
  68. text_rec_batch_size,
  69. ):
  70. self.layout_predictor = self._create(model=layout_model)
  71. self.ocr_pipeline = self._create(
  72. pipeline=OCRPipeline,
  73. text_det_model=text_det_model,
  74. text_rec_model=text_rec_model,
  75. )
  76. self._crop_by_boxes = CropByBoxes()
  77. def set_predictor(
  78. self,
  79. layout_batch_size=None,
  80. text_det_batch_size=None,
  81. text_rec_batch_size=None,
  82. device=None,
  83. ):
  84. if text_det_batch_size and text_det_batch_size > 1:
  85. logging.warning(
  86. f"text det model only support batch_size=1 now,the setting of text_det_batch_size={text_det_batch_size} will not using! "
  87. )
  88. if layout_batch_size:
  89. self.layout_predictor.set_predictor(batch_size=layout_batch_size)
  90. if text_rec_batch_size:
  91. self.ocr_pipeline.text_rec_model.set_predictor(
  92. batch_size=text_rec_batch_size
  93. )
  94. if device:
  95. self.layout_predictor.set_predictor(device=device)
  96. self.ocr_pipeline.set_predictor(device=device)
  97. def predict(self, inputs, **kwargs):
  98. self.set_predictor(**kwargs)
  99. img_info_list = list(self.img_reader(inputs))[0]
  100. img_list = [img_info["img"] for img_info in img_info_list]
  101. for page_id, layout_pred in enumerate(self.layout_predictor(img_list)):
  102. single_img_res = {
  103. "input_path": "",
  104. "layout_result": {},
  105. "ocr_result": {},
  106. }
  107. # update layout result
  108. single_img_res["input_path"] = layout_pred["input_path"]
  109. single_img_res["layout_result"] = layout_pred
  110. seal_subs = []
  111. if len(layout_pred["boxes"]) > 0:
  112. subs_of_img = list(self._crop_by_boxes(layout_pred))
  113. # get cropped images with label "seal"
  114. for sub in subs_of_img:
  115. box = sub["box"]
  116. if sub["label"].lower() == "seal":
  117. seal_subs.append(sub)
  118. all_seal_ocr_res = get_ocr_res(self.ocr_pipeline, seal_subs)
  119. seal_res = {
  120. "dt_polys": [],
  121. "dt_scores": [],
  122. "rec_text": [],
  123. "rec_score": [],
  124. }
  125. for sub, seal_ocr_res in zip(seal_subs, all_seal_ocr_res):
  126. if len(seal_ocr_res["dt_polys"]) > 0:
  127. box = sub["box"]
  128. ori_bbox_list = [
  129. dt + np.array(box[:2]).astype(np.int32)
  130. for dt in seal_ocr_res["dt_polys"]
  131. ]
  132. seal_res["dt_polys"].extend(ori_bbox_list)
  133. seal_res["dt_scores"].extend(seal_ocr_res["dt_scores"])
  134. seal_res["rec_text"].extend(seal_ocr_res["rec_text"])
  135. seal_res["rec_score"].extend(seal_ocr_res["rec_score"])
  136. seal_res["input_path"] = single_img_res["input_path"]
  137. single_img_res["src_file_name"] = inputs
  138. single_img_res["ocr_result"] = OCRResult(seal_res)
  139. single_img_res["page_id"] = page_id
  140. yield SealOCRResult(single_img_res)