predictor.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. import os
  16. from typing import List
  17. from ....modules.doc_vlm.model_list import MODELS
  18. from ....utils.device import TemporaryDeviceChanger
  19. from ....utils.env import get_device_type
  20. from ...common.batch_sampler import DocVLMBatchSampler
  21. from ..base import BasePredictor
  22. from .result import DocVLMResult
  23. class DocVLMPredictor(BasePredictor):
  24. entities = MODELS
  25. model_group = {
  26. "PP-DocBee": {"PP-DocBee-2B", "PP-DocBee-7B"},
  27. "PP-DocBee2": {"PP-DocBee2-3B"},
  28. "PP-Chart2Table": {"PP-Chart2Table"},
  29. }
  30. def __init__(self, *args, **kwargs):
  31. """Initializes DocVLMPredictor.
  32. Args:
  33. *args: Arbitrary positional arguments passed to the superclass.
  34. **kwargs: Arbitrary keyword arguments passed to the superclass.
  35. """
  36. import paddle
  37. super().__init__(*args, **kwargs)
  38. self.device = kwargs.get("device", None)
  39. self.dtype = (
  40. "bfloat16"
  41. if ("npu" in get_device_type() or paddle.amp.is_bfloat16_supported())
  42. and (self.device is not None and "cpu" not in self.device)
  43. else "float32"
  44. )
  45. self.infer, self.processor = self._build(**kwargs)
  46. def _build_batch_sampler(self):
  47. """Builds and returns an DocVLMBatchSampler instance.
  48. Returns:
  49. DocVLMBatchSampler: An instance of DocVLMBatchSampler.
  50. """
  51. return DocVLMBatchSampler(self.model_name)
  52. def _get_result_class(self):
  53. """Returns the result class, DocVLMResult.
  54. Returns:
  55. type: The DocVLMResult class.
  56. """
  57. return DocVLMResult
  58. def _build(self, **kwargs):
  59. """Build the model, and correspounding processor on the configuration.
  60. Returns:
  61. model: An instance of Paddle model, could be either a dynamic model or a static model.
  62. processor: The correspounding processor for the model.
  63. """
  64. from .modeling import (
  65. PPChart2TableInference,
  66. PPDocBee2Inference,
  67. PPDocBeeInference,
  68. )
  69. # build processor
  70. processor = self.build_processor()
  71. # build model
  72. if self.model_name in self.model_group["PP-DocBee"]:
  73. if kwargs.get("use_hpip", False):
  74. raise ValueError(
  75. f"PP-DocBee series do not support `use_hpip=True` for now."
  76. )
  77. with TemporaryDeviceChanger(self.device):
  78. model = PPDocBeeInference.from_pretrained(
  79. self.model_dir, dtype=self.dtype
  80. )
  81. elif self.model_name in self.model_group["PP-Chart2Table"]:
  82. if kwargs.get("use_hpip", False):
  83. raise ValueError(
  84. f"PP-Chart2Table series do not support `use_hpip=True` for now."
  85. )
  86. with TemporaryDeviceChanger(self.device):
  87. model = PPChart2TableInference.from_pretrained(
  88. self.model_dir,
  89. dtype=self.dtype,
  90. pad_token_id=processor.tokenizer.eos_token_id,
  91. )
  92. elif self.model_name in self.model_group["PP-DocBee2"]:
  93. if kwargs.get("use_hpip", False):
  94. raise ValueError(
  95. f"PP-Chart2Table series do not support `use_hpip=True` for now."
  96. )
  97. with TemporaryDeviceChanger(self.device):
  98. model = PPDocBee2Inference.from_pretrained(
  99. self.model_dir,
  100. dtype=self.dtype,
  101. )
  102. else:
  103. raise NotImplementedError(f"Model {self.model_name} is not supported.")
  104. return model, processor
  105. def process(self, data: List[dict], **kwargs):
  106. """
  107. Process a batch of data through the preprocessing, inference, and postprocessing.
  108. Args:
  109. data (List[dict]): A batch of input data, must be a dict (e.g. {"image": /path/to/image, "query": some question}).
  110. kwargs (Optional[dict]): Arbitrary keyword arguments passed to model.generate.
  111. Returns:
  112. dict: A dictionary containing the raw sample information and prediction results for every instance of the batch.
  113. """
  114. assert all(isinstance(i, dict) for i in data)
  115. src_data = copy.copy(data)
  116. # preprocess
  117. data = self.processor.preprocess(data)
  118. data = self._switch_inputs_to_device(data)
  119. # do infer
  120. with TemporaryDeviceChanger(self.device):
  121. preds = self.infer.generate(data, **kwargs)
  122. # postprocess
  123. preds = self.processor.postprocess(preds)
  124. result_dict = self._format_result_dict(preds, src_data)
  125. return result_dict
  126. def build_processor(self, **kwargs):
  127. from ..common.tokenizer import (
  128. MIXQwen2_5_Tokenizer,
  129. MIXQwen2Tokenizer,
  130. QWenTokenizer,
  131. )
  132. from .processors import (
  133. GOTImageProcessor,
  134. PPChart2TableProcessor,
  135. PPDocBee2Processor,
  136. PPDocBeeProcessor,
  137. Qwen2_5_VLImageProcessor,
  138. Qwen2VLImageProcessor,
  139. )
  140. if self.model_name in self.model_group["PP-DocBee"]:
  141. image_processor = Qwen2VLImageProcessor()
  142. tokenizer = MIXQwen2Tokenizer.from_pretrained(self.model_dir)
  143. return PPDocBeeProcessor(
  144. image_processor=image_processor, tokenizer=tokenizer
  145. )
  146. elif self.model_name in self.model_group["PP-Chart2Table"]:
  147. image_processor = GOTImageProcessor(1024)
  148. tokenizer = QWenTokenizer.from_pretrained(self.model_dir)
  149. return PPChart2TableProcessor(
  150. image_processor=image_processor, tokenizer=tokenizer, dtype=self.dtype
  151. )
  152. elif self.model_name in self.model_group["PP-DocBee2"]:
  153. image_processor = Qwen2_5_VLImageProcessor()
  154. tokenizer = MIXQwen2_5_Tokenizer.from_pretrained(self.model_dir)
  155. return PPDocBee2Processor(
  156. image_processor=image_processor, tokenizer=tokenizer
  157. )
  158. else:
  159. raise NotImplementedError
  160. def _format_result_dict(self, model_preds, src_data):
  161. if not isinstance(model_preds, list):
  162. model_preds = [model_preds]
  163. if not isinstance(src_data, list):
  164. src_data = [src_data]
  165. if len(model_preds) != len(src_data):
  166. raise ValueError(
  167. f"Model predicts {len(model_preds)} results while src data has {len(src_data)} samples."
  168. )
  169. rst_format_dict = {k: [] for k in src_data[0].keys()}
  170. rst_format_dict["result"] = []
  171. for data_sample, model_pred in zip(src_data, model_preds):
  172. for k in data_sample.keys():
  173. rst_format_dict[k].append(data_sample[k])
  174. rst_format_dict["result"].append(model_pred)
  175. return rst_format_dict
  176. def _infer_dynamic_forward_device(self, device):
  177. """infer the forward device for dynamic graph model"""
  178. import GPUtil
  179. from ....utils.device import parse_device
  180. if device is None:
  181. return None
  182. if "cpu" in device.lower():
  183. return "cpu"
  184. device_type, device_ids = parse_device(device)
  185. cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
  186. if cuda_visible_devices is None:
  187. env_gpu_num = len(GPUtil.getGPUs())
  188. cuda_visible_devices = ",".join([str(i) for i in range(env_gpu_num)])
  189. env_device_ids = cuda_visible_devices.split(",")
  190. for env_device_id in env_device_ids:
  191. if not env_device_id.isdigit():
  192. raise ValueError(
  193. f"CUDA_VISIBLE_DEVICES ID must be an integer. Invalid device ID: {env_device_id}"
  194. )
  195. if max(device_ids) >= len(env_device_ids):
  196. raise ValueError(
  197. f"Required gpu ids {device_ids} even larger than the number of visible devices {cuda_visible_devices}."
  198. )
  199. rst_global_gpu_ids = [env_device_ids[idx] for idx in device_ids]
  200. return device_type + ":" + ",".join(rst_global_gpu_ids)
  201. def _switch_inputs_to_device(self, input_dict):
  202. """Switch the input to the specified device"""
  203. import paddle
  204. if self.device is None:
  205. return input_dict
  206. rst_dict = {
  207. k: (
  208. paddle.to_tensor(input_dict[k], place=self.device)
  209. if isinstance(input_dict[k], paddle.Tensor)
  210. else input_dict[k]
  211. )
  212. for k in input_dict
  213. }
  214. return rst_dict