predictor.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import base64
  15. import copy
  16. import io
  17. import os
  18. import warnings
  19. from concurrent.futures import ThreadPoolExecutor
  20. from pathlib import Path
  21. from threading import Lock
  22. from typing import List, Optional
  23. import numpy as np
  24. from ....modules.doc_vlm.model_list import MODELS
  25. from ....utils import logging
  26. from ....utils.deps import require_genai_client_plugin
  27. from ....utils.device import TemporaryDeviceChanger
  28. from ....utils.env import get_device_type
  29. from ...common.batch_sampler import DocVLMBatchSampler
  30. from ..base import BasePredictor
  31. from .result import DocVLMResult
  32. class DocVLMPredictor(BasePredictor):
  33. entities = MODELS
  34. model_group = {
  35. "PP-DocBee": {"PP-DocBee-2B", "PP-DocBee-7B"},
  36. "PP-DocBee2": {"PP-DocBee2-3B"},
  37. "PP-Chart2Table": {"PP-Chart2Table"},
  38. "PaddleOCR-VL": {"PaddleOCR-VL-0.9B"},
  39. }
  40. def __init__(self, *args, **kwargs):
  41. """Initializes DocVLMPredictor.
  42. Args:
  43. *args: Arbitrary positional arguments passed to the superclass.
  44. **kwargs: Arbitrary keyword arguments passed to the superclass.
  45. """
  46. super().__init__(*args, **kwargs)
  47. if self._use_local_model:
  48. import paddle
  49. self.device = kwargs.get("device", None)
  50. self.dtype = (
  51. "bfloat16"
  52. if ("npu" in get_device_type() or paddle.amp.is_bfloat16_supported())
  53. and (self.device is None or "cpu" not in self.device)
  54. else "float32"
  55. )
  56. self.infer, self.processor = self._build(**kwargs)
  57. if (
  58. self.model_name == "PaddleOCR-VL-0.9B"
  59. and self.batch_sampler.batch_size > 1
  60. ):
  61. logging.warning(
  62. "Currently, the PaddleOCR-VL-0.9B local model only supports batch size of 1. The batch size will be updated to 1."
  63. )
  64. self.batch_sampler.batch_size = 1
  65. else:
  66. if self.batch_sampler.batch_size > 1:
  67. self._thread_pool = ThreadPoolExecutor(
  68. max_workers=min(self.batch_sampler.batch_size, os.cpu_count() or 1)
  69. )
  70. def _build_batch_sampler(self):
  71. """Builds and returns an DocVLMBatchSampler instance.
  72. Returns:
  73. DocVLMBatchSampler: An instance of DocVLMBatchSampler.
  74. """
  75. return DocVLMBatchSampler(self.model_name)
  76. def _get_result_class(self):
  77. """Returns the result class, DocVLMResult.
  78. Returns:
  79. type: The DocVLMResult class.
  80. """
  81. return DocVLMResult
  82. def _build(self, **kwargs):
  83. """Build the model, and correspounding processor on the configuration.
  84. Returns:
  85. model: An instance of Paddle model, could be either a dynamic model or a static model.
  86. processor: The correspounding processor for the model.
  87. """
  88. from .modeling import (
  89. PaddleOCRVLForConditionalGeneration,
  90. PPChart2TableInference,
  91. PPDocBee2Inference,
  92. PPDocBeeInference,
  93. )
  94. # build processor
  95. processor = self.build_processor()
  96. # build model
  97. if self.model_name in self.model_group["PP-DocBee"]:
  98. if kwargs.get("use_hpip", False):
  99. warnings.warn(
  100. "The PP-DocBee series does not support `use_hpip=True` for now."
  101. )
  102. with TemporaryDeviceChanger(self.device):
  103. model = PPDocBeeInference.from_pretrained(
  104. self.model_dir, dtype=self.dtype
  105. )
  106. elif self.model_name in self.model_group["PP-Chart2Table"]:
  107. if kwargs.get("use_hpip", False):
  108. warnings.warn(
  109. "The PP-Chart2Table series does not support `use_hpip=True` for now."
  110. )
  111. with TemporaryDeviceChanger(self.device):
  112. model = PPChart2TableInference.from_pretrained(
  113. self.model_dir,
  114. dtype=self.dtype,
  115. pad_token_id=processor.tokenizer.eos_token_id,
  116. )
  117. elif self.model_name in self.model_group["PP-DocBee2"]:
  118. if kwargs.get("use_hpip", False):
  119. warnings.warn(
  120. "The PP-Chart2Table series does not support `use_hpip=True` for now."
  121. )
  122. with TemporaryDeviceChanger(self.device):
  123. model = PPDocBee2Inference.from_pretrained(
  124. self.model_dir,
  125. dtype=self.dtype,
  126. )
  127. elif self.model_name in self.model_group["PaddleOCR-VL"]:
  128. if kwargs.get("use_hpip", False):
  129. warnings.warn(
  130. "The PaddelOCR-VL series does not support `use_hpip=True` for now."
  131. )
  132. with TemporaryDeviceChanger(self.device):
  133. model = PaddleOCRVLForConditionalGeneration.from_pretrained(
  134. self.model_dir,
  135. dtype=self.dtype,
  136. convert_from_hf=True,
  137. )
  138. else:
  139. raise NotImplementedError(f"Model {self.model_name} is not supported.")
  140. return model, processor
  141. def process(
  142. self,
  143. data: List[dict],
  144. max_new_tokens: Optional[int] = None,
  145. skip_special_tokens: Optional[bool] = None,
  146. repetition_penalty: Optional[float] = None,
  147. temperature: Optional[float] = None,
  148. top_p: Optional[float] = None,
  149. min_pixels: Optional[int] = None,
  150. max_pixels: Optional[int] = None,
  151. use_cache: Optional[bool] = None,
  152. **kwargs,
  153. ):
  154. """
  155. Process a batch of data through the preprocessing, inference, and postprocessing.
  156. Args:
  157. data (List[dict]): A batch of input data, must be a dict (e.g. {"image": /path/to/image, "query": some question}).
  158. Returns:
  159. dict: A dictionary containing the raw sample information and prediction results for every instance of the batch.
  160. """
  161. # TODO: Sampling settings
  162. # FIXME: When `skip_special_tokens` is `True`, the results from different backends may differ.
  163. assert all(isinstance(i, dict) for i in data)
  164. if self._use_local_model:
  165. src_data = copy.copy(data)
  166. # preprocess
  167. data = self.processor.preprocess(data)
  168. data = self._switch_inputs_to_device(data)
  169. # do infer
  170. generate_kwargs = {}
  171. if max_new_tokens is not None:
  172. generate_kwargs["max_new_tokens"] = max_new_tokens
  173. elif self.model_name in self.model_group["PaddleOCR-VL"]:
  174. generate_kwargs["max_new_tokens"] = 8192
  175. if repetition_penalty is not None:
  176. warnings.warn(
  177. "`repetition_penalty` is currently not supported by the local model and will be ignored."
  178. )
  179. if temperature is not None:
  180. warnings.warn(
  181. "`temperature` is currently not supported by the local model and will be ignored."
  182. )
  183. if top_p is not None:
  184. warnings.warn(
  185. "`top_p` is currently not supported by the local model and will be ignored."
  186. )
  187. if min_pixels is not None:
  188. warnings.warn(
  189. "`min_pixels` is currently not supported by the local model and will be ignored."
  190. )
  191. if max_pixels is not None:
  192. warnings.warn(
  193. "`max_pixels` is currently not supported by the local model and will be ignored."
  194. )
  195. if use_cache is not None:
  196. generate_kwargs["use_cache"] = use_cache
  197. with TemporaryDeviceChanger(self.device):
  198. preds = self.infer.generate(
  199. data,
  200. **generate_kwargs,
  201. )
  202. # postprocess
  203. postprocess_kwargs = {}
  204. if skip_special_tokens is not None:
  205. postprocess_kwargs["skip_special_tokens"] = skip_special_tokens
  206. preds = self.processor.postprocess(preds, **postprocess_kwargs)
  207. else:
  208. require_genai_client_plugin()
  209. src_data = data
  210. preds = self._genai_client_process(
  211. data,
  212. max_new_tokens=max_new_tokens,
  213. skip_special_tokens=skip_special_tokens,
  214. repetition_penalty=repetition_penalty,
  215. temperature=temperature,
  216. top_p=top_p,
  217. min_pixels=min_pixels,
  218. max_pixels=max_pixels,
  219. )
  220. result_dict = self._format_result_dict(preds, src_data)
  221. return result_dict
  222. def build_processor(self, **kwargs):
  223. from ..common.tokenizer import (
  224. LlamaTokenizer,
  225. MIXQwen2_5_Tokenizer,
  226. MIXQwen2Tokenizer,
  227. QWenTokenizer,
  228. )
  229. from ..common.tokenizer.tokenizer_utils import ChatTemplate
  230. from .processors import (
  231. GOTImageProcessor,
  232. PaddleOCRVLProcessor,
  233. PPChart2TableProcessor,
  234. PPDocBee2Processor,
  235. PPDocBeeProcessor,
  236. Qwen2_5_VLImageProcessor,
  237. Qwen2VLImageProcessor,
  238. SiglipImageProcessor,
  239. )
  240. if self.model_name in self.model_group["PP-DocBee"]:
  241. image_processor = Qwen2VLImageProcessor()
  242. tokenizer = MIXQwen2Tokenizer.from_pretrained(self.model_dir)
  243. return PPDocBeeProcessor(
  244. image_processor=image_processor, tokenizer=tokenizer
  245. )
  246. elif self.model_name in self.model_group["PP-Chart2Table"]:
  247. image_processor = GOTImageProcessor(1024)
  248. tokenizer = QWenTokenizer.from_pretrained(self.model_dir)
  249. return PPChart2TableProcessor(
  250. image_processor=image_processor, tokenizer=tokenizer, dtype=self.dtype
  251. )
  252. elif self.model_name in self.model_group["PP-DocBee2"]:
  253. image_processor = Qwen2_5_VLImageProcessor()
  254. tokenizer = MIXQwen2_5_Tokenizer.from_pretrained(self.model_dir)
  255. return PPDocBee2Processor(
  256. image_processor=image_processor, tokenizer=tokenizer
  257. )
  258. elif self.model_name in self.model_group["PaddleOCR-VL"]:
  259. image_processor = SiglipImageProcessor.from_pretrained(self.model_dir)
  260. vocab_file = str(Path(self.model_dir, "tokenizer.model"))
  261. tokenizer = LlamaTokenizer.from_pretrained(
  262. self.model_dir, vocab_file=vocab_file
  263. )
  264. # HACK
  265. chat_template_file = Path(self.model_dir, "chat_template.jinja")
  266. tokenizer.chat_template = ChatTemplate._compile_jinja_template(
  267. chat_template_file.read_text(encoding="utf-8")
  268. )
  269. return PaddleOCRVLProcessor(
  270. image_processor=image_processor,
  271. tokenizer=tokenizer,
  272. )
  273. else:
  274. raise NotImplementedError
  275. def close(self):
  276. super().close()
  277. if hasattr(self, "_thread_pool"):
  278. self._thread_pool.shutdown()
  279. def _format_result_dict(self, model_preds, src_data):
  280. if not isinstance(model_preds, list):
  281. model_preds = [model_preds]
  282. if not isinstance(src_data, list):
  283. src_data = [src_data]
  284. if len(model_preds) != len(src_data):
  285. raise ValueError(
  286. f"Model predicts {len(model_preds)} results while src data has {len(src_data)} samples."
  287. )
  288. rst_format_dict = {k: [] for k in src_data[0].keys()}
  289. rst_format_dict["result"] = []
  290. for data_sample, model_pred in zip(src_data, model_preds):
  291. for k in data_sample.keys():
  292. rst_format_dict[k].append(data_sample[k])
  293. rst_format_dict["result"].append(model_pred)
  294. return rst_format_dict
  295. def _infer_dynamic_forward_device(self, device):
  296. """infer the forward device for dynamic graph model"""
  297. import GPUtil
  298. from ....utils.device import parse_device
  299. if device is None:
  300. return None
  301. if "cpu" in device.lower():
  302. return "cpu"
  303. device_type, device_ids = parse_device(device)
  304. cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
  305. if cuda_visible_devices is None:
  306. env_gpu_num = len(GPUtil.getGPUs())
  307. cuda_visible_devices = ",".join([str(i) for i in range(env_gpu_num)])
  308. env_device_ids = cuda_visible_devices.split(",")
  309. for env_device_id in env_device_ids:
  310. if not env_device_id.isdigit():
  311. raise ValueError(
  312. f"CUDA_VISIBLE_DEVICES ID must be an integer. Invalid device ID: {env_device_id}"
  313. )
  314. if max(device_ids) >= len(env_device_ids):
  315. raise ValueError(
  316. f"Required gpu ids {device_ids} even larger than the number of visible devices {cuda_visible_devices}."
  317. )
  318. rst_global_gpu_ids = [env_device_ids[idx] for idx in device_ids]
  319. return device_type + ":" + ",".join(rst_global_gpu_ids)
  320. def _switch_inputs_to_device(self, input_dict):
  321. """Switch the input to the specified device"""
  322. import paddle
  323. if self.device is None:
  324. return input_dict
  325. rst_dict = {
  326. k: (
  327. paddle.to_tensor(input_dict[k], place=self.device)
  328. if isinstance(input_dict[k], paddle.Tensor)
  329. else input_dict[k]
  330. )
  331. for k in input_dict
  332. }
  333. return rst_dict
  334. def crop_margin(self, img): # 输入是OpenCV图像 (numpy数组)
  335. import cv2
  336. # 如果输入是彩色图像,转换为灰度图
  337. if len(img.shape) == 3:
  338. gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  339. else:
  340. gray = img.copy()
  341. # 转换为0-255范围(确保是uint8类型)
  342. if gray.dtype != np.uint8:
  343. gray = gray.astype(np.uint8)
  344. max_val = gray.max()
  345. min_val = gray.min()
  346. if max_val == min_val:
  347. return img
  348. # 归一化并二值化(与PIL版本逻辑一致)
  349. data = (gray - min_val) / (max_val - min_val) * 255
  350. data = data.astype(np.uint8)
  351. # 创建二值图像(暗色区域为白色,亮色区域为黑色)
  352. _, binary = cv2.threshold(data, 200, 255, cv2.THRESH_BINARY_INV)
  353. # 查找非零像素坐标
  354. coords = cv2.findNonZero(binary)
  355. if coords is None: # 如果没有找到任何内容,返回原图
  356. return img
  357. # 获取边界框
  358. x, y, w, h = cv2.boundingRect(coords)
  359. # 裁剪图像
  360. cropped = img[y : y + h, x : x + w]
  361. return cropped
  362. def _genai_client_process(
  363. self,
  364. data,
  365. max_new_tokens,
  366. skip_special_tokens,
  367. repetition_penalty,
  368. temperature,
  369. top_p,
  370. min_pixels,
  371. max_pixels,
  372. ):
  373. lock = Lock()
  374. def _process(item):
  375. image = item["image"]
  376. prompt = item["query"]
  377. if prompt == "Formula Recognition:":
  378. image = self.crop_margin(image)
  379. if isinstance(image, str):
  380. if image.startswith("http://") or image.startswith("https://"):
  381. image_url = image
  382. else:
  383. from PIL import Image
  384. with Image.open(image) as img:
  385. img = img.convert("RGB")
  386. with io.BytesIO() as buf:
  387. img.save(buf, format="JPEG")
  388. image_url = "data:image/jpeg;base64," + base64.b64encode(
  389. buf.getvalue()
  390. ).decode("ascii")
  391. elif isinstance(image, np.ndarray):
  392. import cv2
  393. from PIL import Image
  394. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  395. img = Image.fromarray(image)
  396. with io.BytesIO() as buf:
  397. img.save(buf, format="JPEG")
  398. image_url = "data:image/jpeg;base64," + base64.b64encode(
  399. buf.getvalue()
  400. ).decode("ascii")
  401. else:
  402. raise TypeError(f"Not supported image type: {type(image)}")
  403. if self._genai_client.backend == "fastdeploy-server":
  404. kwargs = {
  405. "temperature": 1 if temperature is None else temperature,
  406. "top_p": 0 if top_p is None else top_p,
  407. }
  408. else:
  409. kwargs = {
  410. "temperature": 0 if temperature is None else temperature,
  411. }
  412. if top_p is not None:
  413. kwargs["top_p"] = top_p
  414. if max_new_tokens is not None:
  415. kwargs["max_completion_tokens"] = max_new_tokens
  416. elif self.model_name in self.model_group["PaddleOCR-VL"]:
  417. kwargs["max_completion_tokens"] = 8192
  418. kwargs["extra_body"] = {}
  419. if skip_special_tokens is not None:
  420. if self._genai_client.backend in (
  421. "fastdeploy-server",
  422. "vllm-server",
  423. "sglang-server",
  424. ):
  425. kwargs["extra_body"]["skip_special_tokens"] = skip_special_tokens
  426. else:
  427. raise ValueError("Not supported")
  428. if repetition_penalty is not None:
  429. kwargs["extra_body"]["repetition_penalty"] = repetition_penalty
  430. if min_pixels is not None:
  431. if self._genai_client.backend == "vllm-server":
  432. kwargs["extra_body"]["mm_processor_kwargs"] = kwargs[
  433. "extra_body"
  434. ].get("mm_processor_kwargs", {})
  435. kwargs["extra_body"]["mm_processor_kwargs"][
  436. "min_pixels"
  437. ] = min_pixels
  438. else:
  439. warnings.warn(
  440. f"{repr(self._genai_client.backend)} does not support `min_pixels`."
  441. )
  442. if max_pixels is not None:
  443. if self._genai_client.backend == "vllm-server":
  444. kwargs["extra_body"]["mm_processor_kwargs"] = kwargs[
  445. "extra_body"
  446. ].get("mm_processor_kwargs", {})
  447. kwargs["extra_body"]["mm_processor_kwargs"][
  448. "max_pixels"
  449. ] = max_pixels
  450. else:
  451. warnings.warn(
  452. f"{repr(self._genai_client.backend)} does not support `max_pixels`."
  453. )
  454. with lock:
  455. future = self._genai_client.create_chat_completion(
  456. [
  457. {
  458. "role": "user",
  459. "content": [
  460. {"type": "image_url", "image_url": {"url": image_url}},
  461. {"type": "text", "text": item["query"]},
  462. ],
  463. }
  464. ],
  465. return_future=True,
  466. **kwargs,
  467. )
  468. return future
  469. if len(data) > 1:
  470. futures = list(self._thread_pool.map(_process, data))
  471. else:
  472. futures = [_process(data[0])]
  473. results = []
  474. for future in futures:
  475. result = future.result()
  476. results.append(result.choices[0].message.content)
  477. return results