pipeline.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from ..base import BasePipeline
  15. from typing import Any, Dict, Optional
  16. from scipy.ndimage import rotate
  17. from .result import DocPreprocessorResult
  18. ########## [TODO]后续需要更新路径
  19. from ...components.transforms import ReadImage
  20. class DocPreprocessorPipeline(BasePipeline):
  21. """Doc Preprocessor Pipeline"""
  22. entities = "doc_preprocessor"
  23. def __init__(self,
  24. config,
  25. device=None,
  26. pp_option=None,
  27. use_hpip: bool = False,
  28. hpi_params: Optional[Dict[str, Any]] = None):
  29. super().__init__(device=device, pp_option=pp_option,
  30. use_hpip=use_hpip, hpi_params=hpi_params)
  31. self.use_doc_orientation_classify = True
  32. if 'use_doc_orientation_classify' in config:
  33. self.use_doc_orientation_classify = config['use_doc_orientation_classify']
  34. self.use_doc_unwarping = True
  35. if 'use_doc_unwarping' in config:
  36. self.use_doc_unwarping = config['use_doc_unwarping']
  37. if self.use_doc_orientation_classify:
  38. doc_ori_classify_config = config['SubModules']["DocOrientationClassify"]
  39. self.doc_ori_classify_model = self.create_model(doc_ori_classify_config)
  40. if self.use_doc_unwarping:
  41. doc_unwarping_config = config['SubModules']["DocUnwarping"]
  42. self.doc_unwarping_model = self.create_model(doc_unwarping_config)
  43. self.img_reader = ReadImage(format="BGR")
  44. def rotate_image(self, image_array, rotate_angle):
  45. """rotate image"""
  46. assert (
  47. rotate_angle >= 0 and rotate_angle < 360
  48. ), "rotate_angle must in [0-360), but get {rotate_angle}."
  49. return rotate(image_array, rotate_angle, reshape=True)
  50. def check_input_params(self, input_params):
  51. if input_params['use_doc_orientation_classify'] and \
  52. not self.use_doc_orientation_classify:
  53. raise ValueError("The model for doc orientation classify is not initialized.")
  54. if input_params['use_doc_unwarping'] and \
  55. not self.use_doc_unwarping:
  56. raise ValueError("The model for doc unwarping is not initialized.")
  57. return
  58. def predict(self, input,
  59. use_doc_orientation_classify=True,
  60. use_doc_unwarping=False,
  61. **kwargs):
  62. if not isinstance(input, list):
  63. input_list = [input]
  64. else:
  65. input_list = input
  66. input_params = {"use_doc_orientation_classify":use_doc_orientation_classify,
  67. "use_doc_unwarping":use_doc_unwarping}
  68. self.check_input_params(input_params)
  69. img_id = 1
  70. for input in input_list:
  71. if isinstance(input, str):
  72. image_array = next(self.img_reader(input))[0]['img']
  73. else:
  74. image_array = input
  75. assert len(image_array.shape) == 3
  76. if input_params['use_doc_orientation_classify']:
  77. pred = next(self.doc_ori_classify_model(image_array))
  78. angle = int(pred["label_names"][0])
  79. rot_img = self.rotate_image(image_array, angle)
  80. else:
  81. angle = -1
  82. rot_img = image_array
  83. if input_params['use_doc_unwarping']:
  84. output_img = next(self.doc_unwarping_model(rot_img))['doctr_img']
  85. else:
  86. output_img = rot_img
  87. single_img_res = {"input_image":image_array,
  88. "input_params":input_params,
  89. "angle":angle,
  90. "rot_img":rot_img,
  91. "output_img":output_img,
  92. "img_id":img_id}
  93. img_id += 1
  94. yield DocPreprocessorResult(single_img_res)