| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206 |
- import os
- import time
- from typing import List, Tuple
- from PIL import Image
- from loguru import logger
- from .model_init import MineruPipelineModel
- from mineru.utils.config_reader import get_device
- from ...utils.enum_class import ImageType
- from ...utils.pdf_classify import classify
- from ...utils.pdf_image_tools import load_images_from_pdf
- from ...utils.model_utils import get_vram, clean_memory
- os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 让mps可以fallback
- os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
- class ModelSingleton:
- _instance = None
- _models = {}
- def __new__(cls, *args, **kwargs):
- if cls._instance is None:
- cls._instance = super().__new__(cls)
- return cls._instance
- def get_model(
- self,
- lang=None,
- formula_enable=None,
- table_enable=None,
- ):
- key = (lang, formula_enable, table_enable)
- if key not in self._models:
- self._models[key] = custom_model_init(
- lang=lang,
- formula_enable=formula_enable,
- table_enable=table_enable,
- )
- return self._models[key]
- def custom_model_init(
- lang=None,
- formula_enable=True,
- table_enable=True,
- ):
- model_init_start = time.time()
- # 从配置文件读取model-dir和device
- device = get_device()
- formula_config = {"enable": formula_enable}
- table_config = {"enable": table_enable}
- model_input = {
- 'device': device,
- 'table_config': table_config,
- 'formula_config': formula_config,
- 'lang': lang,
- }
- custom_model = MineruPipelineModel(**model_input)
- model_init_cost = time.time() - model_init_start
- logger.info(f'model init cost: {model_init_cost}')
- return custom_model
- def doc_analyze(
- pdf_bytes_list,
- lang_list,
- parse_method: str = 'auto',
- formula_enable=True,
- table_enable=True,
- ):
- """
- 适当调大MIN_BATCH_INFERENCE_SIZE可以提高性能,更大的 MIN_BATCH_INFERENCE_SIZE会消耗更多内存,
- 可通过环境变量MINERU_MIN_BATCH_INFERENCE_SIZE设置,默认值为384。
- """
- min_batch_inference_size = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 384))
- # 收集所有页面信息
- all_pages_info = [] # 存储(dataset_index, page_index, img, ocr, lang, width, height)
- all_image_lists = []
- all_pdf_docs = []
- ocr_enabled_list = []
- for pdf_idx, pdf_bytes in enumerate(pdf_bytes_list):
- # 确定OCR设置
- _ocr_enable = False
- if parse_method == 'auto':
- if classify(pdf_bytes) == 'ocr':
- _ocr_enable = True
- elif parse_method == 'ocr':
- _ocr_enable = True
- ocr_enabled_list.append(_ocr_enable)
- _lang = lang_list[pdf_idx]
- # 收集每个数据集中的页面
- images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.PIL)
- all_image_lists.append(images_list)
- all_pdf_docs.append(pdf_doc)
- for page_idx in range(len(images_list)):
- img_dict = images_list[page_idx]
- all_pages_info.append((
- pdf_idx, page_idx,
- img_dict['img_pil'], _ocr_enable, _lang,
- ))
- # 准备批处理
- images_with_extra_info = [(info[2], info[3], info[4]) for info in all_pages_info]
- batch_size = min_batch_inference_size
- batch_images = [
- images_with_extra_info[i:i + batch_size]
- for i in range(0, len(images_with_extra_info), batch_size)
- ]
- # 执行批处理
- results = []
- processed_images_count = 0
- for index, batch_image in enumerate(batch_images):
- processed_images_count += len(batch_image)
- logger.info(
- f'Batch {index + 1}/{len(batch_images)}: '
- f'{processed_images_count} pages/{len(images_with_extra_info)} pages'
- )
- batch_results = batch_image_analyze(batch_image, formula_enable, table_enable)
- results.extend(batch_results)
- # 构建返回结果
- infer_results = []
- for _ in range(len(pdf_bytes_list)):
- infer_results.append([])
- for i, page_info in enumerate(all_pages_info):
- pdf_idx, page_idx, pil_img, _, _ = page_info
- result = results[i]
- page_info_dict = {'page_no': page_idx, 'width': pil_img.width, 'height': pil_img.height}
- page_dict = {'layout_dets': result, 'page_info': page_info_dict}
- infer_results[pdf_idx].append(page_dict)
- return infer_results, all_image_lists, all_pdf_docs, lang_list, ocr_enabled_list
- def batch_image_analyze(
- images_with_extra_info: List[Tuple[Image.Image, bool, str]],
- formula_enable=True,
- table_enable=True):
- from .batch_analyze import BatchAnalyze
- model_manager = ModelSingleton()
- batch_ratio = 1
- device = get_device()
- if str(device).startswith('npu'):
- try:
- import torch_npu
- if torch_npu.npu.is_available():
- torch_npu.npu.set_compile_mode(jit_compile=False)
- except Exception as e:
- raise RuntimeError(
- "NPU is selected as device, but torch_npu is not available. "
- "Please ensure that the torch_npu package is installed correctly."
- ) from e
- if str(device).startswith('npu') or str(device).startswith('cuda'):
- vram = get_vram(device)
- if vram is not None:
- gpu_memory = int(os.getenv('MINERU_VIRTUAL_VRAM_SIZE', round(vram)))
- if gpu_memory >= 16:
- batch_ratio = 16
- elif gpu_memory >= 12:
- batch_ratio = 8
- elif gpu_memory >= 8:
- batch_ratio = 4
- elif gpu_memory >= 6:
- batch_ratio = 2
- else:
- batch_ratio = 1
- logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
- else:
- # Default batch_ratio when VRAM can't be determined
- batch_ratio = 1
- logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_ratio}')
- # 检测torch的版本号
- import torch
- from packaging import version
- if version.parse(torch.__version__) >= version.parse("2.8.0") or str(device).startswith('mps'):
- enable_ocr_det_batch = False
- else:
- enable_ocr_det_batch = True
- batch_model = BatchAnalyze(model_manager, batch_ratio, formula_enable, table_enable, enable_ocr_det_batch)
- results = batch_model(images_with_extra_info)
- clean_memory(get_device())
- return results
|