|
|
@@ -1,18 +1,22 @@
|
|
|
import os
|
|
|
+import uuid
|
|
|
+import shutil
|
|
|
+import tempfile
|
|
|
+import gc
|
|
|
import fitz
|
|
|
import torch
|
|
|
import base64
|
|
|
+import filetype
|
|
|
import litserve as ls
|
|
|
-from uuid import uuid4
|
|
|
+from pathlib import Path
|
|
|
from fastapi import HTTPException
|
|
|
-from filetype import guess_extension
|
|
|
-from magic_pdf.tools.common import do_parse
|
|
|
+from magic_pdf.tools.cli import do_parse, convert_file_to_pdf
|
|
|
from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
|
|
|
|
|
|
|
|
|
class MinerUAPI(ls.LitAPI):
|
|
|
def __init__(self, output_dir='/tmp'):
|
|
|
- self.output_dir = output_dir
|
|
|
+ self.output_dir = Path(output_dir)
|
|
|
|
|
|
def setup(self, device):
|
|
|
if device.startswith('cuda'):
|
|
|
@@ -27,7 +31,7 @@ class MinerUAPI(ls.LitAPI):
|
|
|
|
|
|
def decode_request(self, request):
|
|
|
file = request['file']
|
|
|
- file = self.to_pdf(file)
|
|
|
+ file = self.cvt2pdf(file)
|
|
|
opts = request.get('kwargs', {})
|
|
|
opts.setdefault('debug_able', False)
|
|
|
opts.setdefault('parse_method', 'auto')
|
|
|
@@ -35,9 +39,12 @@ class MinerUAPI(ls.LitAPI):
|
|
|
|
|
|
def predict(self, inputs):
|
|
|
try:
|
|
|
- do_parse(self.output_dir, pdf_name := str(uuid4()), inputs[0], [], **inputs[1])
|
|
|
- return pdf_name
|
|
|
+ pdf_name = str(uuid.uuid4())
|
|
|
+ output_dir = self.output_dir.joinpath(pdf_name)
|
|
|
+ do_parse(self.output_dir, pdf_name, inputs[0], [], **inputs[1])
|
|
|
+ return output_dir
|
|
|
except Exception as e:
|
|
|
+ shutil.rmtree(output_dir, ignore_errors=True)
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
finally:
|
|
|
self.clean_memory()
|
|
|
@@ -46,21 +53,34 @@ class MinerUAPI(ls.LitAPI):
|
|
|
return {'output_dir': response}
|
|
|
|
|
|
def clean_memory(self):
|
|
|
- import gc
|
|
|
if torch.cuda.is_available():
|
|
|
torch.cuda.empty_cache()
|
|
|
torch.cuda.ipc_collect()
|
|
|
gc.collect()
|
|
|
|
|
|
- def to_pdf(self, file_base64):
|
|
|
+ def cvt2pdf(self, file_base64):
|
|
|
try:
|
|
|
+ temp_dir = Path(tempfile.mkdtemp())
|
|
|
+ temp_file = temp_dir.joinpath('tmpfile')
|
|
|
file_bytes = base64.b64decode(file_base64)
|
|
|
- file_ext = guess_extension(file_bytes)
|
|
|
- with fitz.open(stream=file_bytes, filetype=file_ext) as f:
|
|
|
- if f.is_pdf: return f.tobytes()
|
|
|
- return f.convert_to_pdf()
|
|
|
+ file_ext = filetype.guess_extension(file_bytes)
|
|
|
+
|
|
|
+ if file_ext in ['pdf', 'jpg', 'png', 'doc', 'docx', 'ppt', 'pptx']:
|
|
|
+ if file_ext == 'pdf':
|
|
|
+ return file_bytes
|
|
|
+ elif file_ext in ['jpg', 'png']:
|
|
|
+ with fitz.open(stream=file_bytes, filetype=file_ext) as f:
|
|
|
+ return f.convert_to_pdf()
|
|
|
+ else:
|
|
|
+ temp_file.write_bytes(file_bytes)
|
|
|
+ convert_file_to_pdf(temp_file, temp_dir)
|
|
|
+ return temp_file.with_suffix('.pdf').read_bytes()
|
|
|
+ else:
|
|
|
+ raise Exception('Unsupported file format')
|
|
|
except Exception as e:
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
+ finally:
|
|
|
+ shutil.rmtree(temp_dir, ignore_errors=True)
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|