predictor.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import numpy as np
  16. from pathlib import Path
  17. from ...base import BasePredictor
  18. from ...base.predictor.transforms import image_common
  19. from .keys import WarpKeys as K
  20. from . import transforms as T
  21. from ..model_list import MODELS
  22. class WarpPredictor(BasePredictor):
  23. """Clssification Predictor"""
  24. entities = MODELS
  25. @classmethod
  26. def get_input_keys(cls):
  27. """get input keys"""
  28. return [[K.IMAGE], [K.IM_PATH]]
  29. @classmethod
  30. def get_output_keys(cls):
  31. """get output keys"""
  32. return [K.DOCTR_IMG]
  33. def _run(self, batch_input):
  34. """run"""
  35. input_dict = {}
  36. input_dict[K.IMAGE] = np.stack(
  37. [data[K.IMAGE] for data in batch_input], axis=0
  38. ).astype(dtype=np.float32, copy=False)
  39. input_ = [input_dict[K.IMAGE]]
  40. outputs = self._predictor.predict(input_)
  41. Warp_outs = outputs[0]
  42. # In-place update
  43. pred = batch_input
  44. for dict_, Warp_out in zip(pred, Warp_outs):
  45. dict_[K.DOCTR_IMG] = Warp_out
  46. return pred
  47. def _get_pre_transforms_from_config(self):
  48. """get preprocess transforms"""
  49. pre_transforms = [
  50. image_common.ReadImage(format='RGB'),
  51. image_common.Normalize(scale=1./255, mean=0.0, std=1.0),
  52. image_common.ToCHWImage()
  53. ]
  54. return pre_transforms
  55. def _get_post_transforms_from_config(self):
  56. """get postprocess transforms"""
  57. post_transforms = [
  58. T.DocTrPostProcess(scale=255.),
  59. T.SaveDocTrResults(self.output)
  60. ] # yapf: disable
  61. return post_transforms