server.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import os
  2. import fitz
  3. import torch
  4. import base64
  5. import litserve as ls
  6. from uuid import uuid4
  7. from fastapi import HTTPException
  8. from filetype import guess_extension
  9. from magic_pdf.tools.common import do_parse
  10. from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
  11. class MinerUAPI(ls.LitAPI):
  12. def __init__(self, output_dir='/tmp'):
  13. self.output_dir = output_dir
  14. def setup(self, device):
  15. if device.startswith('cuda'):
  16. os.environ['CUDA_VISIBLE_DEVICES'] = device.split(':')[-1]
  17. if torch.cuda.device_count() > 1:
  18. raise RuntimeError("Remove any CUDA actions before setting 'CUDA_VISIBLE_DEVICES'.")
  19. model_manager = ModelSingleton()
  20. model_manager.get_model(True, False)
  21. model_manager.get_model(False, False)
  22. print(f'Model initialization complete on {device}!')
  23. def decode_request(self, request):
  24. file = request['file']
  25. file = self.to_pdf(file)
  26. opts = request.get('kwargs', {})
  27. opts.setdefault('debug_able', False)
  28. opts.setdefault('parse_method', 'auto')
  29. return file, opts
  30. def predict(self, inputs):
  31. try:
  32. do_parse(self.output_dir, pdf_name := str(uuid4()), inputs[0], [], **inputs[1])
  33. return pdf_name
  34. except Exception as e:
  35. raise HTTPException(status_code=500, detail=str(e))
  36. finally:
  37. self.clean_memory()
  38. def encode_response(self, response):
  39. return {'output_dir': response}
  40. def clean_memory(self):
  41. import gc
  42. if torch.cuda.is_available():
  43. torch.cuda.empty_cache()
  44. torch.cuda.ipc_collect()
  45. gc.collect()
  46. def to_pdf(self, file_base64):
  47. try:
  48. file_bytes = base64.b64decode(file_base64)
  49. file_ext = guess_extension(file_bytes)
  50. with fitz.open(stream=file_bytes, filetype=file_ext) as f:
  51. if f.is_pdf: return f.tobytes()
  52. return f.convert_to_pdf()
  53. except Exception as e:
  54. raise HTTPException(status_code=500, detail=str(e))
  55. if __name__ == '__main__':
  56. server = ls.LitServer(
  57. MinerUAPI(output_dir='/tmp'),
  58. accelerator='cuda',
  59. devices='auto',
  60. workers_per_device=1,
  61. timeout=False
  62. )
  63. server.run(port=8000)