predictor.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. import os
  16. from typing import List
  17. from ....modules.doc_vlm.model_list import MODELS
  18. from ....utils.device import TemporaryDeviceChanger
  19. from ....utils.env import get_device_type
  20. from ...common.batch_sampler import DocVLMBatchSampler
  21. from ..base import BasePredictor
  22. from .result import DocVLMResult
  23. class DocVLMPredictor(BasePredictor):
  24. entities = MODELS
  25. def __init__(self, *args, **kwargs):
  26. """Initializes DocVLMPredictor.
  27. Args:
  28. *args: Arbitrary positional arguments passed to the superclass.
  29. **kwargs: Arbitrary keyword arguments passed to the superclass.
  30. """
  31. import paddle
  32. super().__init__(*args, **kwargs)
  33. self.device = kwargs.get("device", None)
  34. self.dtype = (
  35. "bfloat16"
  36. if ("npu" in get_device_type() or paddle.amp.is_bfloat16_supported())
  37. else "float32"
  38. )
  39. self.infer, self.processor = self._build(**kwargs)
  40. def _build_batch_sampler(self):
  41. """Builds and returns an DocVLMBatchSampler instance.
  42. Returns:
  43. DocVLMBatchSampler: An instance of DocVLMBatchSampler.
  44. """
  45. return DocVLMBatchSampler(self.model_name)
  46. def _get_result_class(self):
  47. """Returns the result class, DocVLMResult.
  48. Returns:
  49. type: The DocVLMResult class.
  50. """
  51. return DocVLMResult
  52. def _build(self, **kwargs):
  53. """Build the model, and correspounding processor on the configuration.
  54. Returns:
  55. model: An instance of Paddle model, could be either a dynamic model or a static model.
  56. processor: The correspounding processor for the model.
  57. """
  58. from .modeling import PPChart2TableInference, PPDocBeeInference
  59. # build processor
  60. processor = self.build_processor()
  61. # build model
  62. if "PP-DocBee" in self.model_name:
  63. if kwargs.get("use_hpip", False):
  64. raise ValueError(
  65. f"PP-DocBee series do not support `use_hpip=True` for now."
  66. )
  67. with TemporaryDeviceChanger(self.device):
  68. model = PPDocBeeInference.from_pretrained(
  69. self.model_dir, dtype=self.dtype
  70. )
  71. elif "PP-Chart2Table" in self.model_name:
  72. if kwargs.get("use_hpip", False):
  73. raise ValueError(
  74. f"PP-Chart2Table series do not support `use_hpip=True` for now."
  75. )
  76. with TemporaryDeviceChanger(self.device):
  77. model = PPChart2TableInference.from_pretrained(
  78. self.model_dir,
  79. dtype=self.dtype,
  80. pad_token_id=processor.tokenizer.eos_token_id,
  81. )
  82. else:
  83. raise NotImplementedError(f"Model {self.model_name} is not supported.")
  84. return model, processor
  85. def process(self, data: List[dict], **kwargs):
  86. """
  87. Process a batch of data through the preprocessing, inference, and postprocessing.
  88. Args:
  89. data (List[dict]): A batch of input data, must be a dict (e.g. {"image": /path/to/image, "query": some question}).
  90. kwargs (Optional[dict]): Arbitrary keyword arguments passed to model.generate.
  91. Returns:
  92. dict: A dictionary containing the raw sample information and prediction results for every instance of the batch.
  93. """
  94. assert all(isinstance(i, dict) for i in data)
  95. src_data = copy.copy(data)
  96. # preprocess
  97. data = self.processor.preprocess(data)
  98. data = self._switch_inputs_to_device(data)
  99. # do infer
  100. with TemporaryDeviceChanger(self.device):
  101. preds = self.infer.generate(data, **kwargs)
  102. # postprocess
  103. preds = self.processor.postprocess(preds)
  104. result_dict = self._format_result_dict(preds, src_data)
  105. return result_dict
  106. def build_processor(self, **kwargs):
  107. from ..common.tokenizer import MIXQwen2Tokenizer, QWenTokenizer
  108. from .processors import (
  109. GOTImageProcessor,
  110. PPChart2TableProcessor,
  111. PPDocBeeProcessor,
  112. Qwen2VLImageProcessor,
  113. )
  114. if "PP-DocBee" in self.model_name:
  115. image_processor = Qwen2VLImageProcessor()
  116. tokenizer = MIXQwen2Tokenizer.from_pretrained(self.model_dir)
  117. return PPDocBeeProcessor(
  118. image_processor=image_processor, tokenizer=tokenizer
  119. )
  120. elif "PP-Chart2Table" in self.model_name:
  121. image_processor = GOTImageProcessor(1024)
  122. tokenizer = QWenTokenizer.from_pretrained(self.model_dir)
  123. return PPChart2TableProcessor(
  124. image_processor=image_processor, tokenizer=tokenizer, dtype=self.dtype
  125. )
  126. else:
  127. raise NotImplementedError
  128. def _format_result_dict(self, model_preds, src_data):
  129. if not isinstance(model_preds, list):
  130. model_preds = [model_preds]
  131. if not isinstance(src_data, list):
  132. src_data = [src_data]
  133. if len(model_preds) != len(src_data):
  134. raise ValueError(
  135. f"Model predicts {len(model_preds)} results while src data has {len(src_data)} samples."
  136. )
  137. rst_format_dict = {k: [] for k in src_data[0].keys()}
  138. rst_format_dict["result"] = []
  139. for data_sample, model_pred in zip(src_data, model_preds):
  140. for k in data_sample.keys():
  141. rst_format_dict[k].append(data_sample[k])
  142. rst_format_dict["result"].append(model_pred)
  143. return rst_format_dict
  144. def _infer_dynamic_forward_device(self, device):
  145. """infer the forward device for dynamic graph model"""
  146. import GPUtil
  147. from ....utils.device import parse_device
  148. if device is None:
  149. return None
  150. if "cpu" in device.lower():
  151. return "cpu"
  152. device_type, device_ids = parse_device(device)
  153. cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
  154. if cuda_visible_devices is None:
  155. env_gpu_num = len(GPUtil.getGPUs())
  156. cuda_visible_devices = ",".join([str(i) for i in range(env_gpu_num)])
  157. env_device_ids = cuda_visible_devices.split(",")
  158. for env_device_id in env_device_ids:
  159. if not env_device_id.isdigit():
  160. raise ValueError(
  161. f"CUDA_VISIBLE_DEVICES ID must be an integer. Invalid device ID: {env_device_id}"
  162. )
  163. if max(device_ids) >= len(env_device_ids):
  164. raise ValueError(
  165. f"Required gpu ids {device_ids} even larger than the number of visible devices {cuda_visible_devices}."
  166. )
  167. rst_global_gpu_ids = [env_device_ids[idx] for idx in device_ids]
  168. return device_type + ":" + ",".join(rst_global_gpu_ids)
  169. def _switch_inputs_to_device(self, input_dict):
  170. """Switch the input to the specified device"""
  171. import paddle
  172. if self.device is None:
  173. return input_dict
  174. rst_dict = {
  175. k: (
  176. paddle.to_tensor(input_dict[k], place=self.device)
  177. if isinstance(input_dict[k], paddle.Tensor)
  178. else input_dict[k]
  179. )
  180. for k in input_dict
  181. }
  182. return rst_dict