| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374 |
- import os
- import fitz
- import torch
- import base64
- import litserve as ls
- from uuid import uuid4
- from fastapi import HTTPException
- from filetype import guess_extension
- from magic_pdf.tools.common import do_parse
- from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
- class MinerUAPI(ls.LitAPI):
- def __init__(self, output_dir='/tmp'):
- self.output_dir = output_dir
- def setup(self, device):
- if device.startswith('cuda'):
- os.environ['CUDA_VISIBLE_DEVICES'] = device.split(':')[-1]
- if torch.cuda.device_count() > 1:
- raise RuntimeError("Remove any CUDA actions before setting 'CUDA_VISIBLE_DEVICES'.")
- model_manager = ModelSingleton()
- model_manager.get_model(True, False)
- model_manager.get_model(False, False)
- print(f'Model initialization complete on {device}!')
- def decode_request(self, request):
- file = request['file']
- file = self.to_pdf(file)
- opts = request.get('kwargs', {})
- opts.setdefault('debug_able', False)
- opts.setdefault('parse_method', 'auto')
- return file, opts
- def predict(self, inputs):
- try:
- do_parse(self.output_dir, pdf_name := str(uuid4()), inputs[0], [], **inputs[1])
- return pdf_name
- except Exception as e:
- raise HTTPException(status_code=500, detail=str(e))
- finally:
- self.clean_memory()
- def encode_response(self, response):
- return {'output_dir': response}
- def clean_memory(self):
- import gc
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- torch.cuda.ipc_collect()
- gc.collect()
- def to_pdf(self, file_base64):
- try:
- file_bytes = base64.b64decode(file_base64)
- file_ext = guess_extension(file_bytes)
- with fitz.open(stream=file_bytes, filetype=file_ext) as f:
- if f.is_pdf: return f.tobytes()
- return f.convert_to_pdf()
- except Exception as e:
- raise HTTPException(status_code=500, detail=str(e))
- if __name__ == '__main__':
- server = ls.LitServer(
- MinerUAPI(output_dir='/tmp'),
- accelerator='cuda',
- devices='auto',
- workers_per_device=1,
- timeout=False
- )
- server.run(port=8000)
|