Sfoglia il codice sorgente

feat(gradio_app): implement dynamic concurrency limit based on VRAM

- Add get_concurrency_limit function to calculate concurrency limit based on VRAM
- Update clean_vram function and rename to get_vram for better clarity
- Apply concurrency limit to the to_markdown function in the Gradio app
myhloli 11 mesi fa
parent
commit
b1fe9d4f60
2 ha cambiato i file con 23 aggiunte e 6 eliminazioni
  1. 11 5
      magic_pdf/model/sub_modules/model_utils.py
  2. 12 1
      projects/gradio_app/app.py

+ 11 - 5
magic_pdf/model/sub_modules/model_utils.py

@@ -42,10 +42,16 @@ def get_res_list_from_layout_res(layout_res):
 
 
 def clean_vram(device, vram_threshold=8):
+    total_memory = get_vram(device)
+    if total_memory <= vram_threshold:
+        gc_start = time.time()
+        clean_memory()
+        gc_time = round(time.time() - gc_start, 2)
+        logger.info(f"gc time: {gc_time}")
+
+
+def get_vram(device):
     if torch.cuda.is_available() and device != 'cpu':
         total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)  # 将字节转换为 GB
-        if total_memory <= vram_threshold:
-            gc_start = time.time()
-            clean_memory()
-            gc_time = round(time.time() - gc_start, 2)
-            logger.info(f"gc time: {gc_time}")
+        return total_memory
+    return 0

+ 12 - 1
projects/gradio_app/app.py

@@ -14,7 +14,9 @@ from gradio_pdf import PDF
 from loguru import logger
 
 from magic_pdf.data.data_reader_writer import FileBasedDataReader
+from magic_pdf.libs.config_reader import get_device
 from magic_pdf.libs.hash_utils import compute_sha256
+from magic_pdf.model.sub_modules.model_utils import get_vram
 from magic_pdf.tools.common import do_parse, prepare_env
 
 
@@ -183,6 +185,15 @@ def to_pdf(file_path):
             return tmp_file_path
 
 
+def get_concurrency_limit(vram_threshold=7.5):
+    vram = get_vram(device = get_device())
+    concurrency_limit = int(vram // vram_threshold)
+    if concurrency_limit < 1:
+        concurrency_limit = 1
+    # logger.info(f'concurrency_limit: {concurrency_limit}')
+    return concurrency_limit
+
+
 if __name__ == '__main__':
     with gr.Blocks() as demo:
         gr.HTML(header)
@@ -219,7 +230,7 @@ if __name__ == '__main__':
                         md_text = gr.TextArea(lines=45, show_copy_button=True)
         file.upload(fn=to_pdf, inputs=file, outputs=pdf_show)
         change_bu.click(fn=to_markdown, inputs=[pdf_show, max_pages, is_ocr, layout_mode, formula_enable, table_enable, language],
-                        outputs=[md, md_text, output_file, pdf_show])
+                        outputs=[md, md_text, output_file, pdf_show], concurrency_limit=get_concurrency_limit())
         clear_bu.add([file, md, pdf_show, md_text, output_file, is_ocr, table_enable, language])
 
     demo.launch(server_name='0.0.0.0')