Sfoglia il codice sorgente

Merge pull request #2078 from opendatalab/dev

feat(model): add tqdm progress bar to model prediction loops
Xiaomeng Zhao 7 mesi fa
parent
commit
6a30b5bd71

+ 1 - 1
README.md

@@ -307,7 +307,7 @@ You can modify certain configurations in this file to enable or disable features
     },
     "table-config": {
         "model": "rapid_table", 
-        "sub_model": "slanet_plus",  // When the model is "rapid_table", you can choose a sub_model. The options are "slanet_plus" and "unitable"
+        "sub_model": "slanet_plus",
         "enable": true, // The table recognition feature is enabled by default. If you need to disable it, please change the value here to "false".
         "max_time": 400
     }

+ 2 - 2
README_zh-CN.md

@@ -310,8 +310,8 @@ pip install -U "magic-pdf[full]" -i https://mirrors.aliyun.com/pypi/simple
         "enable": true  // 公式识别功能默认是开启的,如果需要关闭请修改此处的值为"false"
     },
     "table-config": {
-        "model": "rapid_table",  
-        "sub_model": "slanet_plus",  // 当model为"rapid_table"时,可以自选sub_model,可选项为"slanet_plus"和"unitable"
+        "model": "rapid_table",
+        "sub_model": "slanet_plus",
         "enable": true, // 表格识别功能默认是开启的,如果需要关闭请修改此处的值为"false"
         "max_time": 400
     }

+ 2 - 1
docker/ascend_npu/requirements.txt

@@ -16,4 +16,5 @@ doclayout-yolo==0.0.2b1
 ftfy
 openai
 pydantic>=2.7.2,<2.11
-transformers>=4.49.0,<5.0.0
+transformers>=4.49.0,<5.0.0
+tqdm>=4.67.1

+ 2 - 1
docker/china/requirements.txt

@@ -16,4 +16,5 @@ doclayout-yolo==0.0.2b1
 ftfy
 openai
 pydantic>=2.7.2,<2.11
-transformers>=4.49.0,<5.0.0
+transformers>=4.49.0,<5.0.0
+tqdm>=4.67.1

+ 2 - 1
docker/global/requirements.txt

@@ -16,4 +16,5 @@ doclayout-yolo==0.0.2b1
 ftfy
 openai
 pydantic>=2.7.2,<2.11
-transformers>=4.49.0,<5.0.0
+transformers>=4.49.0,<5.0.0
+tqdm>=4.67.1

+ 94 - 82
magic_pdf/model/batch_analyze.py

@@ -1,8 +1,7 @@
 import time
-
 import cv2
-import torch
 from loguru import logger
+from tqdm import tqdm
 
 from magic_pdf.config.constants import MODEL_NAME
 from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
@@ -52,9 +51,9 @@ class BatchAnalyze:
                 layout_images, YOLO_LAYOUT_BASE_BATCH_SIZE
             )
 
-        logger.info(
-            f'layout time: {round(time.time() - layout_start_time, 2)}, image num: {len(images)}'
-        )
+        # logger.info(
+        #     f'layout time: {round(time.time() - layout_start_time, 2)}, image num: {len(images)}'
+        # )
 
         if self.model.apply_formula:
             # 公式检测
@@ -63,9 +62,9 @@ class BatchAnalyze:
                 # images, self.batch_ratio * MFD_BASE_BATCH_SIZE
                 images, MFD_BASE_BATCH_SIZE
             )
-            logger.info(
-                f'mfd time: {round(time.time() - mfd_start_time, 2)}, image num: {len(images)}'
-            )
+            # logger.info(
+            #     f'mfd time: {round(time.time() - mfd_start_time, 2)}, image num: {len(images)}'
+            # )
 
             # 公式识别
             mfr_start_time = time.time()
@@ -78,104 +77,117 @@ class BatchAnalyze:
             for image_index in range(len(images)):
                 images_layout_res[image_index] += images_formula_list[image_index]
                 mfr_count += len(images_formula_list[image_index])
-            logger.info(
-                f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {mfr_count}'
-            )
+            # logger.info(
+            #     f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {mfr_count}'
+            # )
 
         # 清理显存
-        clean_vram(self.model.device, vram_threshold=8)
+        # clean_vram(self.model.device, vram_threshold=8)
 
-        det_time = 0
-        det_count = 0
-        table_time = 0
-        table_count = 0
-        # reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
+        ocr_res_list_all_page = []
+        table_res_list_all_page = []
         for index in range(len(images)):
             _, ocr_enable, _lang = images_with_extra_info[index]
-            self.model = self.model_manager.get_model(ocr_enable, self.show_log, _lang, self.layout_model, self.formula_enable, self.table_enable)
             layout_res = images_layout_res[index]
             np_array_img = images[index]
 
             ocr_res_list, table_res_list, single_page_mfdetrec_res = (
                 get_res_list_from_layout_res(layout_res)
             )
-            # ocr识别
-            det_start = time.time()
+
+            ocr_res_list_all_page.append({'ocr_res_list':ocr_res_list,
+                                          'lang':_lang,
+                                          'ocr_enable':ocr_enable,
+                                          'np_array_img':np_array_img,
+                                          'single_page_mfdetrec_res':single_page_mfdetrec_res,
+                                          'layout_res':layout_res,
+                                          })
+
+            for table_res in table_res_list:
+                table_img, _ = crop_img(table_res, np_array_img)
+                table_res_list_all_page.append({'table_res':table_res,
+                                                'lang':_lang,
+                                                'table_img':table_img,
+                                              })
+
+        # 文本框检测
+        det_start = time.time()
+        det_count = 0
+        # for ocr_res_list_dict in ocr_res_list_all_page:
+        for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"):
             # Process each area that requires OCR processing
-            for res in ocr_res_list:
+            _lang = ocr_res_list_dict['lang']
+            # Get OCR results for this language's images
+            atom_model_manager = AtomModelSingleton()
+            ocr_model = atom_model_manager.get_atom_model(
+                atom_model_name='ocr',
+                ocr_show_log=False,
+                det_db_box_thresh=0.3,
+                lang=_lang
+            )
+            for res in ocr_res_list_dict['ocr_res_list']:
                 new_image, useful_list = crop_img(
-                    res, np_array_img, crop_paste_x=50, crop_paste_y=50
+                    res, ocr_res_list_dict['np_array_img'], crop_paste_x=50, crop_paste_y=50
                 )
                 adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
-                    single_page_mfdetrec_res, useful_list
+                    ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
                 )
 
-                # OCR recognition
+                # OCR-det
                 new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
-
-                # if ocr_enable:
-                #     ocr_res = self.model.ocr_model.ocr(
-                #         new_image, mfd_res=adjusted_mfdetrec_res
-                #     )[0]
-                # else:
-                ocr_res = self.model.ocr_model.ocr(
+                ocr_res = ocr_model.ocr(
                     new_image, mfd_res=adjusted_mfdetrec_res, rec=False
                 )[0]
 
                 # Integration results
                 if ocr_res:
-                    ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_enable, new_image, _lang)
-                    layout_res.extend(ocr_result_list)
-            det_time += time.time() - det_start
-            det_count += len(ocr_res_list)
-
-            # 表格识别 table recognition
-            if self.model.apply_table:
-                table_start = time.time()
-                for res in table_res_list:
-                    new_image, _ = crop_img(res, np_array_img)
-                    single_table_start_time = time.time()
-                    html_code = None
-                    if self.model.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
-                        with torch.no_grad():
-                            table_result = self.model.table_model.predict(
-                                new_image, 'html'
-                            )
-                            if len(table_result) > 0:
-                                html_code = table_result[0]
-                    elif self.model.table_model_name == MODEL_NAME.TABLE_MASTER:
-                        html_code = self.model.table_model.img2html(new_image)
-                    elif self.model.table_model_name == MODEL_NAME.RAPID_TABLE:
-                        html_code, table_cell_bboxes, logic_points, elapse = (
-                            self.model.table_model.predict(new_image)
-                        )
-                    run_time = time.time() - single_table_start_time
-                    if run_time > self.model.table_max_time:
-                        logger.warning(
-                            f'table recognition processing exceeds max time {self.model.table_max_time}s'
-                        )
-                    # 判断是否返回正常
-                    if html_code:
-                        expected_ending = html_code.strip().endswith(
-                            '</html>'
-                        ) or html_code.strip().endswith('</table>')
-                        if expected_ending:
-                            res['html'] = html_code
-                        else:
-                            logger.warning(
-                                'table recognition processing fails, not found expected HTML table end'
-                            )
-                    else:
-                        logger.warning(
-                            'table recognition processing fails, not get html return'
-                        )
-                table_time += time.time() - table_start
-                table_count += len(table_res_list)
+                    ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang)
+                    ocr_res_list_dict['layout_res'].extend(ocr_result_list)
+            det_count += len(ocr_res_list_dict['ocr_res_list'])
+        # logger.info(f'ocr-det time: {round(time.time()-det_start, 2)}, image num: {det_count}')
 
 
-        logger.info(f'ocr-det time: {round(det_time, 2)}, image num: {det_count}')
+        # 表格识别 table recognition
         if self.model.apply_table:
-            logger.info(f'table time: {round(table_time, 2)}, image num: {table_count}')
+            table_start = time.time()
+            table_count = 0
+            # for table_res_list_dict in table_res_list_all_page:
+            for table_res_dict in tqdm(table_res_list_all_page, desc="Table Predict"):
+                _lang = table_res_dict['lang']
+                atom_model_manager = AtomModelSingleton()
+                ocr_engine = atom_model_manager.get_atom_model(
+                    atom_model_name='ocr',
+                    ocr_show_log=False,
+                    det_db_box_thresh=0.5,
+                    det_db_unclip_ratio=1.6,
+                    lang=_lang
+                )
+                table_model = atom_model_manager.get_atom_model(
+                    atom_model_name='table',
+                    table_model_name='rapid_table',
+                    table_model_path='',
+                    table_max_time=400,
+                    device='cpu',
+                    ocr_engine=ocr_engine,
+                    table_sub_model_name='slanet_plus'
+                )
+                html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(table_res_dict['table_img'])
+                # 判断是否返回正常
+                if html_code:
+                    expected_ending = html_code.strip().endswith(
+                        '</html>'
+                    ) or html_code.strip().endswith('</table>')
+                    if expected_ending:
+                        table_res_dict['table_res']['html'] = html_code
+                    else:
+                        logger.warning(
+                            'table recognition processing fails, not found expected HTML table end'
+                        )
+                else:
+                    logger.warning(
+                        'table recognition processing fails, not get html return'
+                    )
+            # logger.info(f'table time: {round(time.time() - table_start, 2)}, image num: {len(table_res_list_all_page)}')
 
         # Create dictionaries to store items by language
         need_ocr_lists_by_lang = {}  # Dict of lists for each language
@@ -219,7 +231,7 @@ class BatchAnalyze:
                         det_db_box_thresh=0.3,
                         lang=lang
                     )
-                    ocr_res_list = ocr_model.ocr(img_crop_list, det=False)[0]
+                    ocr_res_list = ocr_model.ocr(img_crop_list, det=False, tqdm_enable=True)[0]
 
                     # Verify we have matching counts
                     assert len(ocr_res_list) == len(
@@ -234,7 +246,7 @@ class BatchAnalyze:
                     total_processed += len(img_crop_list)
 
             rec_time += time.time() - rec_start
-            logger.info(f'ocr-rec time: {round(rec_time, 2)}, total images processed: {total_processed}')
+            # logger.info(f'ocr-rec time: {round(rec_time, 2)}, total images processed: {total_processed}')
 
 
 

+ 15 - 19
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -188,7 +188,7 @@ def batch_doc_analyze(
     formula_enable=None,
     table_enable=None,
 ):
-    MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
+    MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 200))
     batch_size = MIN_BATCH_INFERENCE_SIZE
     images = []
     page_wh_list = []
@@ -245,8 +245,7 @@ def may_batch_image_analyze(
 
     model_manager = ModelSingleton()
 
-    images = [image for image, _, _ in images_with_extra_info]
-    batch_analyze = False
+    # images = [image for image, _, _ in images_with_extra_info]
     batch_ratio = 1
     device = get_device()
 
@@ -269,25 +268,22 @@ def may_batch_image_analyze(
             else:
                 batch_ratio = 1
             logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
-            # batch_analyze = True
-    elif str(device).startswith('mps'):
-        # batch_analyze = True
-        pass
 
-    doc_analyze_start = time.time()
+
+    # doc_analyze_start = time.time()
 
     batch_model = BatchAnalyze(model_manager, batch_ratio, show_log, layout_model, formula_enable, table_enable)
     results = batch_model(images_with_extra_info)
 
-    gc_start = time.time()
+    # gc_start = time.time()
     clean_memory(get_device())
-    gc_time = round(time.time() - gc_start, 2)
-    logger.info(f'gc time: {gc_time}')
-
-    doc_analyze_time = round(time.time() - doc_analyze_start, 2)
-    doc_analyze_speed = round(len(images) / doc_analyze_time, 2)
-    logger.info(
-        f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
-        f' speed: {doc_analyze_speed} pages/second'
-    )
-    return (idx, results)
+    # gc_time = round(time.time() - gc_start, 2)
+    # logger.debug(f'gc time: {gc_time}')
+
+    # doc_analyze_time = round(time.time() - doc_analyze_start, 2)
+    # doc_analyze_speed = round(len(images) / doc_analyze_time, 2)
+    # logger.debug(
+    #     f'doc analyze time: {round(time.time() - doc_analyze_start, 2)},'
+    #     f' speed: {doc_analyze_speed} pages/second'
+    # )
+    return idx, results

+ 3 - 1
magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py

@@ -1,4 +1,5 @@
 from doclayout_yolo import YOLOv10
+from tqdm import tqdm
 
 
 class DocLayoutYOLOModel(object):
@@ -31,7 +32,8 @@ class DocLayoutYOLOModel(object):
 
     def batch_predict(self, images: list, batch_size: int) -> list:
         images_layout_res = []
-        for index in range(0, len(images), batch_size):
+        # for index in range(0, len(images), batch_size):
+        for index in tqdm(range(0, len(images), batch_size), desc="Layout Predict"):
             doclayout_yolo_res = [
                 image_res.cpu()
                 for image_res in self.model.predict(

+ 3 - 1
magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py

@@ -1,3 +1,4 @@
+from tqdm import tqdm
 from ultralytics import YOLO
 
 
@@ -14,7 +15,8 @@ class YOLOv8MFDModel(object):
 
     def batch_predict(self, images: list, batch_size: int) -> list:
         images_mfd_res = []
-        for index in range(0, len(images), batch_size):
+        # for index in range(0, len(images), batch_size):
+        for index in tqdm(range(0, len(images), batch_size), desc="MFD Predict"):
             mfd_res = [
                 image_res.cpu()
                 for image_res in self.mfd_model.predict(

+ 14 - 6
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py

@@ -1,5 +1,6 @@
 import torch
 from torch.utils.data import DataLoader, Dataset
+from tqdm import tqdm
 
 
 class MathDataset(Dataset):
@@ -107,12 +108,19 @@ class UnimernetModel(object):
 
         # Process batches and store results
         mfr_res = []
-        for mf_img in dataloader:
-            mf_img = mf_img.to(dtype=self.model.dtype)
-            mf_img = mf_img.to(self.device)
-            with torch.no_grad():
-                output = self.model.generate({"image": mf_img})
-            mfr_res.extend(output["fixed_str"])
+        # for mf_img in dataloader:
+
+        with tqdm(total=len(sorted_images), desc="MFR Predict") as pbar:
+            for index, mf_img in enumerate(dataloader):
+                mf_img = mf_img.to(dtype=self.model.dtype)
+                mf_img = mf_img.to(self.device)
+                with torch.no_grad():
+                    output = self.model.generate({"image": mf_img})
+                mfr_res.extend(output["fixed_str"])
+
+                # 更新进度条,每次增加batch_size,但要注意最后一个batch可能不足batch_size
+                current_batch_size = min(batch_size, len(sorted_images) - index * batch_size)
+                pbar.update(current_batch_size)
 
         # Restore original order
         unsorted_results = [""] * len(mfr_res)

+ 3 - 1
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorch_paddle.py

@@ -72,6 +72,7 @@ class PytorchPaddleOCR(TextSystem):
         kwargs['det_model_path'] = os.path.join(ocr_models_dir, det)
         kwargs['rec_model_path'] = os.path.join(ocr_models_dir, rec)
         kwargs['rec_char_dict_path'] = os.path.join(root_dir, 'pytorchocr', 'utils', 'resources', 'dict', dict_file)
+        # kwargs['rec_batch_num'] = 8
 
         kwargs['device'] = get_device()
 
@@ -86,6 +87,7 @@ class PytorchPaddleOCR(TextSystem):
             det=True,
             rec=True,
             mfd_res=None,
+            tqdm_enable=False,
             ):
         assert isinstance(img, (np.ndarray, list, str, bytes))
         if isinstance(img, list) and det == True:
@@ -129,7 +131,7 @@ class PytorchPaddleOCR(TextSystem):
                     if not isinstance(img, list):
                         img = preprocess_image(img)
                         img = [img]
-                    rec_res, elapse = self.text_recognizer(img)
+                    rec_res, elapse = self.text_recognizer(img, tqdm_enable=tqdm_enable)
                     # logger.debug("rec_res num  : {}, elapsed : {}".format(len(rec_res), elapse))
                     ocr_res.append(rec_res)
                 return ocr_res

+ 133 - 122
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_rec.py

@@ -4,6 +4,8 @@ import numpy as np
 import math
 import time
 import torch
+from tqdm import tqdm
+
 from ...pytorchocr.base_ocr_v20 import BaseOCRV20
 from . import pytorchocr_utility as utility
 from ...pytorchocr.postprocess import build_post_process
@@ -286,7 +288,7 @@ class TextRecognizer(BaseOCRV20):
 
         return img
 
-    def __call__(self, img_list):
+    def __call__(self, img_list, tqdm_enable=False):
         img_num = len(img_list)
         # Calculate the aspect ratio of all text bars
         width_list = []
@@ -299,131 +301,140 @@ class TextRecognizer(BaseOCRV20):
         rec_res = [['', 0.0]] * img_num
         batch_num = self.rec_batch_num
         elapse = 0
-        for beg_img_no in range(0, img_num, batch_num):
-            end_img_no = min(img_num, beg_img_no + batch_num)
-            norm_img_batch = []
-            max_wh_ratio = 0
-            for ino in range(beg_img_no, end_img_no):
-                # h, w = img_list[ino].shape[0:2]
-                h, w = img_list[indices[ino]].shape[0:2]
-                wh_ratio = w * 1.0 / h
-                max_wh_ratio = max(max_wh_ratio, wh_ratio)
-            for ino in range(beg_img_no, end_img_no):
-                if self.rec_algorithm == "SAR":
-                    norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
-                        img_list[indices[ino]], self.rec_image_shape)
-                    norm_img = norm_img[np.newaxis, :]
-                    valid_ratio = np.expand_dims(valid_ratio, axis=0)
-                    valid_ratios = []
-                    valid_ratios.append(valid_ratio)
-                    norm_img_batch.append(norm_img)
-
-                elif self.rec_algorithm == "SVTR":
-                    norm_img = self.resize_norm_img_svtr(img_list[indices[ino]],
-                                                         self.rec_image_shape)
-                    norm_img = norm_img[np.newaxis, :]
-                    norm_img_batch.append(norm_img)
-                elif self.rec_algorithm == "SRN":
-                    norm_img = self.process_image_srn(img_list[indices[ino]],
-                                                      self.rec_image_shape, 8,
-                                                      self.max_text_length)
-                    encoder_word_pos_list = []
-                    gsrm_word_pos_list = []
-                    gsrm_slf_attn_bias1_list = []
-                    gsrm_slf_attn_bias2_list = []
-                    encoder_word_pos_list.append(norm_img[1])
-                    gsrm_word_pos_list.append(norm_img[2])
-                    gsrm_slf_attn_bias1_list.append(norm_img[3])
-                    gsrm_slf_attn_bias2_list.append(norm_img[4])
-                    norm_img_batch.append(norm_img[0])
+        # for beg_img_no in range(0, img_num, batch_num):
+        with tqdm(total=img_num, desc='OCR-rec Predict', disable=not tqdm_enable) as pbar:
+            index = 0
+            for beg_img_no in range(0, img_num, batch_num):
+                end_img_no = min(img_num, beg_img_no + batch_num)
+                norm_img_batch = []
+                max_wh_ratio = 0
+                for ino in range(beg_img_no, end_img_no):
+                    # h, w = img_list[ino].shape[0:2]
+                    h, w = img_list[indices[ino]].shape[0:2]
+                    wh_ratio = w * 1.0 / h
+                    max_wh_ratio = max(max_wh_ratio, wh_ratio)
+                for ino in range(beg_img_no, end_img_no):
+                    if self.rec_algorithm == "SAR":
+                        norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
+                            img_list[indices[ino]], self.rec_image_shape)
+                        norm_img = norm_img[np.newaxis, :]
+                        valid_ratio = np.expand_dims(valid_ratio, axis=0)
+                        valid_ratios = []
+                        valid_ratios.append(valid_ratio)
+                        norm_img_batch.append(norm_img)
+
+                    elif self.rec_algorithm == "SVTR":
+                        norm_img = self.resize_norm_img_svtr(img_list[indices[ino]],
+                                                             self.rec_image_shape)
+                        norm_img = norm_img[np.newaxis, :]
+                        norm_img_batch.append(norm_img)
+                    elif self.rec_algorithm == "SRN":
+                        norm_img = self.process_image_srn(img_list[indices[ino]],
+                                                          self.rec_image_shape, 8,
+                                                          self.max_text_length)
+                        encoder_word_pos_list = []
+                        gsrm_word_pos_list = []
+                        gsrm_slf_attn_bias1_list = []
+                        gsrm_slf_attn_bias2_list = []
+                        encoder_word_pos_list.append(norm_img[1])
+                        gsrm_word_pos_list.append(norm_img[2])
+                        gsrm_slf_attn_bias1_list.append(norm_img[3])
+                        gsrm_slf_attn_bias2_list.append(norm_img[4])
+                        norm_img_batch.append(norm_img[0])
+                    elif self.rec_algorithm == "CAN":
+                        norm_img = self.norm_img_can(img_list[indices[ino]],
+                                                     max_wh_ratio)
+                        norm_img = norm_img[np.newaxis, :]
+                        norm_img_batch.append(norm_img)
+                        norm_image_mask = np.ones(norm_img.shape, dtype='float32')
+                        word_label = np.ones([1, 36], dtype='int64')
+                        norm_img_mask_batch = []
+                        word_label_list = []
+                        norm_img_mask_batch.append(norm_image_mask)
+                        word_label_list.append(word_label)
+                    else:
+                        norm_img = self.resize_norm_img(img_list[indices[ino]],
+                                                        max_wh_ratio)
+                        norm_img = norm_img[np.newaxis, :]
+                        norm_img_batch.append(norm_img)
+                norm_img_batch = np.concatenate(norm_img_batch)
+                norm_img_batch = norm_img_batch.copy()
+
+                if self.rec_algorithm == "SRN":
+                    starttime = time.time()
+                    encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
+                    gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
+                    gsrm_slf_attn_bias1_list = np.concatenate(
+                        gsrm_slf_attn_bias1_list)
+                    gsrm_slf_attn_bias2_list = np.concatenate(
+                        gsrm_slf_attn_bias2_list)
+
+                    with torch.no_grad():
+                        inp = torch.from_numpy(norm_img_batch)
+                        encoder_word_pos_inp = torch.from_numpy(encoder_word_pos_list)
+                        gsrm_word_pos_inp = torch.from_numpy(gsrm_word_pos_list)
+                        gsrm_slf_attn_bias1_inp = torch.from_numpy(gsrm_slf_attn_bias1_list)
+                        gsrm_slf_attn_bias2_inp = torch.from_numpy(gsrm_slf_attn_bias2_list)
+
+                        inp = inp.to(self.device)
+                        encoder_word_pos_inp = encoder_word_pos_inp.to(self.device)
+                        gsrm_word_pos_inp = gsrm_word_pos_inp.to(self.device)
+                        gsrm_slf_attn_bias1_inp = gsrm_slf_attn_bias1_inp.to(self.device)
+                        gsrm_slf_attn_bias2_inp = gsrm_slf_attn_bias2_inp.to(self.device)
+
+                        backbone_out = self.net.backbone(inp) # backbone_feat
+                        prob_out = self.net.head(backbone_out, [encoder_word_pos_inp, gsrm_word_pos_inp, gsrm_slf_attn_bias1_inp, gsrm_slf_attn_bias2_inp])
+                    # preds = {"predict": prob_out[2]}
+                    preds = {"predict": prob_out["predict"]}
+
+                elif self.rec_algorithm == "SAR":
+                    starttime = time.time()
+                    # valid_ratios = np.concatenate(valid_ratios)
+                    # inputs = [
+                    #     norm_img_batch,
+                    #     valid_ratios,
+                    # ]
+
+                    with torch.no_grad():
+                        inp = torch.from_numpy(norm_img_batch)
+                        inp = inp.to(self.device)
+                        preds = self.net(inp)
+
                 elif self.rec_algorithm == "CAN":
-                    norm_img = self.norm_img_can(img_list[indices[ino]],
-                                                 max_wh_ratio)
-                    norm_img = norm_img[np.newaxis, :]
-                    norm_img_batch.append(norm_img)
-                    norm_image_mask = np.ones(norm_img.shape, dtype='float32')
-                    word_label = np.ones([1, 36], dtype='int64')
-                    norm_img_mask_batch = []
-                    word_label_list = []
-                    norm_img_mask_batch.append(norm_image_mask)
-                    word_label_list.append(word_label)
-                else:
-                    norm_img = self.resize_norm_img(img_list[indices[ino]],
-                                                    max_wh_ratio)
-                    norm_img = norm_img[np.newaxis, :]
-                    norm_img_batch.append(norm_img)
-            norm_img_batch = np.concatenate(norm_img_batch)
-            norm_img_batch = norm_img_batch.copy()
-
-            if self.rec_algorithm == "SRN":
-                starttime = time.time()
-                encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
-                gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
-                gsrm_slf_attn_bias1_list = np.concatenate(
-                    gsrm_slf_attn_bias1_list)
-                gsrm_slf_attn_bias2_list = np.concatenate(
-                    gsrm_slf_attn_bias2_list)
-
-                with torch.no_grad():
-                    inp = torch.from_numpy(norm_img_batch)
-                    encoder_word_pos_inp = torch.from_numpy(encoder_word_pos_list)
-                    gsrm_word_pos_inp = torch.from_numpy(gsrm_word_pos_list)
-                    gsrm_slf_attn_bias1_inp = torch.from_numpy(gsrm_slf_attn_bias1_list)
-                    gsrm_slf_attn_bias2_inp = torch.from_numpy(gsrm_slf_attn_bias2_list)
-
-                    inp = inp.to(self.device)
-                    encoder_word_pos_inp = encoder_word_pos_inp.to(self.device)
-                    gsrm_word_pos_inp = gsrm_word_pos_inp.to(self.device)
-                    gsrm_slf_attn_bias1_inp = gsrm_slf_attn_bias1_inp.to(self.device)
-                    gsrm_slf_attn_bias2_inp = gsrm_slf_attn_bias2_inp.to(self.device)
-
-                    backbone_out = self.net.backbone(inp) # backbone_feat
-                    prob_out = self.net.head(backbone_out, [encoder_word_pos_inp, gsrm_word_pos_inp, gsrm_slf_attn_bias1_inp, gsrm_slf_attn_bias2_inp])
-                # preds = {"predict": prob_out[2]}
-                preds = {"predict": prob_out["predict"]}
-
-            elif self.rec_algorithm == "SAR":
-                starttime = time.time()
-                # valid_ratios = np.concatenate(valid_ratios)
-                # inputs = [
-                #     norm_img_batch,
-                #     valid_ratios,
-                # ]
-
-                with torch.no_grad():
-                    inp = torch.from_numpy(norm_img_batch)
-                    inp = inp.to(self.device)
-                    preds = self.net(inp)
-
-            elif self.rec_algorithm == "CAN":
-                starttime = time.time()
-                norm_img_mask_batch = np.concatenate(norm_img_mask_batch)
-                word_label_list = np.concatenate(word_label_list)
-                inputs = [norm_img_batch, norm_img_mask_batch, word_label_list]
-
-                inp = [torch.from_numpy(e_i) for e_i in inputs]
-                inp = [e_i.to(self.device) for e_i in inp]
-                with torch.no_grad():
-                    outputs = self.net(inp)
-                    outputs = [v.cpu().numpy() for k, v in enumerate(outputs)]
-
-                preds = outputs
+                    starttime = time.time()
+                    norm_img_mask_batch = np.concatenate(norm_img_mask_batch)
+                    word_label_list = np.concatenate(word_label_list)
+                    inputs = [norm_img_batch, norm_img_mask_batch, word_label_list]
 
-            else:
-                starttime = time.time()
+                    inp = [torch.from_numpy(e_i) for e_i in inputs]
+                    inp = [e_i.to(self.device) for e_i in inp]
+                    with torch.no_grad():
+                        outputs = self.net(inp)
+                        outputs = [v.cpu().numpy() for k, v in enumerate(outputs)]
 
-                with torch.no_grad():
-                    inp = torch.from_numpy(norm_img_batch)
-                    inp = inp.to(self.device)
-                    prob_out = self.net(inp)
+                    preds = outputs
 
-                if isinstance(prob_out, list):
-                    preds = [v.cpu().numpy() for v in prob_out]
                 else:
-                    preds = prob_out.cpu().numpy()
+                    starttime = time.time()
+
+                    with torch.no_grad():
+                        inp = torch.from_numpy(norm_img_batch)
+                        inp = inp.to(self.device)
+                        prob_out = self.net(inp)
+
+                    if isinstance(prob_out, list):
+                        preds = [v.cpu().numpy() for v in prob_out]
+                    else:
+                        preds = prob_out.cpu().numpy()
+
+                rec_result = self.postprocess_op(preds)
+                for rno in range(len(rec_result)):
+                    rec_res[indices[beg_img_no + rno]] = rec_result[rno]
+                elapse += time.time() - starttime
+
+                # 更新进度条,每次增加batch_size,但要注意最后一个batch可能不足batch_size
+                current_batch_size = min(batch_num, img_num - index * batch_num)
+                index += 1
+                pbar.update(current_batch_size)
 
-            rec_result = self.postprocess_op(preds)
-            for rno in range(len(rec_result)):
-                rec_res[indices[beg_img_no + rno]] = rec_result[rno]
-            elapse += time.time() - starttime
         return rec_res, elapse

+ 1 - 1
magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py

@@ -9,7 +9,7 @@ from magic_pdf.libs.config_reader import get_device
 
 
 class RapidTableModel(object):
-    def __init__(self, ocr_engine, table_sub_model_name):
+    def __init__(self, ocr_engine, table_sub_model_name='slanet_plus'):
         sub_model_list = [model.value for model in ModelType]
         if table_sub_model_name is None:
             input_args = RapidTableInput()

+ 16 - 14
magic_pdf/pdf_parse_union_core_v2.py

@@ -12,6 +12,7 @@ import fitz
 import torch
 import numpy as np
 from loguru import logger
+from tqdm import tqdm
 
 from magic_pdf.config.enums import SupportedPdfParseMethod
 from magic_pdf.config.ocr_content_type import BlockType, ContentType
@@ -932,17 +933,18 @@ def pdf_parse_union(
         logger.warning('end_page_id is out of range, use pdf_docs length')
         end_page_id = len(dataset) - 1
 
-    """初始化启动时间"""
-    start_time = time.time()
+    # """初始化启动时间"""
+    # start_time = time.time()
 
-    for page_id, page in enumerate(dataset):
-        """debug时输出每页解析的耗时."""
-        if debug_mode:
-            time_now = time.time()
-            logger.info(
-                f'page_id: {page_id}, last_page_cost_time: {round(time.time() - start_time, 2)}'
-            )
-            start_time = time_now
+    # for page_id, page in enumerate(dataset):
+    for page_id, page in tqdm(enumerate(dataset), total=len(dataset), desc="Processing pages"):
+        # """debug时输出每页解析的耗时."""
+        # if debug_mode:
+            # time_now = time.time()
+            # logger.info(
+            #     f'page_id: {page_id}, last_page_cost_time: {round(time.time() - start_time, 2)}'
+            # )
+            # start_time = time_now
 
         """解析pdf中的每一页"""
         if start_page_id <= page_id <= end_page_id:
@@ -987,8 +989,8 @@ def pdf_parse_union(
             det_db_box_thresh=0.3,
             lang=lang
         )
-        rec_start = time.time()
-        ocr_res_list = ocr_model.ocr(img_crop_list, det=False)[0]
+        # rec_start = time.time()
+        ocr_res_list = ocr_model.ocr(img_crop_list, det=False, tqdm_enable=True)[0]
         # Verify we have matching counts
         assert len(ocr_res_list) == len(need_ocr_list), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_list)}'
         # Process OCR results for this language
@@ -996,8 +998,8 @@ def pdf_parse_union(
             ocr_text, ocr_score = ocr_res_list[index]
             span['content'] = ocr_text
             span['score'] = float(round(ocr_score, 2))
-        rec_time = time.time() - rec_start
-        logger.info(f'ocr-dynamic-rec time: {round(rec_time, 2)}, total images processed: {len(img_crop_list)}')
+        # rec_time = time.time() - rec_start
+        # logger.info(f'ocr-dynamic-rec time: {round(rec_time, 2)}, total images processed: {len(img_crop_list)}')
 
 
     """分段"""

+ 1 - 0
requirements.txt

@@ -11,4 +11,5 @@ torch>=2.2.2,!=2.5.0,!=2.5.1,<=2.6.0
 torchvision
 transformers>=4.49.0,<5.0.0
 pdfminer.six==20231228
+tqdm>=4.67.1
 # The requirements.txt must ensure that only necessary external dependencies are introduced. If there are new dependencies to add, please contact the project administrator.