Browse Source

Merge pull request #2536 from seedclaimer/dev

支持batch-ocr-det,速度约提升3倍(200页pdf在3090上)
Xiaomeng Zhao 5 months ago
parent
commit
763fbc6097

+ 150 - 38
magic_pdf/model/batch_analyze.py

@@ -2,6 +2,8 @@ import time
 import cv2
 from loguru import logger
 from tqdm import tqdm
+from collections import defaultdict
+import numpy as np
 
 from magic_pdf.config.constants import MODEL_NAME
 from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
@@ -16,27 +18,28 @@ MFR_BASE_BATCH_SIZE = 16
 
 
 class BatchAnalyze:
-    def __init__(self, model_manager, batch_ratio: int, show_log, layout_model, formula_enable, table_enable):
+    def __init__(self, model_manager, batch_ratio: int, show_log, layout_model, formula_enable, table_enable, enable_ocr_det_batch=True):
         self.model_manager = model_manager
         self.batch_ratio = batch_ratio
         self.show_log = show_log
         self.layout_model = layout_model
         self.formula_enable = formula_enable
         self.table_enable = table_enable
+        self.enable_ocr_det_batch = enable_ocr_det_batch
 
     def __call__(self, images_with_extra_info: list) -> list:
         if len(images_with_extra_info) == 0:
             return []
-    
+
         images_layout_res = []
         layout_start_time = time.time()
         self.model = self.model_manager.get_model(
             ocr=True,
             show_log=self.show_log,
-            lang = None,
-            layout_model = self.layout_model,
-            formula_enable = self.formula_enable,
-            table_enable = self.table_enable,
+            lang=None,
+            layout_model=self.layout_model,
+            formula_enable=self.formula_enable,
+            table_enable=self.table_enable,
         )
 
         images = [image for image, _, _ in images_with_extra_info]
@@ -101,43 +104,152 @@ class BatchAnalyze:
                 get_res_list_from_layout_res(layout_res)
             )
 
-            ocr_res_list_all_page.append({'ocr_res_list':ocr_res_list,
-                                          'lang':_lang,
-                                          'ocr_enable':ocr_enable,
-                                          'np_array_img':np_array_img,
-                                          'single_page_mfdetrec_res':single_page_mfdetrec_res,
-                                          'layout_res':layout_res,
-                                          })
+            ocr_res_list_all_page.append({
+                'ocr_res_list': ocr_res_list,
+                'lang': _lang,
+                'ocr_enable': ocr_enable,
+                'np_array_img': np_array_img,
+                'single_page_mfdetrec_res': single_page_mfdetrec_res,
+                'layout_res': layout_res,
+            })
 
             for table_res in table_res_list:
                 table_img, _ = crop_img(table_res, np_array_img)
-                table_res_list_all_page.append({'table_res':table_res,
-                                                'lang':_lang,
-                                                'table_img':table_img,
-                                              })
-
-        # 文本框检测
-        det_start = time.time()
-        det_count = 0
-        # for ocr_res_list_dict in ocr_res_list_all_page:
-        for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"):
-            # Process each area that requires OCR processing
-            _lang = ocr_res_list_dict['lang']
-            # Get OCR results for this language's images
-            atom_model_manager = AtomModelSingleton()
-            ocr_model = atom_model_manager.get_atom_model(
-                atom_model_name='ocr',
-                ocr_show_log=False,
-                det_db_box_thresh=0.3,
-                lang=_lang
-            )
-            for res in ocr_res_list_dict['ocr_res_list']:
-                new_image, useful_list = crop_img(
-                    res, ocr_res_list_dict['np_array_img'], crop_paste_x=50, crop_paste_y=50
+                table_res_list_all_page.append({
+                    'table_res': table_res,
+                    'lang': _lang,
+                    'table_img': table_img,
+                })
+
+        # OCR检测处理
+        if self.enable_ocr_det_batch:
+            # 批处理模式 - 按语言和分辨率分组
+            # 收集所有需要OCR检测的裁剪图像
+            all_cropped_images_info = []
+
+            for ocr_res_list_dict in ocr_res_list_all_page:
+                _lang = ocr_res_list_dict['lang']
+
+                for res in ocr_res_list_dict['ocr_res_list']:
+                    new_image, useful_list = crop_img(
+                        res, ocr_res_list_dict['np_array_img'], crop_paste_x=50, crop_paste_y=50
+                    )
+                    adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
+                        ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
+                    )
+
+                    # BGR转换
+                    new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
+
+                    all_cropped_images_info.append((
+                        new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang
+                    ))
+
+            # 按语言分组
+            lang_groups = defaultdict(list)
+            for crop_info in all_cropped_images_info:
+                lang = crop_info[5]
+                lang_groups[lang].append(crop_info)
+
+            # 对每种语言按分辨率分组并批处理
+            for lang, lang_crop_list in lang_groups.items():
+                if not lang_crop_list:
+                    continue
+
+                # logger.info(f"Processing OCR detection for language {lang} with {len(lang_crop_list)} images")
+
+                # 获取OCR模型
+                atom_model_manager = AtomModelSingleton()
+                ocr_model = atom_model_manager.get_atom_model(
+                    atom_model_name='ocr',
+                    ocr_show_log=False,
+                    det_db_box_thresh=0.3,
+                    lang=lang
                 )
-                adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
-                    ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
+
+                # 按分辨率分组并同时完成padding
+                resolution_groups = defaultdict(list)
+                for crop_info in lang_crop_list:
+                    cropped_img = crop_info[0]
+                    h, w = cropped_img.shape[:2]
+                    # 使用更大的分组容差,减少分组数量
+                    # 将尺寸标准化到32的倍数
+                    normalized_h = ((h + 32) // 32) * 32  # 向上取整到32的倍数
+                    normalized_w = ((w + 32) // 32) * 32
+                    group_key = (normalized_h, normalized_w)
+                    resolution_groups[group_key].append(crop_info)
+
+                # 对每个分辨率组进行批处理
+                for group_key, group_crops in tqdm(resolution_groups.items(), desc=f"OCR-det {lang}"):
+                    raw_images = [crop_info[0] for crop_info in group_crops]
+
+                    # 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
+                    max_h = max(img.shape[0] for img in raw_images)
+                    max_w = max(img.shape[1] for img in raw_images)
+                    target_h = ((max_h + 32 - 1) // 32) * 32
+                    target_w = ((max_w + 32 - 1) // 32) * 32
+
+                    # 对所有图像进行padding到统一尺寸
+                    batch_images = []
+                    for img in raw_images:
+                        h, w = img.shape[:2]
+                        # 创建目标尺寸的白色背景
+                        padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
+                        # 将原图像粘贴到左上角
+                        padded_img[:h, :w] = img
+                        batch_images.append(padded_img)
+
+                    # 批处理检测
+                    batch_size = min(len(batch_images), self.batch_ratio * 16)  # 增加批处理大小
+                    # logger.debug(f"OCR-det batch: {batch_size} images, target size: {target_h}x{target_w}")
+                    batch_results = ocr_model.text_detector.batch_predict(batch_images, batch_size)
+
+                    # 处理批处理结果
+                    for i, (crop_info, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)):
+                        new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = crop_info
+
+                        if dt_boxes is not None:
+                            # 构造OCR结果格式 - 每个box应该是4个点的列表
+                            ocr_res = [box.tolist() for box in dt_boxes]
+
+                            if ocr_res:
+                                ocr_result_list = get_ocr_result_list(
+                                    ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang
+                                )
+
+                                if res["category_id"] == 3:
+                                    # ocr_result_list中所有bbox的面积之和
+                                    ocr_res_area = sum(get_coords_and_area(ocr_res_item)[4] for ocr_res_item in ocr_result_list if 'poly' in ocr_res_item)
+                                    # 求ocr_res_area和res的面积的比值
+                                    res_area = get_coords_and_area(res)[4]
+                                    if res_area > 0:
+                                        ratio = ocr_res_area / res_area
+                                        if ratio > 0.25:
+                                            res["category_id"] = 1
+                                        else:
+                                            continue
+
+                                ocr_res_list_dict['layout_res'].extend(ocr_result_list)
+        else:
+            # 原始单张处理模式
+            for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"):
+                # Process each area that requires OCR processing
+                _lang = ocr_res_list_dict['lang']
+                # Get OCR results for this language's images
+                atom_model_manager = AtomModelSingleton()
+                ocr_model = atom_model_manager.get_atom_model(
+                    atom_model_name='ocr',
+                    ocr_show_log=False,
+                    det_db_box_thresh=0.3,
+                    lang=_lang
                 )
+                for res in ocr_res_list_dict['ocr_res_list']:
+                    new_image, useful_list = crop_img(
+                        res, ocr_res_list_dict['np_array_img'], crop_paste_x=50, crop_paste_y=50
+                    )
+                    adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
+                        ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
+                    )
 
                 # OCR-det
                 new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)

+ 122 - 0
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_det.py

@@ -117,6 +117,128 @@ class TextDetector(BaseOCRV20):
         self.net.eval()
         self.net.to(self.device)
 
+    def _batch_process_same_size(self, img_list):
+        """
+            对相同尺寸的图像进行批处理
+
+            Args:
+                img_list: 相同尺寸的图像列表
+
+            Returns:
+                batch_results: 批处理结果列表
+                total_elapse: 总耗时
+            """
+        starttime = time.time()
+
+        # 预处理所有图像
+        batch_data = []
+        batch_shapes = []
+        ori_imgs = []
+
+        for img in img_list:
+            ori_im = img.copy()
+            ori_imgs.append(ori_im)
+
+            data = {'image': img}
+            data = transform(data, self.preprocess_op)
+            if data is None:
+                # 如果预处理失败,返回空结果
+                return [(None, 0) for _ in img_list], 0
+
+            img_processed, shape_list = data
+            batch_data.append(img_processed)
+            batch_shapes.append(shape_list)
+
+        # 堆叠成批处理张量
+        try:
+            batch_tensor = np.stack(batch_data, axis=0)
+            batch_shapes = np.stack(batch_shapes, axis=0)
+        except Exception as e:
+            # 如果堆叠失败,回退到逐个处理
+            batch_results = []
+            for img in img_list:
+                dt_boxes, elapse = self.__call__(img)
+                batch_results.append((dt_boxes, elapse))
+            return batch_results, time.time() - starttime
+
+        # 批处理推理
+        with torch.no_grad():
+            inp = torch.from_numpy(batch_tensor)
+            inp = inp.to(self.device)
+            outputs = self.net(inp)
+
+        # 处理输出
+        preds = {}
+        if self.det_algorithm == "EAST":
+            preds['f_geo'] = outputs['f_geo'].cpu().numpy()
+            preds['f_score'] = outputs['f_score'].cpu().numpy()
+        elif self.det_algorithm == 'SAST':
+            preds['f_border'] = outputs['f_border'].cpu().numpy()
+            preds['f_score'] = outputs['f_score'].cpu().numpy()
+            preds['f_tco'] = outputs['f_tco'].cpu().numpy()
+            preds['f_tvo'] = outputs['f_tvo'].cpu().numpy()
+        elif self.det_algorithm in ['DB', 'PSE', 'DB++']:
+            preds['maps'] = outputs['maps'].cpu().numpy()
+        elif self.det_algorithm == 'FCE':
+            for i, (k, output) in enumerate(outputs.items()):
+                preds['level_{}'.format(i)] = output.cpu().numpy()
+        else:
+            raise NotImplementedError
+
+        # 后处理每个图像的结果
+        batch_results = []
+        total_elapse = time.time() - starttime
+
+        for i in range(len(img_list)):
+            # 提取单个图像的预测结果
+            single_preds = {}
+            for key, value in preds.items():
+                if isinstance(value, np.ndarray):
+                    single_preds[key] = value[i:i + 1]  # 保持批次维度
+                else:
+                    single_preds[key] = value
+
+            # 后处理
+            post_result = self.postprocess_op(single_preds, batch_shapes[i:i + 1])
+            dt_boxes = post_result[0]['points']
+
+            # 过滤和裁剪检测框
+            if (self.det_algorithm == "SAST" and
+                self.det_sast_polygon) or (self.det_algorithm in ["PSE", "FCE"] and
+                                           self.postprocess_op.box_type == 'poly'):
+                dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_imgs[i].shape)
+            else:
+                dt_boxes = self.filter_tag_det_res(dt_boxes, ori_imgs[i].shape)
+
+            batch_results.append((dt_boxes, total_elapse / len(img_list)))
+
+        return batch_results, total_elapse
+
+    def batch_predict(self, img_list, max_batch_size=8):
+        """
+        批处理预测方法,支持多张图像同时检测
+
+        Args:
+            img_list: 图像列表
+            max_batch_size: 最大批处理大小
+
+        Returns:
+            batch_results: 批处理结果列表,每个元素为(dt_boxes, elapse)
+        """
+        if not img_list:
+            return []
+
+        batch_results = []
+
+        # 分批处理
+        for i in range(0, len(img_list), max_batch_size):
+            batch_imgs = img_list[i:i + max_batch_size]
+            # assert尺寸一致
+            batch_dt_boxes, batch_elapse = self._batch_process_same_size(batch_imgs)
+            batch_results.extend(batch_dt_boxes)
+
+        return batch_results
+
     def order_points_clockwise(self, pts):
         """
         reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py