seal_recognition.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from .base import BasePipeline
  16. from .ocr import OCRPipeline
  17. from ..components import CropByBoxes
  18. from ..results import SealOCRResult
  19. from ...utils import logging
  20. def get_ocr_res(pipeline, input):
  21. """get ocr res"""
  22. ocr_res_list = []
  23. if isinstance(input, list):
  24. img = [im["img"] for im in input]
  25. elif isinstance(input, dict):
  26. img = input["img"]
  27. else:
  28. img = input
  29. for ocr_res in pipeline(img):
  30. ocr_res_list.append(ocr_res)
  31. return ocr_res_list
  32. class SealOCRPipeline(BasePipeline):
  33. """Seal Recognition Pipeline"""
  34. entities = "seal_recognition"
  35. def __init__(
  36. self,
  37. layout_model,
  38. text_det_model,
  39. text_rec_model,
  40. layout_batch_size=1,
  41. text_det_batch_size=1,
  42. text_rec_batch_size=1,
  43. device=None,
  44. predictor_kwargs=None,
  45. ):
  46. super().__init__(device, predictor_kwargs)
  47. self._build_predictor(
  48. layout_model=layout_model,
  49. text_det_model=text_det_model,
  50. text_rec_model=text_rec_model,
  51. layout_batch_size=layout_batch_size,
  52. text_det_batch_size=text_det_batch_size,
  53. text_rec_batch_size=text_rec_batch_size,
  54. )
  55. self.set_predictor(
  56. layout_batch_size=layout_batch_size,
  57. text_det_batch_size=text_det_batch_size,
  58. text_rec_batch_size=text_rec_batch_size,
  59. )
  60. def _build_predictor(
  61. self,
  62. layout_model,
  63. text_det_model,
  64. text_rec_model,
  65. layout_batch_size,
  66. text_det_batch_size,
  67. text_rec_batch_size,
  68. ):
  69. self.layout_predictor = self._create(model=layout_model)
  70. self.ocr_pipeline = self._create(
  71. pipeline=OCRPipeline,
  72. text_det_model=text_det_model,
  73. text_rec_model=text_rec_model,
  74. )
  75. self._crop_by_boxes = CropByBoxes()
  76. def set_predictor(
  77. self,
  78. layout_batch_size=None,
  79. text_det_batch_size=None,
  80. text_rec_batch_size=None,
  81. device=None,
  82. ):
  83. if text_det_batch_size and text_det_batch_size > 1:
  84. logging.warning(
  85. f"text det model only support batch_size=1 now,the setting of text_det_batch_size={text_det_batch_size} will not using! "
  86. )
  87. if layout_batch_size:
  88. self.layout_predictor.set_predictor(batch_size=layout_batch_size)
  89. if text_rec_batch_size:
  90. self.ocr_pipeline.text_rec_model.set_predictor(
  91. batch_size=text_rec_batch_size
  92. )
  93. if device:
  94. self.layout_predictor.set_predictor(device=device)
  95. self.ocr_pipeline.set_predictor(device=device)
  96. def predict(self, x, **kwargs):
  97. self.set_predictor(**kwargs)
  98. for layout_pred in self.layout_predictor(x):
  99. single_img_res = {
  100. "input_path": "",
  101. "layout_result": {},
  102. "ocr_result": {},
  103. }
  104. # update layout result
  105. single_img_res["input_path"] = layout_pred["input_path"]
  106. single_img_res["layout_result"] = layout_pred
  107. seal_subs = []
  108. if len(layout_pred["boxes"]) > 0:
  109. subs_of_img = list(self._crop_by_boxes(layout_pred))
  110. # get cropped images with label "seal"
  111. for sub in subs_of_img:
  112. box = sub["box"]
  113. if sub["label"].lower() == "seal":
  114. seal_subs.append(sub)
  115. all_seal_ocr_res = get_ocr_res(self.ocr_pipeline, seal_subs)
  116. single_img_res["ocr_result"] = all_seal_ocr_res
  117. yield SealOCRResult(single_img_res)