classifier.h 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #pragma once
  15. #include "ultra_infer/ultra_infer_model.h"
  16. #include "ultra_infer/utils/unique_ptr.h"
  17. #include "ultra_infer/vision/common/processors/transform.h"
  18. #include "ultra_infer/vision/common/result.h"
  19. #include "ultra_infer/vision/ocr/ppocr/cls_postprocessor.h"
  20. #include "ultra_infer/vision/ocr/ppocr/cls_preprocessor.h"
  21. #include "ultra_infer/vision/ocr/ppocr/utils/ocr_postprocess_op.h"
  22. namespace ultra_infer {
  23. namespace vision {
  24. /** \brief All OCR series model APIs are defined inside this namespace
  25. *
  26. */
  27. namespace ocr {
  28. /*! @brief Classifier object is used to load the classification model provided
  29. * by PaddleOCR.
  30. */
  31. class ULTRAINFER_DECL Classifier : public UltraInferModel {
  32. public:
  33. Classifier();
  34. /** \brief Set path of model file, and the configuration of runtime
  35. *
  36. * \param[in] model_file Path of model file, e.g
  37. * ./ch_ppocr_mobile_v2.0_cls_infer/model.pdmodel. \param[in] params_file Path
  38. * of parameter file, e.g ./ch_ppocr_mobile_v2.0_cls_infer/model.pdiparams, if
  39. * the model format is ONNX, this parameter will be ignored. \param[in]
  40. * custom_option RuntimeOption for inference, the default will use cpu, and
  41. * choose the backend defined in `valid_cpu_backends`. \param[in] model_format
  42. * Model format of the loaded model, default is Paddle format.
  43. */
  44. Classifier(const std::string &model_file, const std::string &params_file = "",
  45. const RuntimeOption &custom_option = RuntimeOption(),
  46. const ModelFormat &model_format = ModelFormat::PADDLE);
  47. /** \brief Clone a new Classifier with less memory usage when multiple
  48. * instances of the same model are created
  49. *
  50. * \return new Classifier* type unique pointer
  51. */
  52. virtual std::unique_ptr<Classifier> Clone() const;
  53. /// Get model's name
  54. std::string ModelName() const { return "ppocr/ocr_cls"; }
  55. /** \brief Predict the input image and get OCR classification model
  56. * cls_result.
  57. *
  58. * \param[in] img The input image data, comes from cv::imread(), is a 3-D
  59. * array with layout HWC, BGR format. \param[in] cls_label The label result of
  60. * cls model will be written in to this param. \param[in] cls_score The score
  61. * result of cls model will be written in to this param. \return true if the
  62. * prediction is succeeded, otherwise false.
  63. */
  64. virtual bool Predict(const cv::Mat &img, int32_t *cls_label,
  65. float *cls_score);
  66. /** \brief Predict the input image and get OCR recognition model result.
  67. *
  68. * \param[in] img The input image data, comes from cv::imread(), is a 3-D
  69. * array with layout HWC, BGR format. \param[in] ocr_result The output of OCR
  70. * recognition model result will be written to this structure. \return true if
  71. * the prediction is succeeded, otherwise false.
  72. */
  73. virtual bool Predict(const cv::Mat &img, vision::OCRResult *ocr_result);
  74. /** \brief BatchPredict the input image and get OCR classification model
  75. * result.
  76. *
  77. * \param[in] img The input image data, comes from cv::imread(), is a 3-D
  78. * array with layout HWC, BGR format. \param[in] ocr_result The output of OCR
  79. * classification model result will be written to this structure. \return true
  80. * if the prediction is succeeded, otherwise false.
  81. */
  82. virtual bool BatchPredict(const std::vector<cv::Mat> &images,
  83. vision::OCRResult *ocr_result);
  84. /** \brief BatchPredict the input image and get OCR classification model
  85. * cls_result.
  86. *
  87. * \param[in] images The list of input image data, comes from cv::imread(), is
  88. * a 3-D array with layout HWC, BGR format. \param[in] cls_labels The label
  89. * results of cls model will be written in to this vector. \param[in]
  90. * cls_scores The score results of cls model will be written in to this
  91. * vector. \return true if the prediction is succeeded, otherwise false.
  92. */
  93. virtual bool BatchPredict(const std::vector<cv::Mat> &images,
  94. std::vector<int32_t> *cls_labels,
  95. std::vector<float> *cls_scores);
  96. virtual bool BatchPredict(const std::vector<cv::Mat> &images,
  97. std::vector<int32_t> *cls_labels,
  98. std::vector<float> *cls_scores, size_t start_index,
  99. size_t end_index);
  100. /// Get preprocessor reference of ClassifierPreprocessor
  101. virtual ClassifierPreprocessor &GetPreprocessor() { return preprocessor_; }
  102. /// Get postprocessor reference of ClassifierPostprocessor
  103. virtual ClassifierPostprocessor &GetPostprocessor() { return postprocessor_; }
  104. private:
  105. bool Initialize();
  106. ClassifierPreprocessor preprocessor_;
  107. ClassifierPostprocessor postprocessor_;
  108. };
  109. } // namespace ocr
  110. } // namespace vision
  111. } // namespace ultra_infer