浏览代码

feat(ocr): add area ratio calculation for OCR results and enhance get_coords_and_area function

myhloli 5 月之前
父节点
当前提交
a2b848136b
共有 2 个文件被更改,包括 18 次插入5 次删除
  1. 14 1
      magic_pdf/model/batch_analyze.py
  2. 4 4
      magic_pdf/model/sub_modules/model_utils.py

+ 14 - 1
magic_pdf/model/batch_analyze.py

@@ -6,7 +6,7 @@ from tqdm import tqdm
 from magic_pdf.config.constants import MODEL_NAME
 from magic_pdf.config.constants import MODEL_NAME
 from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
 from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
 from magic_pdf.model.sub_modules.model_utils import (
 from magic_pdf.model.sub_modules.model_utils import (
-    clean_vram, crop_img, get_res_list_from_layout_res)
+    clean_vram, crop_img, get_res_list_from_layout_res, get_coords_and_area)
 from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.ocr_utils import (
 from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.ocr_utils import (
     get_adjusted_mfdetrec_res, get_ocr_result_list)
     get_adjusted_mfdetrec_res, get_ocr_result_list)
 
 
@@ -148,6 +148,19 @@ class BatchAnalyze:
                 # Integration results
                 # Integration results
                 if ocr_res:
                 if ocr_res:
                     ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang)
                     ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_res_list_dict['ocr_enable'], new_image, _lang)
+
+                    if res["category_id"] == 3:
+                        # ocr_result_list中所有bbox的面积之和
+                        ocr_res_area = sum(get_coords_and_area(ocr_res_item)[4] for ocr_res_item in ocr_result_list if 'poly' in ocr_res_item)
+                        # 求ocr_res_area和res的面积的比值
+                        res_area = get_coords_and_area(res)[4]
+                        if res_area > 0:
+                            ratio = ocr_res_area / res_area
+                            if ratio > 0.45:
+                                res["category_id"] = 1
+                            else:
+                                continue
+
                     ocr_res_list_dict['layout_res'].extend(ocr_result_list)
                     ocr_res_list_dict['layout_res'].extend(ocr_result_list)
 
 
             # det_count += len(ocr_res_list_dict['ocr_res_list'])
             # det_count += len(ocr_res_list_dict['ocr_res_list'])

+ 4 - 4
magic_pdf/model/sub_modules/model_utils.py

@@ -31,10 +31,10 @@ def crop_img(input_res, input_np_img, crop_paste_x=0, crop_paste_y=0):
     return return_image, return_list
     return return_image, return_list
 
 
 
 
-def get_coords_and_area(table):
+def get_coords_and_area(block_with_poly):
     """Extract coordinates and area from a table."""
     """Extract coordinates and area from a table."""
-    xmin, ymin = int(table['poly'][0]), int(table['poly'][1])
-    xmax, ymax = int(table['poly'][4]), int(table['poly'][5])
+    xmin, ymin = int(block_with_poly['poly'][0]), int(block_with_poly['poly'][1])
+    xmax, ymax = int(block_with_poly['poly'][4]), int(block_with_poly['poly'][5])
     area = (xmax - xmin) * (ymax - ymin)
     area = (xmax - xmin) * (ymax - ymin)
     return xmin, ymin, xmax, ymax, area
     return xmin, ymin, xmax, ymax, area
 
 
@@ -243,7 +243,7 @@ def get_res_list_from_layout_res(layout_res, iou_threshold=0.7, overlap_threshol
                 "bbox": [int(res['poly'][0]), int(res['poly'][1]),
                 "bbox": [int(res['poly'][0]), int(res['poly'][1]),
                          int(res['poly'][4]), int(res['poly'][5])],
                          int(res['poly'][4]), int(res['poly'][5])],
             })
             })
-        elif category_id in [0, 2, 4, 6, 7]:  # OCR regions
+        elif category_id in [0, 2, 4, 6, 7, 3]:  # OCR regions
             ocr_res_list.append(res)
             ocr_res_list.append(res)
         elif category_id == 5:  # Table regions
         elif category_id == 5:  # Table regions
             table_res_list.append(res)
             table_res_list.append(res)