predictor.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. import os
  16. import warnings
  17. from typing import List
  18. from ....modules.doc_vlm.model_list import MODELS
  19. from ....utils.device import TemporaryDeviceChanger
  20. from ....utils.env import get_device_type
  21. from ...common.batch_sampler import DocVLMBatchSampler
  22. from ..base import BasePredictor
  23. from .result import DocVLMResult
  24. class DocVLMPredictor(BasePredictor):
  25. entities = MODELS
  26. model_group = {
  27. "PP-DocBee": {"PP-DocBee-2B", "PP-DocBee-7B"},
  28. "PP-DocBee2": {"PP-DocBee2-3B"},
  29. "PP-Chart2Table": {"PP-Chart2Table"},
  30. }
  31. def __init__(self, *args, **kwargs):
  32. """Initializes DocVLMPredictor.
  33. Args:
  34. *args: Arbitrary positional arguments passed to the superclass.
  35. **kwargs: Arbitrary keyword arguments passed to the superclass.
  36. """
  37. import paddle
  38. super().__init__(*args, **kwargs)
  39. self.device = kwargs.get("device", None)
  40. self.dtype = (
  41. "bfloat16"
  42. if ("npu" in get_device_type() or paddle.amp.is_bfloat16_supported())
  43. and (self.device is not None and "cpu" not in self.device)
  44. else "float32"
  45. )
  46. self.infer, self.processor = self._build(**kwargs)
  47. def _build_batch_sampler(self):
  48. """Builds and returns an DocVLMBatchSampler instance.
  49. Returns:
  50. DocVLMBatchSampler: An instance of DocVLMBatchSampler.
  51. """
  52. return DocVLMBatchSampler(self.model_name)
  53. def _get_result_class(self):
  54. """Returns the result class, DocVLMResult.
  55. Returns:
  56. type: The DocVLMResult class.
  57. """
  58. return DocVLMResult
  59. def _build(self, **kwargs):
  60. """Build the model, and correspounding processor on the configuration.
  61. Returns:
  62. model: An instance of Paddle model, could be either a dynamic model or a static model.
  63. processor: The correspounding processor for the model.
  64. """
  65. from .modeling import (
  66. PPChart2TableInference,
  67. PPDocBee2Inference,
  68. PPDocBeeInference,
  69. )
  70. # build processor
  71. processor = self.build_processor()
  72. # build model
  73. if self.model_name in self.model_group["PP-DocBee"]:
  74. if kwargs.get("use_hpip", False):
  75. warnings.warn(
  76. "The PP-DocBee series does not support `use_hpip=True` for now."
  77. )
  78. with TemporaryDeviceChanger(self.device):
  79. model = PPDocBeeInference.from_pretrained(
  80. self.model_dir, dtype=self.dtype
  81. )
  82. elif self.model_name in self.model_group["PP-Chart2Table"]:
  83. if kwargs.get("use_hpip", False):
  84. warnings.warn(
  85. "The PP-Chart2Table series does not support `use_hpip=True` for now."
  86. )
  87. with TemporaryDeviceChanger(self.device):
  88. model = PPChart2TableInference.from_pretrained(
  89. self.model_dir,
  90. dtype=self.dtype,
  91. pad_token_id=processor.tokenizer.eos_token_id,
  92. )
  93. elif self.model_name in self.model_group["PP-DocBee2"]:
  94. if kwargs.get("use_hpip", False):
  95. warnings.warn(
  96. "The PP-Chart2Table series does not support `use_hpip=True` for now."
  97. )
  98. with TemporaryDeviceChanger(self.device):
  99. model = PPDocBee2Inference.from_pretrained(
  100. self.model_dir,
  101. dtype=self.dtype,
  102. )
  103. else:
  104. raise NotImplementedError(f"Model {self.model_name} is not supported.")
  105. return model, processor
  106. def process(self, data: List[dict], **kwargs):
  107. """
  108. Process a batch of data through the preprocessing, inference, and postprocessing.
  109. Args:
  110. data (List[dict]): A batch of input data, must be a dict (e.g. {"image": /path/to/image, "query": some question}).
  111. kwargs (Optional[dict]): Arbitrary keyword arguments passed to model.generate.
  112. Returns:
  113. dict: A dictionary containing the raw sample information and prediction results for every instance of the batch.
  114. """
  115. assert all(isinstance(i, dict) for i in data)
  116. src_data = copy.copy(data)
  117. # preprocess
  118. data = self.processor.preprocess(data)
  119. data = self._switch_inputs_to_device(data)
  120. # do infer
  121. with TemporaryDeviceChanger(self.device):
  122. preds = self.infer.generate(data, **kwargs)
  123. # postprocess
  124. preds = self.processor.postprocess(preds)
  125. result_dict = self._format_result_dict(preds, src_data)
  126. return result_dict
  127. def build_processor(self, **kwargs):
  128. from ..common.tokenizer import (
  129. MIXQwen2_5_Tokenizer,
  130. MIXQwen2Tokenizer,
  131. QWenTokenizer,
  132. )
  133. from .processors import (
  134. GOTImageProcessor,
  135. PPChart2TableProcessor,
  136. PPDocBee2Processor,
  137. PPDocBeeProcessor,
  138. Qwen2_5_VLImageProcessor,
  139. Qwen2VLImageProcessor,
  140. )
  141. if self.model_name in self.model_group["PP-DocBee"]:
  142. image_processor = Qwen2VLImageProcessor()
  143. tokenizer = MIXQwen2Tokenizer.from_pretrained(self.model_dir)
  144. return PPDocBeeProcessor(
  145. image_processor=image_processor, tokenizer=tokenizer
  146. )
  147. elif self.model_name in self.model_group["PP-Chart2Table"]:
  148. image_processor = GOTImageProcessor(1024)
  149. tokenizer = QWenTokenizer.from_pretrained(self.model_dir)
  150. return PPChart2TableProcessor(
  151. image_processor=image_processor, tokenizer=tokenizer, dtype=self.dtype
  152. )
  153. elif self.model_name in self.model_group["PP-DocBee2"]:
  154. image_processor = Qwen2_5_VLImageProcessor()
  155. tokenizer = MIXQwen2_5_Tokenizer.from_pretrained(self.model_dir)
  156. return PPDocBee2Processor(
  157. image_processor=image_processor, tokenizer=tokenizer
  158. )
  159. else:
  160. raise NotImplementedError
  161. def _format_result_dict(self, model_preds, src_data):
  162. if not isinstance(model_preds, list):
  163. model_preds = [model_preds]
  164. if not isinstance(src_data, list):
  165. src_data = [src_data]
  166. if len(model_preds) != len(src_data):
  167. raise ValueError(
  168. f"Model predicts {len(model_preds)} results while src data has {len(src_data)} samples."
  169. )
  170. rst_format_dict = {k: [] for k in src_data[0].keys()}
  171. rst_format_dict["result"] = []
  172. for data_sample, model_pred in zip(src_data, model_preds):
  173. for k in data_sample.keys():
  174. rst_format_dict[k].append(data_sample[k])
  175. rst_format_dict["result"].append(model_pred)
  176. return rst_format_dict
  177. def _infer_dynamic_forward_device(self, device):
  178. """infer the forward device for dynamic graph model"""
  179. import GPUtil
  180. from ....utils.device import parse_device
  181. if device is None:
  182. return None
  183. if "cpu" in device.lower():
  184. return "cpu"
  185. device_type, device_ids = parse_device(device)
  186. cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
  187. if cuda_visible_devices is None:
  188. env_gpu_num = len(GPUtil.getGPUs())
  189. cuda_visible_devices = ",".join([str(i) for i in range(env_gpu_num)])
  190. env_device_ids = cuda_visible_devices.split(",")
  191. for env_device_id in env_device_ids:
  192. if not env_device_id.isdigit():
  193. raise ValueError(
  194. f"CUDA_VISIBLE_DEVICES ID must be an integer. Invalid device ID: {env_device_id}"
  195. )
  196. if max(device_ids) >= len(env_device_ids):
  197. raise ValueError(
  198. f"Required gpu ids {device_ids} even larger than the number of visible devices {cuda_visible_devices}."
  199. )
  200. rst_global_gpu_ids = [env_device_ids[idx] for idx in device_ids]
  201. return device_type + ":" + ",".join(rst_global_gpu_ids)
  202. def _switch_inputs_to_device(self, input_dict):
  203. """Switch the input to the specified device"""
  204. import paddle
  205. if self.device is None:
  206. return input_dict
  207. rst_dict = {
  208. k: (
  209. paddle.to_tensor(input_dict[k], place=self.device)
  210. if isinstance(input_dict[k], paddle.Tensor)
  211. else input_dict[k]
  212. )
  213. for k in input_dict
  214. }
  215. return rst_dict