pipeline.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import queue
  15. import threading
  16. import time
  17. from itertools import chain
  18. from typing import Any, Dict, Optional, Tuple, Union
  19. import numpy as np
  20. from PIL import Image
  21. from ....utils import logging
  22. from ....utils.deps import pipeline_requires_extra
  23. from ...common.batch_sampler import ImageBatchSampler
  24. from ...common.reader import ReadImage
  25. from ...utils.benchmark import benchmark
  26. from ...utils.hpi import HPIConfig
  27. from ...utils.pp_option import PaddlePredictorOption
  28. from .._parallel import AutoParallelImageSimpleInferencePipeline
  29. from ..base import BasePipeline
  30. from ..components import CropByBoxes
  31. from ..layout_parsing.utils import gather_imgs
  32. from .result import PPOCRVLBlock, PPOCRVLResult
  33. from .uilts import (
  34. convert_otsl_to_html,
  35. filter_overlap_boxes,
  36. merge_blocks,
  37. tokenize_figure_of_table,
  38. truncate_repetitive_content,
  39. untokenize_figure_of_table,
  40. )
  41. IMAGE_LABELS = ["image", "header_image", "footer_image", "seal"]
  42. @benchmark.time_methods
  43. class _PPOCRVLPipeline(BasePipeline):
  44. """_PPOCRVLPipeline Pipeline"""
  45. def __init__(
  46. self,
  47. config: Dict,
  48. device: Optional[str] = None,
  49. pp_option: Optional[PaddlePredictorOption] = None,
  50. use_hpip: bool = False,
  51. hpi_config: Optional[Union[Dict[str, Any], HPIConfig]] = None,
  52. ) -> None:
  53. """
  54. Initializes the class with given configurations and options.
  55. Args:
  56. config (Dict): Configuration dictionary containing various settings.
  57. device (str, optional): Device to run the predictions on. Defaults to None.
  58. pp_option (PaddlePredictorOption, optional): PaddlePredictor options. Defaults to None.
  59. use_hpip (bool, optional): Whether to use the high-performance
  60. inference plugin (HPIP) by default. Defaults to False.
  61. hpi_config (Optional[Union[Dict[str, Any], HPIConfig]], optional):
  62. The default high-performance inference configuration dictionary.
  63. Defaults to None.
  64. """
  65. super().__init__(
  66. device=device, pp_option=pp_option, use_hpip=use_hpip, hpi_config=hpi_config
  67. )
  68. self.use_doc_preprocessor = config.get("use_doc_preprocessor", True)
  69. if self.use_doc_preprocessor:
  70. doc_preprocessor_config = config.get("SubPipelines", {}).get(
  71. "DocPreprocessor",
  72. {
  73. "pipeline_config_error": "config error for doc_preprocessor_pipeline!"
  74. },
  75. )
  76. self.doc_preprocessor_pipeline = self.create_pipeline(
  77. doc_preprocessor_config
  78. )
  79. self.use_layout_detection = config.get("use_layout_detection", True)
  80. if self.use_layout_detection:
  81. layout_det_config = config.get("SubModules", {}).get(
  82. "LayoutDetection",
  83. {"model_config_error": "config error for layout_det_model!"},
  84. )
  85. model_name = layout_det_config.get("model_name", None)
  86. assert (
  87. model_name is not None and model_name == "PP-DocLayoutV2"
  88. ), "model_name must be PP-DocLayoutV2"
  89. layout_kwargs = {}
  90. if (threshold := layout_det_config.get("threshold", None)) is not None:
  91. layout_kwargs["threshold"] = threshold
  92. if (layout_nms := layout_det_config.get("layout_nms", None)) is not None:
  93. layout_kwargs["layout_nms"] = layout_nms
  94. if (
  95. layout_unclip_ratio := layout_det_config.get(
  96. "layout_unclip_ratio", None
  97. )
  98. ) is not None:
  99. layout_kwargs["layout_unclip_ratio"] = layout_unclip_ratio
  100. if (
  101. layout_merge_bboxes_mode := layout_det_config.get(
  102. "layout_merge_bboxes_mode", None
  103. )
  104. ) is not None:
  105. layout_kwargs["layout_merge_bboxes_mode"] = layout_merge_bboxes_mode
  106. self.layout_det_model = self.create_model(
  107. layout_det_config, **layout_kwargs
  108. )
  109. self.use_chart_recognition = config.get("use_chart_recognition", True)
  110. vl_rec_config = config.get("SubModules", {}).get(
  111. "VLRecognition",
  112. {"model_config_error": "config error for vl_rec_model!"},
  113. )
  114. self.vl_rec_model = self.create_model(vl_rec_config)
  115. self.format_block_content = config.get("format_block_content", False)
  116. self.batch_sampler = ImageBatchSampler(batch_size=config.get("batch_size", 1))
  117. self.img_reader = ReadImage(format="BGR")
  118. self.crop_by_boxes = CropByBoxes()
  119. self.use_queues = config.get("use_queues", False)
  120. def close(self):
  121. self.vl_rec_model.close()
  122. def get_model_settings(
  123. self,
  124. use_doc_orientation_classify: Union[bool, None],
  125. use_doc_unwarping: Union[bool, None],
  126. use_layout_detection: Union[bool, None],
  127. use_chart_recognition: Union[bool, None],
  128. format_block_content: Union[bool, None],
  129. ) -> dict:
  130. """
  131. Get the model settings based on the provided parameters or default values.
  132. Args:
  133. use_doc_orientation_classify (Union[bool, None]): Enables document orientation classification if True. Defaults to system setting if None.
  134. use_doc_unwarping (Union[bool, None]): Enables document unwarping if True. Defaults to system setting if None.
  135. Returns:
  136. dict: A dictionary containing the model settings.
  137. """
  138. if use_doc_orientation_classify is None and use_doc_unwarping is None:
  139. use_doc_preprocessor = self.use_doc_preprocessor
  140. else:
  141. if use_doc_orientation_classify is True or use_doc_unwarping is True:
  142. use_doc_preprocessor = True
  143. else:
  144. use_doc_preprocessor = False
  145. if use_layout_detection is None:
  146. use_layout_detection = self.use_layout_detection
  147. if use_chart_recognition is None:
  148. use_chart_recognition = self.use_chart_recognition
  149. if format_block_content is None:
  150. format_block_content = self.format_block_content
  151. return dict(
  152. use_doc_preprocessor=use_doc_preprocessor,
  153. use_layout_detection=use_layout_detection,
  154. use_chart_recognition=use_chart_recognition,
  155. format_block_content=format_block_content,
  156. )
  157. def check_model_settings_valid(self, input_params: dict) -> bool:
  158. """
  159. Check if the input parameters are valid based on the initialized models.
  160. Args:
  161. input_params (Dict): A dictionary containing input parameters.
  162. Returns:
  163. bool: True if all required models are initialized according to input parameters, False otherwise.
  164. """
  165. if input_params["use_doc_preprocessor"] and not self.use_doc_preprocessor:
  166. logging.error(
  167. "Set use_doc_preprocessor, but the models for doc preprocessor are not initialized.",
  168. )
  169. return False
  170. return True
  171. def get_layout_parsing_results(
  172. self,
  173. images,
  174. layout_det_results,
  175. imgs_in_doc,
  176. use_chart_recognition=False,
  177. vlm_kwargs=None,
  178. ):
  179. blocks = []
  180. block_imgs = []
  181. text_prompts = []
  182. vlm_block_ids = []
  183. figure_token_maps = []
  184. drop_figures_set = set()
  185. image_labels = (
  186. IMAGE_LABELS if use_chart_recognition else IMAGE_LABELS + ["chart"]
  187. )
  188. for i, (image, layout_det_res, imgs_in_doc_for_img) in enumerate(
  189. zip(images, layout_det_results, imgs_in_doc)
  190. ):
  191. layout_det_res = filter_overlap_boxes(layout_det_res)
  192. boxes = layout_det_res["boxes"]
  193. blocks_for_img = self.crop_by_boxes(image, boxes)
  194. blocks_for_img = merge_blocks(
  195. blocks_for_img, non_merge_labels=image_labels + ["table"]
  196. )
  197. blocks.append(blocks_for_img)
  198. for j, block in enumerate(blocks_for_img):
  199. block_img = block["img"]
  200. block_label = block["label"]
  201. if block_label not in image_labels and block_img is not None:
  202. figure_token_map = {}
  203. text_prompt = "OCR:"
  204. drop_figures = []
  205. if block_label == "table":
  206. text_prompt = "Table Recognition:"
  207. block_img, figure_token_map, drop_figures = (
  208. tokenize_figure_of_table(
  209. block_img, block["box"], imgs_in_doc_for_img
  210. )
  211. )
  212. elif block_label == "chart" and use_chart_recognition:
  213. text_prompt = "Chart Recognition:"
  214. elif "formula" in block_label and block_label != "formula_number":
  215. text_prompt = "Formula Recognition:"
  216. block_imgs.append(block_img)
  217. text_prompts.append(text_prompt)
  218. figure_token_maps.append(figure_token_map)
  219. vlm_block_ids.append((i, j))
  220. drop_figures_set.update(drop_figures)
  221. kwargs = {
  222. "use_cache": True,
  223. "max_new_tokens": 4096,
  224. **(vlm_kwargs or {}),
  225. }
  226. vl_rec_results = list(
  227. self.vl_rec_model.predict(
  228. [
  229. {
  230. "image": block_img,
  231. "query": text_prompt,
  232. }
  233. for block_img, text_prompt in zip(block_imgs, text_prompts)
  234. ],
  235. skip_special_tokens=True,
  236. **kwargs,
  237. )
  238. )
  239. parsing_res_lists = []
  240. table_res_lists = []
  241. curr_vlm_block_idx = 0
  242. for i, blocks_for_img in enumerate(blocks):
  243. parsing_res_list = []
  244. table_res_list = []
  245. for j, block in enumerate(blocks_for_img):
  246. block_img = block["img"]
  247. block_bbox = block["box"]
  248. block_label = block["label"]
  249. block_content = ""
  250. if curr_vlm_block_idx < len(vlm_block_ids) and vlm_block_ids[
  251. curr_vlm_block_idx
  252. ] == (i, j):
  253. vl_rec_result = vl_rec_results[curr_vlm_block_idx]
  254. figure_token_map = figure_token_maps[curr_vlm_block_idx]
  255. block_img4vl = block_imgs[curr_vlm_block_idx]
  256. curr_vlm_block_idx += 1
  257. vl_rec_result["image"] = block_img4vl
  258. result_str = vl_rec_result.get("result", "")
  259. if result_str is None:
  260. result_str = ""
  261. result_str = truncate_repetitive_content(result_str)
  262. if ("\\(" in result_str and "\\)" in result_str) or (
  263. "\\[" in result_str and "\\]" in result_str
  264. ):
  265. result_str = result_str.replace("$", "")
  266. result_str = (
  267. result_str.replace("\(", " $ ")
  268. .replace("\\)", " $ ")
  269. .replace("\\[", " $$ ")
  270. .replace("\\]", " $$ ")
  271. )
  272. if block_label == "formula_number":
  273. result_str = result_str.replace("$", "")
  274. if block_label == "table":
  275. html_str = convert_otsl_to_html(result_str)
  276. if html_str != "":
  277. result_str = html_str
  278. result_str = untokenize_figure_of_table(
  279. result_str, figure_token_map
  280. )
  281. block_content = result_str
  282. block_info = PPOCRVLBlock(
  283. label=block_label,
  284. bbox=block_bbox,
  285. content=block_content,
  286. )
  287. if block_label in image_labels and block_img is not None:
  288. x_min, y_min, x_max, y_max = list(map(int, block_bbox))
  289. img_path = f"imgs/img_in_{block_label}_box_{x_min}_{y_min}_{x_max}_{y_max}.jpg"
  290. if img_path not in drop_figures_set:
  291. import cv2
  292. block_img = cv2.cvtColor(block_img, cv2.COLOR_BGR2RGB)
  293. block_info.image = {
  294. "path": img_path,
  295. "img": Image.fromarray(block_img),
  296. }
  297. else:
  298. continue
  299. parsing_res_list.append(block_info)
  300. parsing_res_lists.append(parsing_res_list)
  301. table_res_lists.append(table_res_list)
  302. return parsing_res_lists, table_res_lists, imgs_in_doc
  303. def predict(
  304. self,
  305. input: Union[str, list[str], np.ndarray, list[np.ndarray]],
  306. use_doc_orientation_classify: Union[bool, None] = False,
  307. use_doc_unwarping: Union[bool, None] = False,
  308. use_layout_detection: Union[bool, None] = None,
  309. use_chart_recognition: Union[bool, None] = None,
  310. layout_threshold: Optional[Union[float, dict]] = None,
  311. layout_nms: Optional[bool] = None,
  312. layout_unclip_ratio: Optional[Union[float, Tuple[float, float], dict]] = None,
  313. layout_merge_bboxes_mode: Optional[str] = None,
  314. use_queues: Optional[bool] = None,
  315. prompt_label: Optional[Union[str, None]] = None,
  316. format_block_content: Union[bool, None] = None,
  317. repetition_penalty: Optional[float] = None,
  318. temperature: Optional[float] = None,
  319. top_p: Optional[float] = None,
  320. min_pixels: Optional[int] = None,
  321. max_pixels: Optional[int] = None,
  322. **kwargs,
  323. ) -> PPOCRVLResult:
  324. """
  325. Predicts the layout parsing result for the given input.
  326. Args:
  327. input (Union[str, list[str], np.ndarray, list[np.ndarray]]): Input image path, list of image paths,
  328. numpy array of an image, or list of numpy arrays.
  329. use_doc_orientation_classify (Optional[bool]): Whether to use document orientation classification.
  330. use_doc_unwarping (Optional[bool]): Whether to use document unwarping.
  331. layout_threshold (Optional[float]): The threshold value to filter out low-confidence predictions. Default is None.
  332. layout_nms (bool, optional): Whether to use layout-aware NMS. Defaults to False.
  333. layout_unclip_ratio (Optional[Union[float, Tuple[float, float]]], optional): The ratio of unclipping the bounding box.
  334. Defaults to None.
  335. If it's a single number, then both width and height are used.
  336. If it's a tuple of two numbers, then they are used separately for width and height respectively.
  337. If it's None, then no unclipping will be performed.
  338. layout_merge_bboxes_mode (Optional[str], optional): The mode for merging bounding boxes. Defaults to None.
  339. **kwargs (Any): Additional settings to extend functionality.
  340. Returns:
  341. PPOCRVLResult: The predicted layout parsing result.
  342. """
  343. model_settings = self.get_model_settings(
  344. use_doc_orientation_classify,
  345. use_doc_unwarping,
  346. use_layout_detection,
  347. use_chart_recognition,
  348. format_block_content,
  349. )
  350. if not self.check_model_settings_valid(model_settings):
  351. yield {"error": "the input params for model settings are invalid!"}
  352. if use_queues is None:
  353. use_queues = self.use_queues
  354. if not model_settings["use_layout_detection"]:
  355. prompt_label = prompt_label if prompt_label else "ocr"
  356. if prompt_label.lower() == "chart":
  357. model_settings["use_chart_recognition"] = True
  358. assert prompt_label.lower() in [
  359. "ocr",
  360. "formula",
  361. "table",
  362. "chart",
  363. ], f"Layout detection is disabled (use_layout_detection=False). 'prompt_label' must be one of ['ocr', 'formula', 'table', 'chart'], but got '{prompt_label}'."
  364. def _process_cv(batch_data, new_batch_size=None):
  365. if not new_batch_size:
  366. new_batch_size = len(batch_data)
  367. for idx in range(0, len(batch_data), new_batch_size):
  368. instances = batch_data.instances[idx : idx + new_batch_size]
  369. input_paths = batch_data.input_paths[idx : idx + new_batch_size]
  370. page_indexes = batch_data.page_indexes[idx : idx + new_batch_size]
  371. image_arrays = self.img_reader(instances)
  372. if model_settings["use_doc_preprocessor"]:
  373. doc_preprocessor_results = list(
  374. self.doc_preprocessor_pipeline(
  375. image_arrays,
  376. use_doc_orientation_classify=use_doc_orientation_classify,
  377. use_doc_unwarping=use_doc_unwarping,
  378. )
  379. )
  380. else:
  381. doc_preprocessor_results = [
  382. {"output_img": arr} for arr in image_arrays
  383. ]
  384. doc_preprocessor_images = [
  385. item["output_img"] for item in doc_preprocessor_results
  386. ]
  387. if model_settings["use_layout_detection"]:
  388. layout_det_results = list(
  389. self.layout_det_model(
  390. doc_preprocessor_images,
  391. threshold=layout_threshold,
  392. layout_nms=layout_nms,
  393. layout_unclip_ratio=layout_unclip_ratio,
  394. layout_merge_bboxes_mode=layout_merge_bboxes_mode,
  395. )
  396. )
  397. imgs_in_doc = [
  398. gather_imgs(doc_pp_img, layout_det_res["boxes"])
  399. for doc_pp_img, layout_det_res in zip(
  400. doc_preprocessor_images, layout_det_results
  401. )
  402. ]
  403. else:
  404. layout_det_results = []
  405. for doc_preprocessor_image in doc_preprocessor_images:
  406. layout_det_results.append(
  407. {
  408. "input_path": None,
  409. "page_index": None,
  410. "boxes": [
  411. {
  412. "cls_id": 0,
  413. "label": prompt_label.lower(),
  414. "score": 1,
  415. "coordinate": [
  416. 0,
  417. 0,
  418. doc_preprocessor_image.shape[1],
  419. doc_preprocessor_image.shape[0],
  420. ],
  421. }
  422. ],
  423. }
  424. )
  425. imgs_in_doc = [[] for _ in layout_det_results]
  426. yield input_paths, page_indexes, doc_preprocessor_images, doc_preprocessor_results, layout_det_results, imgs_in_doc
  427. def _process_vlm(results_cv):
  428. (
  429. input_paths,
  430. page_indexes,
  431. doc_preprocessor_images,
  432. doc_preprocessor_results,
  433. layout_det_results,
  434. imgs_in_doc,
  435. ) = results_cv
  436. parsing_res_lists, table_res_lists, imgs_in_doc = (
  437. self.get_layout_parsing_results(
  438. doc_preprocessor_images,
  439. layout_det_results,
  440. imgs_in_doc,
  441. model_settings["use_chart_recognition"],
  442. {
  443. "repetition_penalty": repetition_penalty,
  444. "temperature": temperature,
  445. "top_p": top_p,
  446. "min_pixels": min_pixels,
  447. "max_pixels": max_pixels,
  448. },
  449. )
  450. )
  451. for (
  452. input_path,
  453. page_index,
  454. doc_preprocessor_image,
  455. doc_preprocessor_res,
  456. layout_det_res,
  457. table_res_list,
  458. parsing_res_list,
  459. imgs_in_doc_for_img,
  460. ) in zip(
  461. input_paths,
  462. page_indexes,
  463. doc_preprocessor_images,
  464. doc_preprocessor_results,
  465. layout_det_results,
  466. table_res_lists,
  467. parsing_res_lists,
  468. imgs_in_doc,
  469. ):
  470. single_img_res = {
  471. "input_path": input_path,
  472. "page_index": page_index,
  473. "doc_preprocessor_res": doc_preprocessor_res,
  474. "layout_det_res": layout_det_res,
  475. "table_res_list": table_res_list,
  476. "parsing_res_list": parsing_res_list,
  477. "imgs_in_doc": imgs_in_doc_for_img,
  478. "model_settings": model_settings,
  479. }
  480. yield PPOCRVLResult(single_img_res)
  481. if use_queues:
  482. max_num_batches_in_process = 64
  483. queue_input = queue.Queue(maxsize=max_num_batches_in_process)
  484. queue_cv = queue.Queue(maxsize=max_num_batches_in_process)
  485. queue_vlm = queue.Queue(
  486. maxsize=self.batch_sampler.batch_size * max_num_batches_in_process
  487. )
  488. event_shutdown = threading.Event()
  489. event_data_loading_done = threading.Event()
  490. event_cv_processing_done = threading.Event()
  491. event_vlm_processing_done = threading.Event()
  492. def _worker_input(input_):
  493. all_batch_data = self.batch_sampler(input_)
  494. while not event_shutdown.is_set():
  495. try:
  496. batch_data = next(all_batch_data)
  497. except StopIteration:
  498. break
  499. except Exception as e:
  500. queue_input.put((False, "input", e))
  501. break
  502. else:
  503. queue_input.put((True, batch_data))
  504. event_data_loading_done.set()
  505. def _worker_cv():
  506. while not event_shutdown.is_set():
  507. try:
  508. item = queue_input.get(timeout=0.5)
  509. except queue.Empty:
  510. if event_data_loading_done.is_set():
  511. event_cv_processing_done.set()
  512. break
  513. continue
  514. if not item[0]:
  515. queue_cv.put(item)
  516. break
  517. try:
  518. for results_cv in _process_cv(
  519. item[1],
  520. (
  521. self.layout_det_model.batch_sampler.batch_size
  522. if model_settings["use_layout_detection"]
  523. else None
  524. ),
  525. ):
  526. queue_cv.put((True, results_cv))
  527. except Exception as e:
  528. queue_cv.put((False, "cv", e))
  529. break
  530. def _worker_vlm():
  531. MAX_QUEUE_DELAY_SECS = 0.5
  532. MAX_NUM_BOXES = self.vl_rec_model.batch_sampler.batch_size
  533. while not event_shutdown.is_set():
  534. results_cv_list = []
  535. start_time = time.time()
  536. should_break = False
  537. num_boxes = 0
  538. while True:
  539. remaining_time = MAX_QUEUE_DELAY_SECS - (
  540. time.time() - start_time
  541. )
  542. if remaining_time <= 0:
  543. break
  544. try:
  545. item = queue_cv.get(timeout=remaining_time)
  546. except queue.Empty:
  547. break
  548. if not item[0]:
  549. queue_vlm.put(item)
  550. should_break = True
  551. break
  552. results_cv_list.append(item[1])
  553. for res in results_cv_list[-1][4]:
  554. num_boxes += len(res["boxes"])
  555. if num_boxes >= MAX_NUM_BOXES:
  556. break
  557. if should_break:
  558. break
  559. if not results_cv_list:
  560. if event_cv_processing_done.is_set():
  561. event_vlm_processing_done.set()
  562. break
  563. continue
  564. merged_results_cv = [
  565. list(chain.from_iterable(lists))
  566. for lists in zip(*results_cv_list)
  567. ]
  568. try:
  569. for result_vlm in _process_vlm(merged_results_cv):
  570. queue_vlm.put((True, result_vlm))
  571. except Exception as e:
  572. queue_vlm.put((False, "vlm", e))
  573. break
  574. thread_input = threading.Thread(
  575. target=_worker_input, args=(input,), daemon=False
  576. )
  577. thread_input.start()
  578. thread_cv = threading.Thread(target=_worker_cv, daemon=False)
  579. thread_cv.start()
  580. thread_vlm = threading.Thread(target=_worker_vlm, daemon=False)
  581. thread_vlm.start()
  582. try:
  583. if use_queues:
  584. while not (event_vlm_processing_done.is_set() and queue_vlm.empty()):
  585. try:
  586. item = queue_vlm.get(timeout=0.5)
  587. except queue.Empty:
  588. if event_vlm_processing_done.is_set():
  589. break
  590. continue
  591. if not item[0]:
  592. raise RuntimeError(
  593. f"Exception from the '{item[1]}' worker: {item[2]}"
  594. )
  595. else:
  596. yield item[1]
  597. else:
  598. for batch_data in self.batch_sampler(input):
  599. results_cv_list = list(_process_cv(batch_data))
  600. assert len(results_cv_list) == 1, len(results_cv_list)
  601. results_cv = results_cv_list[0]
  602. for res in _process_vlm(results_cv):
  603. yield res
  604. finally:
  605. if use_queues:
  606. event_shutdown.set()
  607. thread_cv.join(timeout=5)
  608. if thread_cv.is_alive():
  609. logging.warning("CV worker did not terminate in time")
  610. thread_vlm.join(timeout=5)
  611. if thread_vlm.is_alive():
  612. logging.warning("VLM worker did not terminate in time")
  613. @pipeline_requires_extra("ocr")
  614. class PPOCRVLPipeline(AutoParallelImageSimpleInferencePipeline):
  615. entities = "PaddleOCR-VL"
  616. @property
  617. def _pipeline_cls(self):
  618. return _PPOCRVLPipeline
  619. def _get_batch_size(self, config):
  620. return config.get("batch_size", 1)