predictor.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import base64
  15. import copy
  16. import io
  17. import os
  18. import warnings
  19. from concurrent.futures import ThreadPoolExecutor
  20. from pathlib import Path
  21. from threading import Lock
  22. from typing import List, Optional
  23. import numpy as np
  24. from ....modules.doc_vlm.model_list import MODELS
  25. from ....utils import logging
  26. from ....utils.deps import require_genai_client_plugin
  27. from ....utils.device import TemporaryDeviceChanger
  28. from ...common.batch_sampler import DocVLMBatchSampler
  29. from ...utils.misc import is_bfloat16_available
  30. from ..base import BasePredictor
  31. from .result import DocVLMResult
  32. class DocVLMPredictor(BasePredictor):
  33. entities = MODELS
  34. model_group = {
  35. "PP-DocBee": {"PP-DocBee-2B", "PP-DocBee-7B"},
  36. "PP-DocBee2": {"PP-DocBee2-3B"},
  37. "PP-Chart2Table": {"PP-Chart2Table"},
  38. "PaddleOCR-VL": {"PaddleOCR-VL-0.9B"},
  39. }
  40. def __init__(self, *args, **kwargs):
  41. """Initializes DocVLMPredictor.
  42. Args:
  43. *args: Arbitrary positional arguments passed to the superclass.
  44. **kwargs: Arbitrary keyword arguments passed to the superclass.
  45. """
  46. super().__init__(*args, **kwargs)
  47. if self._use_local_model:
  48. self.device = kwargs.get("device", None)
  49. self.dtype = "bfloat16" if is_bfloat16_available(self.device) else "float32"
  50. self.infer, self.processor = self._build(**kwargs)
  51. if (
  52. self.model_name == "PaddleOCR-VL-0.9B"
  53. and self.batch_sampler.batch_size > 1
  54. ):
  55. logging.warning(
  56. "Currently, the PaddleOCR-VL-0.9B local model only supports batch size of 1. The batch size will be updated to 1."
  57. )
  58. self.batch_sampler.batch_size = 1
  59. else:
  60. if self.batch_sampler.batch_size > 1:
  61. self._thread_pool = ThreadPoolExecutor(
  62. max_workers=min(self.batch_sampler.batch_size, os.cpu_count() or 1)
  63. )
  64. def _build_batch_sampler(self):
  65. """Builds and returns an DocVLMBatchSampler instance.
  66. Returns:
  67. DocVLMBatchSampler: An instance of DocVLMBatchSampler.
  68. """
  69. return DocVLMBatchSampler(self.model_name)
  70. def _get_result_class(self):
  71. """Returns the result class, DocVLMResult.
  72. Returns:
  73. type: The DocVLMResult class.
  74. """
  75. return DocVLMResult
  76. def _build(self, **kwargs):
  77. """Build the model, and correspounding processor on the configuration.
  78. Returns:
  79. model: An instance of Paddle model, could be either a dynamic model or a static model.
  80. processor: The correspounding processor for the model.
  81. """
  82. from .modeling import (
  83. PaddleOCRVLForConditionalGeneration,
  84. PPChart2TableInference,
  85. PPDocBee2Inference,
  86. PPDocBeeInference,
  87. )
  88. # build processor
  89. processor = self.build_processor()
  90. # build model
  91. if self.model_name in self.model_group["PP-DocBee"]:
  92. if kwargs.get("use_hpip", False):
  93. warnings.warn(
  94. "The PP-DocBee series does not support `use_hpip=True` for now."
  95. )
  96. with TemporaryDeviceChanger(self.device):
  97. model = PPDocBeeInference.from_pretrained(
  98. self.model_dir, dtype=self.dtype
  99. )
  100. elif self.model_name in self.model_group["PP-Chart2Table"]:
  101. if kwargs.get("use_hpip", False):
  102. warnings.warn(
  103. "The PP-Chart2Table series does not support `use_hpip=True` for now."
  104. )
  105. with TemporaryDeviceChanger(self.device):
  106. model = PPChart2TableInference.from_pretrained(
  107. self.model_dir,
  108. dtype=self.dtype,
  109. pad_token_id=processor.tokenizer.eos_token_id,
  110. )
  111. elif self.model_name in self.model_group["PP-DocBee2"]:
  112. if kwargs.get("use_hpip", False):
  113. warnings.warn(
  114. "The PP-Chart2Table series does not support `use_hpip=True` for now."
  115. )
  116. with TemporaryDeviceChanger(self.device):
  117. model = PPDocBee2Inference.from_pretrained(
  118. self.model_dir,
  119. dtype=self.dtype,
  120. )
  121. elif self.model_name in self.model_group["PaddleOCR-VL"]:
  122. if kwargs.get("use_hpip", False):
  123. warnings.warn(
  124. "The PaddelOCR-VL series does not support `use_hpip=True` for now."
  125. )
  126. with TemporaryDeviceChanger(self.device):
  127. model = PaddleOCRVLForConditionalGeneration.from_pretrained(
  128. self.model_dir,
  129. dtype=self.dtype,
  130. convert_from_hf=True,
  131. )
  132. else:
  133. raise NotImplementedError(f"Model {self.model_name} is not supported.")
  134. return model, processor
  135. def process(
  136. self,
  137. data: List[dict],
  138. max_new_tokens: Optional[int] = None,
  139. skip_special_tokens: Optional[bool] = None,
  140. repetition_penalty: Optional[float] = None,
  141. temperature: Optional[float] = None,
  142. top_p: Optional[float] = None,
  143. min_pixels: Optional[int] = None,
  144. max_pixels: Optional[int] = None,
  145. use_cache: Optional[bool] = None,
  146. **kwargs,
  147. ):
  148. """
  149. Process a batch of data through the preprocessing, inference, and postprocessing.
  150. Args:
  151. data (List[dict]): A batch of input data, must be a dict (e.g. {"image": /path/to/image, "query": some question}).
  152. Returns:
  153. dict: A dictionary containing the raw sample information and prediction results for every instance of the batch.
  154. """
  155. # TODO: Sampling settings
  156. # FIXME: When `skip_special_tokens` is `True`, the results from different backends may differ.
  157. assert all(isinstance(i, dict) for i in data)
  158. if self._use_local_model:
  159. src_data = copy.copy(data)
  160. # preprocess
  161. data = self.processor.preprocess(data)
  162. data = self._switch_inputs_to_device(data)
  163. # do infer
  164. generate_kwargs = {}
  165. if max_new_tokens is not None:
  166. generate_kwargs["max_new_tokens"] = max_new_tokens
  167. elif self.model_name in self.model_group["PaddleOCR-VL"]:
  168. generate_kwargs["max_new_tokens"] = 8192
  169. if repetition_penalty is not None:
  170. warnings.warn(
  171. "`repetition_penalty` is currently not supported by the local model and will be ignored."
  172. )
  173. if temperature is not None:
  174. warnings.warn(
  175. "`temperature` is currently not supported by the local model and will be ignored."
  176. )
  177. if top_p is not None:
  178. warnings.warn(
  179. "`top_p` is currently not supported by the local model and will be ignored."
  180. )
  181. if min_pixels is not None:
  182. warnings.warn(
  183. "`min_pixels` is currently not supported by the local model and will be ignored."
  184. )
  185. if max_pixels is not None:
  186. warnings.warn(
  187. "`max_pixels` is currently not supported by the local model and will be ignored."
  188. )
  189. if use_cache is not None:
  190. generate_kwargs["use_cache"] = use_cache
  191. with TemporaryDeviceChanger(self.device):
  192. preds = self.infer.generate(
  193. data,
  194. **generate_kwargs,
  195. )
  196. # postprocess
  197. postprocess_kwargs = {}
  198. if skip_special_tokens is not None:
  199. postprocess_kwargs["skip_special_tokens"] = skip_special_tokens
  200. preds = self.processor.postprocess(preds, **postprocess_kwargs)
  201. else:
  202. require_genai_client_plugin()
  203. src_data = data
  204. preds = self._genai_client_process(
  205. data,
  206. max_new_tokens=max_new_tokens,
  207. skip_special_tokens=skip_special_tokens,
  208. repetition_penalty=repetition_penalty,
  209. temperature=temperature,
  210. top_p=top_p,
  211. min_pixels=min_pixels,
  212. max_pixels=max_pixels,
  213. )
  214. result_dict = self._format_result_dict(preds, src_data)
  215. return result_dict
  216. def build_processor(self, **kwargs):
  217. from ..common.tokenizer import (
  218. LlamaTokenizer,
  219. MIXQwen2_5_Tokenizer,
  220. MIXQwen2Tokenizer,
  221. QWenTokenizer,
  222. )
  223. from ..common.tokenizer.tokenizer_utils import ChatTemplate
  224. from .processors import (
  225. GOTImageProcessor,
  226. PaddleOCRVLProcessor,
  227. PPChart2TableProcessor,
  228. PPDocBee2Processor,
  229. PPDocBeeProcessor,
  230. Qwen2_5_VLImageProcessor,
  231. Qwen2VLImageProcessor,
  232. SiglipImageProcessor,
  233. )
  234. if self.model_name in self.model_group["PP-DocBee"]:
  235. image_processor = Qwen2VLImageProcessor()
  236. tokenizer = MIXQwen2Tokenizer.from_pretrained(self.model_dir)
  237. return PPDocBeeProcessor(
  238. image_processor=image_processor, tokenizer=tokenizer
  239. )
  240. elif self.model_name in self.model_group["PP-Chart2Table"]:
  241. image_processor = GOTImageProcessor(1024)
  242. tokenizer = QWenTokenizer.from_pretrained(self.model_dir)
  243. return PPChart2TableProcessor(
  244. image_processor=image_processor, tokenizer=tokenizer, dtype=self.dtype
  245. )
  246. elif self.model_name in self.model_group["PP-DocBee2"]:
  247. image_processor = Qwen2_5_VLImageProcessor()
  248. tokenizer = MIXQwen2_5_Tokenizer.from_pretrained(self.model_dir)
  249. return PPDocBee2Processor(
  250. image_processor=image_processor, tokenizer=tokenizer
  251. )
  252. elif self.model_name in self.model_group["PaddleOCR-VL"]:
  253. image_processor = SiglipImageProcessor.from_pretrained(self.model_dir)
  254. vocab_file = str(Path(self.model_dir, "tokenizer.model"))
  255. tokenizer = LlamaTokenizer.from_pretrained(
  256. self.model_dir, vocab_file=vocab_file
  257. )
  258. # HACK
  259. chat_template_file = Path(self.model_dir, "chat_template.jinja")
  260. tokenizer.chat_template = ChatTemplate._compile_jinja_template(
  261. chat_template_file.read_text(encoding="utf-8")
  262. )
  263. return PaddleOCRVLProcessor(
  264. image_processor=image_processor,
  265. tokenizer=tokenizer,
  266. )
  267. else:
  268. raise NotImplementedError
  269. def close(self):
  270. super().close()
  271. if hasattr(self, "_thread_pool"):
  272. self._thread_pool.shutdown()
  273. def _format_result_dict(self, model_preds, src_data):
  274. if not isinstance(model_preds, list):
  275. model_preds = [model_preds]
  276. if not isinstance(src_data, list):
  277. src_data = [src_data]
  278. if len(model_preds) != len(src_data):
  279. raise ValueError(
  280. f"Model predicts {len(model_preds)} results while src data has {len(src_data)} samples."
  281. )
  282. rst_format_dict = {k: [] for k in src_data[0].keys()}
  283. rst_format_dict["result"] = []
  284. for data_sample, model_pred in zip(src_data, model_preds):
  285. for k in data_sample.keys():
  286. rst_format_dict[k].append(data_sample[k])
  287. rst_format_dict["result"].append(model_pred)
  288. return rst_format_dict
  289. def _infer_dynamic_forward_device(self, device):
  290. """infer the forward device for dynamic graph model"""
  291. import GPUtil
  292. from ....utils.device import parse_device
  293. if device is None:
  294. return None
  295. if "cpu" in device.lower():
  296. return "cpu"
  297. device_type, device_ids = parse_device(device)
  298. cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
  299. if cuda_visible_devices is None:
  300. env_gpu_num = len(GPUtil.getGPUs())
  301. cuda_visible_devices = ",".join([str(i) for i in range(env_gpu_num)])
  302. env_device_ids = cuda_visible_devices.split(",")
  303. for env_device_id in env_device_ids:
  304. if not env_device_id.isdigit():
  305. raise ValueError(
  306. f"CUDA_VISIBLE_DEVICES ID must be an integer. Invalid device ID: {env_device_id}"
  307. )
  308. if max(device_ids) >= len(env_device_ids):
  309. raise ValueError(
  310. f"Required gpu ids {device_ids} even larger than the number of visible devices {cuda_visible_devices}."
  311. )
  312. rst_global_gpu_ids = [env_device_ids[idx] for idx in device_ids]
  313. return device_type + ":" + ",".join(rst_global_gpu_ids)
  314. def _switch_inputs_to_device(self, input_dict):
  315. """Switch the input to the specified device"""
  316. import paddle
  317. if self.device is None:
  318. return input_dict
  319. rst_dict = {
  320. k: (
  321. paddle.to_tensor(input_dict[k], place=self.device)
  322. if isinstance(input_dict[k], paddle.Tensor)
  323. else input_dict[k]
  324. )
  325. for k in input_dict
  326. }
  327. return rst_dict
  328. def _genai_client_process(
  329. self,
  330. data,
  331. max_new_tokens,
  332. skip_special_tokens,
  333. repetition_penalty,
  334. temperature,
  335. top_p,
  336. min_pixels,
  337. max_pixels,
  338. ):
  339. lock = Lock()
  340. def _process(item):
  341. image = item["image"]
  342. if isinstance(image, str):
  343. if image.startswith("http://") or image.startswith("https://"):
  344. image_url = image
  345. else:
  346. from PIL import Image
  347. with Image.open(image) as img:
  348. img = img.convert("RGB")
  349. with io.BytesIO() as buf:
  350. img.save(buf, format="JPEG")
  351. image_url = "data:image/jpeg;base64," + base64.b64encode(
  352. buf.getvalue()
  353. ).decode("ascii")
  354. elif isinstance(image, np.ndarray):
  355. import cv2
  356. from PIL import Image
  357. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  358. img = Image.fromarray(image)
  359. with io.BytesIO() as buf:
  360. img.save(buf, format="JPEG")
  361. image_url = "data:image/jpeg;base64," + base64.b64encode(
  362. buf.getvalue()
  363. ).decode("ascii")
  364. else:
  365. raise TypeError(f"Not supported image type: {type(image)}")
  366. if self._genai_client.backend == "fastdeploy-server":
  367. kwargs = {
  368. "temperature": 1 if temperature is None else temperature,
  369. "top_p": 0 if top_p is None else top_p,
  370. }
  371. else:
  372. kwargs = {
  373. "temperature": 0 if temperature is None else temperature,
  374. }
  375. if top_p is not None:
  376. kwargs["top_p"] = top_p
  377. if max_new_tokens is not None:
  378. kwargs["max_completion_tokens"] = max_new_tokens
  379. elif self.model_name in self.model_group["PaddleOCR-VL"]:
  380. kwargs["max_completion_tokens"] = 8192
  381. kwargs["extra_body"] = {}
  382. if skip_special_tokens is not None:
  383. if self._genai_client.backend in (
  384. "fastdeploy-server",
  385. "vllm-server",
  386. "sglang-server",
  387. ):
  388. kwargs["extra_body"]["skip_special_tokens"] = skip_special_tokens
  389. else:
  390. raise ValueError("Not supported")
  391. if repetition_penalty is not None:
  392. kwargs["extra_body"]["repetition_penalty"] = repetition_penalty
  393. if min_pixels is not None:
  394. if self._genai_client.backend == "vllm-server":
  395. kwargs["extra_body"]["mm_processor_kwargs"] = kwargs[
  396. "extra_body"
  397. ].get("mm_processor_kwargs", {})
  398. kwargs["extra_body"]["mm_processor_kwargs"][
  399. "min_pixels"
  400. ] = min_pixels
  401. else:
  402. warnings.warn(
  403. f"{repr(self._genai_client.backend)} does not support `min_pixels`."
  404. )
  405. if max_pixels is not None:
  406. if self._genai_client.backend == "vllm-server":
  407. kwargs["extra_body"]["mm_processor_kwargs"] = kwargs[
  408. "extra_body"
  409. ].get("mm_processor_kwargs", {})
  410. kwargs["extra_body"]["mm_processor_kwargs"][
  411. "max_pixels"
  412. ] = max_pixels
  413. else:
  414. warnings.warn(
  415. f"{repr(self._genai_client.backend)} does not support `max_pixels`."
  416. )
  417. with lock:
  418. future = self._genai_client.create_chat_completion(
  419. [
  420. {
  421. "role": "user",
  422. "content": [
  423. {"type": "image_url", "image_url": {"url": image_url}},
  424. {"type": "text", "text": item["query"]},
  425. ],
  426. }
  427. ],
  428. return_future=True,
  429. **kwargs,
  430. )
  431. return future
  432. if len(data) > 1:
  433. futures = list(self._thread_pool.map(_process, data))
  434. else:
  435. futures = [_process(data[0])]
  436. results = []
  437. for future in futures:
  438. result = future.result()
  439. results.append(result.choices[0].message.content)
  440. return results