server.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import os
  2. import uuid
  3. import shutil
  4. import tempfile
  5. import gc
  6. import fitz
  7. import torch
  8. import base64
  9. import filetype
  10. import litserve as ls
  11. from pathlib import Path
  12. from fastapi import HTTPException
  13. class MinerUAPI(ls.LitAPI):
  14. def __init__(self, output_dir='/tmp'):
  15. self.output_dir = Path(output_dir)
  16. def setup(self, device):
  17. if device.startswith('cuda'):
  18. os.environ['CUDA_VISIBLE_DEVICES'] = device.split(':')[-1]
  19. if torch.cuda.device_count() > 1:
  20. raise RuntimeError("Remove any CUDA actions before setting 'CUDA_VISIBLE_DEVICES'.")
  21. from magic_pdf.tools.cli import do_parse, convert_file_to_pdf
  22. from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
  23. self.do_parse = do_parse
  24. self.convert_file_to_pdf = convert_file_to_pdf
  25. model_manager = ModelSingleton()
  26. model_manager.get_model(True, False)
  27. model_manager.get_model(False, False)
  28. print(f'Model initialization complete on {device}!')
  29. def decode_request(self, request):
  30. file = request['file']
  31. file = self.cvt2pdf(file)
  32. opts = request.get('kwargs', {})
  33. opts.setdefault('debug_able', False)
  34. opts.setdefault('parse_method', 'auto')
  35. return file, opts
  36. def predict(self, inputs):
  37. try:
  38. pdf_name = str(uuid.uuid4())
  39. output_dir = self.output_dir.joinpath(pdf_name)
  40. self.do_parse(self.output_dir, pdf_name, inputs[0], [], **inputs[1])
  41. return output_dir
  42. except Exception as e:
  43. shutil.rmtree(output_dir, ignore_errors=True)
  44. raise HTTPException(status_code=500, detail=str(e))
  45. finally:
  46. self.clean_memory()
  47. def encode_response(self, response):
  48. return {'output_dir': response}
  49. def clean_memory(self):
  50. if torch.cuda.is_available():
  51. torch.cuda.empty_cache()
  52. torch.cuda.ipc_collect()
  53. gc.collect()
  54. def cvt2pdf(self, file_base64):
  55. try:
  56. temp_dir = Path(tempfile.mkdtemp())
  57. temp_file = temp_dir.joinpath('tmpfile')
  58. file_bytes = base64.b64decode(file_base64)
  59. file_ext = filetype.guess_extension(file_bytes)
  60. if file_ext in ['pdf', 'jpg', 'png', 'doc', 'docx', 'ppt', 'pptx']:
  61. if file_ext == 'pdf':
  62. return file_bytes
  63. elif file_ext in ['jpg', 'png']:
  64. with fitz.open(stream=file_bytes, filetype=file_ext) as f:
  65. return f.convert_to_pdf()
  66. else:
  67. temp_file.write_bytes(file_bytes)
  68. self.convert_file_to_pdf(temp_file, temp_dir)
  69. return temp_file.with_suffix('.pdf').read_bytes()
  70. else:
  71. raise Exception('Unsupported file format')
  72. except Exception as e:
  73. raise HTTPException(status_code=500, detail=str(e))
  74. finally:
  75. shutil.rmtree(temp_dir, ignore_errors=True)
  76. if __name__ == '__main__':
  77. server = ls.LitServer(
  78. MinerUAPI(output_dir='/tmp'),
  79. accelerator='cuda',
  80. devices='auto',
  81. workers_per_device=1,
  82. timeout=False
  83. )
  84. server.run(port=8000)