Browse Source

Adjust projects/multi_gpu/server.py for magic_pdf-1.0.1

Hui 10 tháng trước cách đây
mục cha
commit
667d2c0d55

+ 1 - 1
projects/multi_gpu/README.md

@@ -31,7 +31,7 @@ python server.py
 ### 2. 启动客户端
 以下代码展示了客户端的使用方式,可根据需求修改配置:
 ```python
-files = ['demo/small_ocr.pdf']  # 替换为文件路径,支持 jpg/jpeg、png、pdf 文件
+files = ['demo/small_ocr.pdf']  # 替换为文件路径,支持 pdf、jpg/jpeg、png、doc、docx、ppt、pptx 文件
 n_jobs = np.clip(len(files), 1, 8)  # 设置并发线程数,此处最大为 8,可根据自身修改
 results = Parallel(n_jobs, prefer='threads', verbose=10)(
     delayed(do_parse)(p) for p in files

+ 1 - 1
projects/multi_gpu/client.py

@@ -31,7 +31,7 @@ def do_parse(file_path, url='http://127.0.0.1:8000/predict', **kwargs):
 
 
 if __name__ == '__main__':
-    files = ['small_ocr.pdf']
+    files = ['demo/small_ocr.pdf']
     n_jobs = np.clip(len(files), 1, 8)
     results = Parallel(n_jobs, prefer='threads', verbose=10)(
         delayed(do_parse)(p) for p in files

+ 33 - 13
projects/multi_gpu/server.py

@@ -1,18 +1,22 @@
 import os
+import uuid
+import shutil
+import tempfile
+import gc
 import fitz
 import torch
 import base64
+import filetype
 import litserve as ls
-from uuid import uuid4
+from pathlib import Path
 from fastapi import HTTPException
-from filetype import guess_extension
-from magic_pdf.tools.common import do_parse
+from magic_pdf.tools.cli import do_parse, convert_file_to_pdf
 from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
 
 
 class MinerUAPI(ls.LitAPI):
     def __init__(self, output_dir='/tmp'):
-        self.output_dir = output_dir
+        self.output_dir = Path(output_dir)
 
     def setup(self, device):
         if device.startswith('cuda'):
@@ -27,7 +31,7 @@ class MinerUAPI(ls.LitAPI):
 
     def decode_request(self, request):
         file = request['file']
-        file = self.to_pdf(file)
+        file = self.cvt2pdf(file)
         opts = request.get('kwargs', {})
         opts.setdefault('debug_able', False)
         opts.setdefault('parse_method', 'auto')
@@ -35,9 +39,12 @@ class MinerUAPI(ls.LitAPI):
 
     def predict(self, inputs):
         try:
-            do_parse(self.output_dir, pdf_name := str(uuid4()), inputs[0], [], **inputs[1])
-            return pdf_name
+            pdf_name = str(uuid.uuid4())
+            output_dir = self.output_dir.joinpath(pdf_name)
+            do_parse(self.output_dir, pdf_name, inputs[0], [], **inputs[1])
+            return output_dir
         except Exception as e:
+            shutil.rmtree(output_dir, ignore_errors=True)
             raise HTTPException(status_code=500, detail=str(e))
         finally:
             self.clean_memory()
@@ -46,21 +53,34 @@ class MinerUAPI(ls.LitAPI):
         return {'output_dir': response}
 
     def clean_memory(self):
-        import gc
         if torch.cuda.is_available():
             torch.cuda.empty_cache()
             torch.cuda.ipc_collect()
         gc.collect()
 
-    def to_pdf(self, file_base64):
+    def cvt2pdf(self, file_base64):
         try:
+            temp_dir = Path(tempfile.mkdtemp())
+            temp_file = temp_dir.joinpath('tmpfile')
             file_bytes = base64.b64decode(file_base64)
-            file_ext = guess_extension(file_bytes)
-            with fitz.open(stream=file_bytes, filetype=file_ext) as f:
-                if f.is_pdf: return f.tobytes()
-                return f.convert_to_pdf()
+            file_ext = filetype.guess_extension(file_bytes)
+
+            if file_ext in ['pdf', 'jpg', 'png', 'doc', 'docx', 'ppt', 'pptx']:
+                if file_ext == 'pdf':
+                    return file_bytes
+                elif file_ext in ['jpg', 'png']:
+                    with fitz.open(stream=file_bytes, filetype=file_ext) as f:
+                        return f.convert_to_pdf()
+                else:
+                    temp_file.write_bytes(file_bytes)
+                    convert_file_to_pdf(temp_file, temp_dir)
+                    return temp_file.with_suffix('.pdf').read_bytes()
+            else:
+                raise Exception('Unsupported file format')
         except Exception as e:
             raise HTTPException(status_code=500, detail=str(e))
+        finally:
+            shutil.rmtree(temp_dir, ignore_errors=True)
 
 
 if __name__ == '__main__':

BIN
projects/multi_gpu/small_ocr.pdf