server.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import os
  2. import uuid
  3. import shutil
  4. import tempfile
  5. import gc
  6. import fitz
  7. import torch
  8. import base64
  9. import filetype
  10. import litserve as ls
  11. from pathlib import Path
  12. from fastapi import HTTPException
  13. from magic_pdf.tools.cli import do_parse, convert_file_to_pdf
  14. from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
  15. class MinerUAPI(ls.LitAPI):
  16. def __init__(self, output_dir='/tmp'):
  17. self.output_dir = Path(output_dir)
  18. def setup(self, device):
  19. if device.startswith('cuda'):
  20. os.environ['CUDA_VISIBLE_DEVICES'] = device.split(':')[-1]
  21. if torch.cuda.device_count() > 1:
  22. raise RuntimeError("Remove any CUDA actions before setting 'CUDA_VISIBLE_DEVICES'.")
  23. model_manager = ModelSingleton()
  24. model_manager.get_model(True, False)
  25. model_manager.get_model(False, False)
  26. print(f'Model initialization complete on {device}!')
  27. def decode_request(self, request):
  28. file = request['file']
  29. file = self.cvt2pdf(file)
  30. opts = request.get('kwargs', {})
  31. opts.setdefault('debug_able', False)
  32. opts.setdefault('parse_method', 'auto')
  33. return file, opts
  34. def predict(self, inputs):
  35. try:
  36. pdf_name = str(uuid.uuid4())
  37. output_dir = self.output_dir.joinpath(pdf_name)
  38. do_parse(self.output_dir, pdf_name, inputs[0], [], **inputs[1])
  39. return output_dir
  40. except Exception as e:
  41. shutil.rmtree(output_dir, ignore_errors=True)
  42. raise HTTPException(status_code=500, detail=str(e))
  43. finally:
  44. self.clean_memory()
  45. def encode_response(self, response):
  46. return {'output_dir': response}
  47. def clean_memory(self):
  48. if torch.cuda.is_available():
  49. torch.cuda.empty_cache()
  50. torch.cuda.ipc_collect()
  51. gc.collect()
  52. def cvt2pdf(self, file_base64):
  53. try:
  54. temp_dir = Path(tempfile.mkdtemp())
  55. temp_file = temp_dir.joinpath('tmpfile')
  56. file_bytes = base64.b64decode(file_base64)
  57. file_ext = filetype.guess_extension(file_bytes)
  58. if file_ext in ['pdf', 'jpg', 'png', 'doc', 'docx', 'ppt', 'pptx']:
  59. if file_ext == 'pdf':
  60. return file_bytes
  61. elif file_ext in ['jpg', 'png']:
  62. with fitz.open(stream=file_bytes, filetype=file_ext) as f:
  63. return f.convert_to_pdf()
  64. else:
  65. temp_file.write_bytes(file_bytes)
  66. convert_file_to_pdf(temp_file, temp_dir)
  67. return temp_file.with_suffix('.pdf').read_bytes()
  68. else:
  69. raise Exception('Unsupported file format')
  70. except Exception as e:
  71. raise HTTPException(status_code=500, detail=str(e))
  72. finally:
  73. shutil.rmtree(temp_dir, ignore_errors=True)
  74. if __name__ == '__main__':
  75. server = ls.LitServer(
  76. MinerUAPI(output_dir='/tmp'),
  77. accelerator='cuda',
  78. devices='auto',
  79. workers_per_device=1,
  80. timeout=False
  81. )
  82. server.run(port=8000)