Эх сурвалжийг харах

refactor: optimize OCR batch processing and enhance image cropping logic

myhloli 5 сар өмнө
parent
commit
73f8503514

+ 136 - 136
mineru/backend/pipeline/batch_analyze.py

@@ -92,144 +92,100 @@ class BatchAnalyze:
                                                 'table_img':table_img,
                                               })
 
-                # OCR检测处理
-                if self.enable_ocr_det_batch:
-                    # 批处理模式 - 按语言和分辨率分组
-                    # 收集所有需要OCR检测的裁剪图像
-                    all_cropped_images_info = []
-
-                    for ocr_res_list_dict in ocr_res_list_all_page:
-                        _lang = ocr_res_list_dict['lang']
-
-                        for res in ocr_res_list_dict['ocr_res_list']:
-                            new_image, useful_list = crop_img(
-                                res, ocr_res_list_dict['pil_img'], crop_paste_x=50, crop_paste_y=50
-                            )
-                            adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
-                                ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
-                            )
-
-                            # BGR转换
-                            new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
-
-                            all_cropped_images_info.append((
-                                new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang
-                            ))
-
-                    # 按语言分组
-                    lang_groups = defaultdict(list)
-                    for crop_info in all_cropped_images_info:
-                        lang = crop_info[5]
-                        lang_groups[lang].append(crop_info)
-
-                    # 对每种语言按分辨率分组并批处理
-                    for lang, lang_crop_list in lang_groups.items():
-                        if not lang_crop_list:
-                            continue
-
-                        # logger.info(f"Processing OCR detection for language {lang} with {len(lang_crop_list)} images")
-
-                        # 获取OCR模型
-                        ocr_model = atom_model_manager.get_atom_model(
-                            atom_model_name='ocr',
-                            ocr_show_log=False,
-                            det_db_box_thresh=0.3,
-                            lang=lang
-                        )
+        # OCR检测处理
+        if self.enable_ocr_det_batch:
+            # 批处理模式 - 按语言和分辨率分组
+            # 收集所有需要OCR检测的裁剪图像
+            all_cropped_images_info = []
+
+            for ocr_res_list_dict in ocr_res_list_all_page:
+                _lang = ocr_res_list_dict['lang']
+
+                for res in ocr_res_list_dict['ocr_res_list']:
+                    new_image, useful_list = crop_img(
+                        res, ocr_res_list_dict['pil_img'], crop_paste_x=50, crop_paste_y=50
+                    )
+                    adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
+                        ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
+                    )
+
+                    # BGR转换
+                    new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
+
+                    all_cropped_images_info.append((
+                        new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang
+                    ))
+
+            # 按语言分组
+            lang_groups = defaultdict(list)
+            for crop_info in all_cropped_images_info:
+                lang = crop_info[5]
+                lang_groups[lang].append(crop_info)
+
+            # 对每种语言按分辨率分组并批处理
+            for lang, lang_crop_list in lang_groups.items():
+                if not lang_crop_list:
+                    continue
+
+                logger.info(f"Processing OCR detection for language {lang} with {len(lang_crop_list)} images")
+
+                # 获取OCR模型
+                ocr_model = atom_model_manager.get_atom_model(
+                    atom_model_name='ocr',
+                    ocr_show_log=False,
+                    det_db_box_thresh=0.3,
+                    lang=lang
+                )
+
+                # 按分辨率分组并同时完成padding
+                resolution_groups = defaultdict(list)
+                for crop_info in lang_crop_list:
+                    cropped_img = crop_info[0]
+                    h, w = cropped_img.shape[:2]
+                    # 使用更大的分组容差,减少分组数量
+                    # 将尺寸标准化到32的倍数
+                    normalized_h = ((h + 32) // 32) * 32  # 向上取整到32的倍数
+                    normalized_w = ((w + 32) // 32) * 32
+                    group_key = (normalized_h, normalized_w)
+                    resolution_groups[group_key].append(crop_info)
+
+                # 对每个分辨率组进行批处理
+                for group_key, group_crops in tqdm(resolution_groups.items(), desc=f"OCR-det {lang}"):
+                    raw_images = [crop_info[0] for crop_info in group_crops]
+
+                    # 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
+                    max_h = max(img.shape[0] for img in raw_images)
+                    max_w = max(img.shape[1] for img in raw_images)
+                    target_h = ((max_h + 32 - 1) // 32) * 32
+                    target_w = ((max_w + 32 - 1) // 32) * 32
+
+                    # 对所有图像进行padding到统一尺寸
+                    batch_images = []
+                    for img in raw_images:
+                        h, w = img.shape[:2]
+                        # 创建目标尺寸的白色背景
+                        padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
+                        # 将原图像粘贴到左上角
+                        padded_img[:h, :w] = img
+                        batch_images.append(padded_img)
+
+                    # 批处理检测
+                    batch_size = min(len(batch_images), self.batch_ratio * 16)  # 增加批处理大小
+                    logger.debug(f"OCR-det batch: {batch_size} images, target size: {target_h}x{target_w}")
+                    batch_results = ocr_model.text_detector.batch_predict(batch_images, batch_size)
+
+                    # 处理批处理结果
+                    for i, (crop_info, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)):
+                        new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = crop_info
+
+                        if dt_boxes is not None:
+                            # 构造OCR结果格式 - 每个box应该是4个点的列表
+                            ocr_res = [box.tolist() for box in dt_boxes]
 
-                        # 按分辨率分组并同时完成padding
-                        resolution_groups = defaultdict(list)
-                        for crop_info in lang_crop_list:
-                            cropped_img = crop_info[0]
-                            h, w = cropped_img.shape[:2]
-                            # 使用更大的分组容差,减少分组数量
-                            # 将尺寸标准化到32的倍数
-                            normalized_h = ((h + 32) // 32) * 32  # 向上取整到32的倍数
-                            normalized_w = ((w + 32) // 32) * 32
-                            group_key = (normalized_h, normalized_w)
-                            resolution_groups[group_key].append(crop_info)
-
-                        # 对每个分辨率组进行批处理
-                        for group_key, group_crops in tqdm(resolution_groups.items(), desc=f"OCR-det {lang}"):
-                            raw_images = [crop_info[0] for crop_info in group_crops]
-
-                            # 计算目标尺寸(组内最大尺寸,向上取整到32的倍数)
-                            max_h = max(img.shape[0] for img in raw_images)
-                            max_w = max(img.shape[1] for img in raw_images)
-                            target_h = ((max_h + 32 - 1) // 32) * 32
-                            target_w = ((max_w + 32 - 1) // 32) * 32
-
-                            # 对所有图像进行padding到统一尺寸
-                            batch_images = []
-                            for img in raw_images:
-                                h, w = img.shape[:2]
-                                # 创建目标尺寸的白色背景
-                                padded_img = np.ones((target_h, target_w, 3), dtype=np.uint8) * 255
-                                # 将原图像粘贴到左上角
-                                padded_img[:h, :w] = img
-                                batch_images.append(padded_img)
-
-                            # 批处理检测
-                            batch_size = min(len(batch_images), self.batch_ratio * 16)  # 增加批处理大小
-                            # logger.debug(f"OCR-det batch: {batch_size} images, target size: {target_h}x{target_w}")
-                            batch_results = ocr_model.text_detector.batch_predict(batch_images, batch_size)
-
-                            # 处理批处理结果
-                            for i, (crop_info, (dt_boxes, elapse)) in enumerate(zip(group_crops, batch_results)):
-                                new_image, useful_list, ocr_res_list_dict, res, adjusted_mfdetrec_res, _lang = crop_info
-
-                                if dt_boxes is not None:
-                                    # 构造OCR结果格式 - 每个box应该是4个点的列表
-                                    ocr_res = [box.tolist() for box in dt_boxes]
-
-                                    if ocr_res:
-                                        ocr_result_list = get_ocr_result_list(
-                                            ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang
-                                        )
-
-                                        if res["category_id"] == 3:
-                                            # ocr_result_list中所有bbox的面积之和
-                                            ocr_res_area = sum(
-                                                get_coords_and_area(ocr_res_item)[4] for ocr_res_item in ocr_result_list if 'poly' in ocr_res_item)
-                                            # 求ocr_res_area和res的面积的比值
-                                            res_area = get_coords_and_area(res)[4]
-                                            if res_area > 0:
-                                                ratio = ocr_res_area / res_area
-                                                if ratio > 0.25:
-                                                    res["category_id"] = 1
-                                                else:
-                                                    continue
-
-                                        ocr_res_list_dict['layout_res'].extend(ocr_result_list)
-                else:
-                    # 原始单张处理模式
-                    for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"):
-                        # Process each area that requires OCR processing
-                        _lang = ocr_res_list_dict['lang']
-                        # Get OCR results for this language's images
-                        ocr_model = atom_model_manager.get_atom_model(
-                            atom_model_name='ocr',
-                            ocr_show_log=False,
-                            det_db_box_thresh=0.3,
-                            lang=_lang
-                        )
-                        for res in ocr_res_list_dict['ocr_res_list']:
-                            new_image, useful_list = crop_img(
-                                res, ocr_res_list_dict['pil_img'], crop_paste_x=50, crop_paste_y=50
-                            )
-                            adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
-                                ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
-                            )
-                            # OCR-det
-                            new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
-                            ocr_res = ocr_model.ocr(
-                                new_image, mfd_res=adjusted_mfdetrec_res, rec=False
-                            )[0]
-
-                            # Integration results
                             if ocr_res:
-                                ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_res_list_dict['ocr_enable'],
-                                                                      new_image, _lang)
+                                ocr_result_list = get_ocr_result_list(
+                                    ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang
+                                )
 
                                 if res["category_id"] == 3:
                                     # ocr_result_list中所有bbox的面积之和
@@ -245,6 +201,50 @@ class BatchAnalyze:
                                             continue
 
                                 ocr_res_list_dict['layout_res'].extend(ocr_result_list)
+        else:
+            # 原始单张处理模式
+            for ocr_res_list_dict in tqdm(ocr_res_list_all_page, desc="OCR-det Predict"):
+                # Process each area that requires OCR processing
+                _lang = ocr_res_list_dict['lang']
+                # Get OCR results for this language's images
+                ocr_model = atom_model_manager.get_atom_model(
+                    atom_model_name='ocr',
+                    ocr_show_log=False,
+                    det_db_box_thresh=0.3,
+                    lang=_lang
+                )
+                for res in ocr_res_list_dict['ocr_res_list']:
+                    new_image, useful_list = crop_img(
+                        res, ocr_res_list_dict['pil_img'], crop_paste_x=50, crop_paste_y=50
+                    )
+                    adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
+                        ocr_res_list_dict['single_page_mfdetrec_res'], useful_list
+                    )
+                    # OCR-det
+                    new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
+                    ocr_res = ocr_model.ocr(
+                        new_image, mfd_res=adjusted_mfdetrec_res, rec=False
+                    )[0]
+
+                    # Integration results
+                    if ocr_res:
+                        ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_res_list_dict['ocr_enable'],
+                                                              new_image, _lang)
+
+                        if res["category_id"] == 3:
+                            # ocr_result_list中所有bbox的面积之和
+                            ocr_res_area = sum(
+                                get_coords_and_area(ocr_res_item)[4] for ocr_res_item in ocr_result_list if 'poly' in ocr_res_item)
+                            # 求ocr_res_area和res的面积的比值
+                            res_area = get_coords_and_area(res)[4]
+                            if res_area > 0:
+                                ratio = ocr_res_area / res_area
+                                if ratio > 0.25:
+                                    res["category_id"] = 1
+                                else:
+                                    continue
+
+                        ocr_res_list_dict['layout_res'].extend(ocr_result_list)
 
         # 表格识别 table recognition
         if self.table_enable: