Explorar o código

Merge branch 'opendatalab:dev' into dev

Xiaomeng Zhao hai 7 meses
pai
achega
a9b37b716e
Modificáronse 1 ficheiros con 12 adicións e 62 borrados
  1. 12 62
      magic_pdf/model/doc_analyze_by_custom_model.py

+ 12 - 62
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -150,7 +150,10 @@ def doc_analyze(
             img_dict = page_data.get_image()
             images.append(img_dict['img'])
             page_wh_list.append((img_dict['width'], img_dict['height']))
-    images_with_extra_info = [(images[index], ocr, dataset._lang) for index in range(len(dataset))]
+    if lang is None or lang == 'auto':
+        images_with_extra_info = [(images[index], ocr, dataset._lang) for index in range(len(dataset))]
+    else:
+        images_with_extra_info = [(images[index], ocr, lang) for index in range(len(dataset))]
 
     if len(images) >= MIN_BATCH_INFERENCE_SIZE:
         batch_size = MIN_BATCH_INFERENCE_SIZE
@@ -160,7 +163,7 @@ def doc_analyze(
 
     results = []
     for sn, batch_image in enumerate(batch_images):
-        _, result = may_batch_image_analyze(batch_image, sn, ocr, show_log, lang, layout_model, formula_enable, table_enable)
+        _, result = may_batch_image_analyze(batch_image, sn, ocr, show_log,layout_model, formula_enable, table_enable)
         results.extend(result)
 
     model_json = []
@@ -214,7 +217,7 @@ def batch_doc_analyze(
     batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)]
     results = []
     for sn, batch_image in enumerate(batch_images):
-        _, result = may_batch_image_analyze(batch_image, sn, True, show_log, lang, layout_model, formula_enable, table_enable)
+        _, result = may_batch_image_analyze(batch_image, sn, True, show_log, layout_model, formula_enable, table_enable)
         results.extend(result)
 
     infer_results = []
@@ -237,7 +240,6 @@ def may_batch_image_analyze(
         idx: int,
         ocr: bool,
         show_log: bool = False,
-        lang=None,
         layout_model=None,
         formula_enable=None,
         table_enable=None):
@@ -248,9 +250,6 @@ def may_batch_image_analyze(
     from magic_pdf.model.batch_analyze import BatchAnalyze
 
     model_manager = ModelSingleton()
-    custom_model = model_manager.get_model(
-        ocr, show_log, lang, layout_model, formula_enable, table_enable
-    )
 
     images = [image for image, _, _ in images_with_extra_info]
     batch_analyze = False
@@ -276,64 +275,15 @@ def may_batch_image_analyze(
             else:
                 batch_ratio = 1
             logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
-            batch_analyze = True
+            # batch_analyze = True
     elif str(device).startswith('mps'):
-        batch_analyze = True
-    doc_analyze_start = time.time()
+        # batch_analyze = True
+        pass
 
-    if batch_analyze:
-        """# batch analyze
-        images = []
-        page_wh_list = []
-        for index in range(len(dataset)):
-            if start_page_id <= index <= end_page_id:
-                page_data = dataset.get_page(index)
-                img_dict = page_data.get_image()
-                images.append(img_dict['img'])
-                page_wh_list.append((img_dict['width'], img_dict['height']))
-        """
-        batch_model = BatchAnalyze(model_manager, batch_ratio, show_log, layout_model, formula_enable, table_enable)
-        results = batch_model(images_with_extra_info)
-        """
-        for index in range(len(dataset)):
-            if start_page_id <= index <= end_page_id:
-                result = analyze_result.pop(0)
-                page_width, page_height = page_wh_list.pop(0)
-            else:
-                result = []
-                page_height = 0
-                page_width = 0
-
-            page_info = {'page_no': index, 'width': page_width, 'height': page_height}
-            page_dict = {'layout_dets': result, 'page_info': page_info}
-            model_json.append(page_dict)
-        """
-    else:
-        # single analyze
-        """
-        for index in range(len(dataset)):
-            page_data = dataset.get_page(index)
-            img_dict = page_data.get_image()
-            img = img_dict['img']
-            page_width = img_dict['width']
-            page_height = img_dict['height']
-            if start_page_id <= index <= end_page_id:
-                page_start = time.time()
-                result = custom_model(img)
-                logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
-            else:
-                result = []
+    doc_analyze_start = time.time()
 
-            page_info = {'page_no': index, 'width': page_width, 'height': page_height}
-            page_dict = {'layout_dets': result, 'page_info': page_info}
-            model_json.append(page_dict)
-        """
-        results = []
-        for img_idx, img in enumerate(images):
-            inference_start = time.time()
-            result = custom_model(img)
-            logger.info(f'-----image index : {img_idx}, image inference total time: {round(time.time() - inference_start, 2)}-----')
-            results.append(result)
+    batch_model = BatchAnalyze(model_manager, batch_ratio, show_log, layout_model, formula_enable, table_enable)
+    results = batch_model(images_with_extra_info)
 
     gc_start = time.time()
     clean_memory(get_device())