# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from ..base import BasePipeline from typing import Any, Dict, Optional from ..components import SortQuadBoxes, CropByPolys from .result import OCRResult ########## [TODO]后续需要更新路径 from ...components.transforms import ReadImage class OCRPipeline(BasePipeline): """OCR Pipeline""" entities = "OCR" def __init__(self, config, device=None, pp_option=None, use_hpip: bool = False, hpi_params: Optional[Dict[str, Any]] = None): super().__init__(device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_params=hpi_params) text_det_model_config = config['SubModules']["TextDetection"] self.text_det_model = self.create_model(text_det_model_config) text_rec_model_config = config['SubModules']["TextRecognition"] self.text_rec_model = self.create_model(text_rec_model_config) self.text_type = config['text_type'] self._sort_quad_boxes = SortQuadBoxes() if self.text_type == "common": self._crop_by_polys = CropByPolys(det_box_type = "quad") elif self.text_type == "seal": self._crop_by_polys = CropByPolys(det_box_type = "poly") else: raise ValueError("Unsupported text type {}".format(self.text_type)) self.img_reader = ReadImage(format="BGR") def predict(self, input, **kwargs): if not isinstance(input, list): input_list = [input] else: input_list = input img_id = 1 for input in input_list: if isinstance(input, str): image_array = next(self.img_reader(input))[0]['img'] else: image_array = input assert len(image_array.shape) == 3 det_res = next(self.text_det_model(image_array)) dt_polys = det_res['dt_polys'] dt_scores = det_res['dt_scores'] ########## [TODO]需要确认检测模块和识别模块过滤阈值等情况 if self.text_type == "common": dt_polys = self._sort_quad_boxes(dt_polys) single_img_res = {'input_img':image_array, 'dt_polys':dt_polys, \ "img_id":img_id, "text_type":self.text_type} img_id += 1 single_img_res["rec_text"] = [] single_img_res["rec_score"] = [] if len(dt_polys) > 0: all_subs_of_img = list(self._crop_by_polys(image_array, dt_polys)) ########## [TODO]updata in future for sub_img in all_subs_of_img: sub_img['input'] = sub_img['img'] ########## for rec_res in self.text_rec_model(all_subs_of_img): single_img_res["rec_text"].append(rec_res["rec_text"]) single_img_res["rec_score"].append(rec_res["rec_score"]) yield OCRResult(single_img_res)