evaluator.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from pathlib import Path
  15. from ..base import BaseEvaluator
  16. from .model_list import MODELS
  17. class FormulaRecEvaluator(BaseEvaluator):
  18. """Text Recognition Model Evaluator"""
  19. entities = MODELS
  20. def update_config(self):
  21. """update evaluation config"""
  22. if self.eval_config.log_interval:
  23. self.pdx_config.update_log_interval(self.eval_config.log_interval)
  24. if self.global_config["model"] == "LaTeX_OCR_rec":
  25. self.pdx_config.update_dataset(
  26. self.global_config.dataset_dir, "LaTeXOCRDataSet"
  27. )
  28. elif self.global_config["model"] in (
  29. "UniMERNet",
  30. "PP-FormulaNet-L",
  31. "PP-FormulaNet-S",
  32. "PP-FormulaNet_plus-L",
  33. "PP-FormulaNet_plus-M",
  34. "PP-FormulaNet_plus-S",
  35. ):
  36. self.pdx_config.update_dataset(
  37. self.global_config.dataset_dir, "SimpleDataSet"
  38. )
  39. label_dict_path = None
  40. if self.eval_config.get("label_dict_path"):
  41. label_dict_path = self.eval_config.label_dict_path
  42. else:
  43. label_dict_path = (
  44. Path(self.eval_config.weight_path).parent / "label_dict.txt"
  45. )
  46. if not label_dict_path.exists():
  47. label_dict_path = None
  48. if label_dict_path is not None:
  49. self.pdx_config.update_label_dict_path(label_dict_path)
  50. if self.eval_config.batch_size is not None:
  51. if self.global_config["model"] == "LaTeX_OCR_rec":
  52. self.pdx_config.update_batch_size_pair(
  53. self.eval_config.batch_size, mode="eval"
  54. )
  55. else:
  56. self.pdx_config.update_batch_size(
  57. self.eval_config.batch_size, mode="eval"
  58. )
  59. if self.eval_config.get("delimiter", None) is not None:
  60. self.pdx_config.update_delimiter(self.eval_config.delimiter, mode="eval")
  61. def get_eval_kwargs(self) -> dict:
  62. """get key-value arguments of model evaluation function
  63. Returns:
  64. dict: the arguments of evaluation function.
  65. """
  66. return {
  67. "weight_path": self.eval_config.weight_path,
  68. "device": self.get_device(),
  69. }