model.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from ....utils import logging
  16. from ...base.utils.arg import CLIArgument
  17. from ...base.utils.subprocess import CompletedProcess
  18. from ....utils.misc import abspath
  19. from ..text_rec.model import TextRecModel
  20. class TableRecModel(TextRecModel):
  21. """ Table Recognition Model """
  22. METRICS = ['acc']
  23. def predict(self,
  24. weight_path: str,
  25. input_path: str,
  26. device: str='gpu',
  27. save_dir: str=None,
  28. **kwargs) -> CompletedProcess:
  29. """predict using specified weight
  30. Args:
  31. weight_path (str): the path of model weight file used to predict.
  32. input_path (str): the path of image file to be predicted.
  33. device (str, optional): the running device. Defaults to 'gpu'.
  34. save_dir (str, optional): the directory path to save predict output. Defaults to None.
  35. Returns:
  36. CompletedProcess: the result of predicting subprocess execution.
  37. """
  38. config = self.config.copy()
  39. weight_path = abspath(weight_path)
  40. config.update_pretrained_weights(weight_path)
  41. input_path = abspath(input_path)
  42. config._update_infer_img(input_path)
  43. # TODO: Handle `device`
  44. logging.warning("`device` will not be used.")
  45. if save_dir is not None:
  46. save_dir = abspath(save_dir)
  47. else:
  48. save_dir = abspath(config.get_predict_save_dir())
  49. config._update_save_res_path(save_dir)
  50. self._assert_empty_kwargs(kwargs)
  51. with self._create_new_config_file() as config_path:
  52. config.dump(config_path)
  53. return self.runner.predict(config_path, [], device)
  54. def infer(self,
  55. model_dir: str,
  56. input_path: str,
  57. device: str='gpu',
  58. save_dir: str=None,
  59. **kwargs) -> CompletedProcess:
  60. """predict image using infernece model
  61. Args:
  62. model_dir (str): the directory path of inference model files that would use to predict.
  63. input_path (str): the path of image that would be predict.
  64. device (str, optional): the running device. Defaults to 'gpu'.
  65. save_dir (str, optional): the directory path to save output. Defaults to None.
  66. Returns:
  67. CompletedProcess: the result of infering subprocess execution.
  68. """
  69. config = self.config.copy()
  70. cli_args = []
  71. model_dir = abspath(model_dir)
  72. cli_args.append(CLIArgument('--table_model_dir', model_dir))
  73. input_path = abspath(input_path)
  74. cli_args.append(CLIArgument('--image_dir', input_path))
  75. device_type, _ = self.runner.parse_device(device)
  76. cli_args.append(CLIArgument('--use_gpu', str(device_type == 'gpu')))
  77. if save_dir is not None:
  78. save_dir = abspath(save_dir)
  79. else:
  80. # `save_dir` is None
  81. save_dir = abspath(os.path.join('output', 'infer'))
  82. cli_args.append(CLIArgument('--output', save_dir))
  83. dict_path = kwargs.pop('dict_path', None)
  84. if dict_path is not None:
  85. dict_path = abspath(dict_path)
  86. else:
  87. dict_path = config.get_label_dict_path()
  88. cli_args.append(CLIArgument('--table_char_dict_path', dict_path))
  89. model_type = config._get_model_type()
  90. cli_args.append(CLIArgument('--table_algorithm', model_type))
  91. infer_shape = config._get_infer_shape()
  92. if infer_shape is not None:
  93. cli_args.append(CLIArgument('--table_max_len', infer_shape))
  94. self._assert_empty_kwargs(kwargs)
  95. with self._create_new_config_file() as config_path:
  96. config.dump(config_path)
  97. return self.runner.infer(config_path, cli_args, device)