Forráskód Böngészése

feat(ocr): implement separate detection and recognition processes

- Split OCR process into detection and recognition stages
- Update batch analysis and document analysis pipelines
- Modify OCR result formatting and handling
- Remove unused imports and optimize code structure
myhloli 7 hónapja
szülő
commit
a330651d64

+ 41 - 17
magic_pdf/model/batch_analyze.py

@@ -8,7 +8,7 @@ from magic_pdf.config.constants import MODEL_NAME
 from magic_pdf.model.pdf_extract_kit import CustomPEKModel
 from magic_pdf.model.sub_modules.model_utils import (
     clean_vram, crop_img, get_res_list_from_layout_res)
-from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
+from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.ocr_utils import (
     get_adjusted_mfdetrec_res, get_ocr_result_list)
 
 YOLO_LAYOUT_BASE_BATCH_SIZE = 1
@@ -85,8 +85,8 @@ class BatchAnalyze:
         # 清理显存
         clean_vram(self.model.device, vram_threshold=8)
 
-        ocr_time = 0
-        ocr_count = 0
+        det_time = 0
+        det_count = 0
         table_time = 0
         table_count = 0
         # reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
@@ -100,7 +100,7 @@ class BatchAnalyze:
                 get_res_list_from_layout_res(layout_res)
             )
             # ocr识别
-            ocr_start = time.time()
+            det_start = time.time()
             # Process each area that requires OCR processing
             for res in ocr_res_list:
                 new_image, useful_list = crop_img(
@@ -113,21 +113,21 @@ class BatchAnalyze:
                 # OCR recognition
                 new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
 
-                if ocr_enable:
-                    ocr_res = self.model.ocr_model.ocr(
-                        new_image, mfd_res=adjusted_mfdetrec_res
-                    )[0]
-                else:
-                    ocr_res = self.model.ocr_model.ocr(
-                        new_image, mfd_res=adjusted_mfdetrec_res, rec=False
-                    )[0]
+                # if ocr_enable:
+                #     ocr_res = self.model.ocr_model.ocr(
+                #         new_image, mfd_res=adjusted_mfdetrec_res
+                #     )[0]
+                # else:
+                ocr_res = self.model.ocr_model.ocr(
+                    new_image, mfd_res=adjusted_mfdetrec_res, rec=False
+                )[0]
 
                 # Integration results
                 if ocr_res:
-                    ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
+                    ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_enable, new_image)
                     layout_res.extend(ocr_result_list)
-            ocr_time += time.time() - ocr_start
-            ocr_count += len(ocr_res_list)
+            det_time += time.time() - det_start
+            det_count += len(ocr_res_list)
 
             # 表格识别 table recognition
             if self.model.apply_table:
@@ -172,9 +172,33 @@ class BatchAnalyze:
                 table_time += time.time() - table_start
                 table_count += len(table_res_list)
 
-        if self.model.apply_ocr:
-            logger.info(f'det or det time costs: {round(ocr_time, 2)}, image num: {ocr_count}')
+
+        logger.info(f'ocr-det time: {round(det_time, 2)}, image num: {det_count}')
         if self.model.apply_table:
             logger.info(f'table time: {round(table_time, 2)}, image num: {table_count}')
 
+        need_ocr_list = []
+        img_crop_list = []
+        for layout_res in images_layout_res:
+            for layout_res_item in layout_res:
+                if layout_res_item['category_id'] in [15]:
+                    if 'np_img' in layout_res_item:
+                        need_ocr_list.append(layout_res_item)
+                        img_crop_list.append(layout_res_item['np_img'])
+                        layout_res_item.pop('np_img')
+
+        rec_time = 0
+        rec_start = time.time()
+        if len(img_crop_list) > 0:
+            ocr_res_list = self.model.ocr_model.ocr(img_crop_list, det=False)[0]
+            assert len(ocr_res_list)==len(need_ocr_list), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_list)}'
+            for index, layout_res_item in enumerate(need_ocr_list):
+                ocr_text, ocr_score = ocr_res_list[index]
+                layout_res_item['text'] = ocr_text
+                layout_res_item['score'] = float(round(ocr_score, 2))
+        rec_time += time.time() - rec_start
+        logger.info(f'ocr-rec time: {round(rec_time, 2)}, image num: {len(img_crop_list)}')
+
+
+
         return images_layout_res

+ 1 - 1
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -141,7 +141,7 @@ def doc_analyze(
         else len(dataset) - 1
     )
 
-    MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
+    MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 200))
     images = []
     page_wh_list = []
     for index in range(len(dataset)):

+ 24 - 7
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/ocr_utils.py

@@ -1,4 +1,6 @@
 # Copyright (c) Opendatalab. All rights reserved.
+import copy
+
 import cv2
 import numpy as np
 from magic_pdf.pre_proc.ocr_dict_merge import merge_spans_to_line
@@ -259,9 +261,10 @@ def get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list):
     return adjusted_mfdetrec_res
 
 
-def get_ocr_result_list(ocr_res, useful_list):
+def get_ocr_result_list(ocr_res, useful_list, ocr_enable, new_image):
     paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
     ocr_result_list = []
+    ori_im = new_image.copy()
     for box_ocr_res in ocr_res:
 
         if len(box_ocr_res) == 2:
@@ -273,6 +276,11 @@ def get_ocr_result_list(ocr_res, useful_list):
         else:
             p1, p2, p3, p4 = box_ocr_res
             text, score = "", 1
+
+            if ocr_enable:
+                tmp_box = copy.deepcopy(np.array([p1, p2, p3, p4]).astype('float32'))
+                img_crop = get_rotate_crop_image(ori_im, tmp_box)
+
         # average_angle_degrees = calculate_angle_degrees(box_ocr_res[0])
         # if average_angle_degrees > 0.5:
         poly = [p1, p2, p3, p4]
@@ -295,12 +303,21 @@ def get_ocr_result_list(ocr_res, useful_list):
         p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
         p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
 
-        ocr_result_list.append({
-            'category_id': 15,
-            'poly': p1 + p2 + p3 + p4,
-            'score': float(round(score, 2)),
-            'text': text,
-        })
+        if ocr_enable:
+            ocr_result_list.append({
+                'category_id': 15,
+                'poly': p1 + p2 + p3 + p4,
+                'score': float(round(score, 2)),
+                'text': text,
+                'np_img': img_crop,
+            })
+        else:
+            ocr_result_list.append({
+                'category_id': 15,
+                'poly': p1 + p2 + p3 + p4,
+                'score': float(round(score, 2)),
+                'text': text,
+            })
 
     return ocr_result_list
 

+ 0 - 3
magic_pdf/pdf_parse_union_core_v2.py

@@ -21,12 +21,9 @@ from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir, get_l
 from magic_pdf.libs.convert_utils import dict_to_list
 from magic_pdf.libs.hash_utils import compute_md5
 from magic_pdf.libs.pdf_image_tools import cut_image_to_pil_image
-from magic_pdf.libs.performance_stats import measure_time, PerformanceStats
 from magic_pdf.model.magic_model import MagicModel
 from magic_pdf.post_proc.llm_aided import llm_aided_formula, llm_aided_text, llm_aided_title
 
-from concurrent.futures import ThreadPoolExecutor
-
 from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
 from magic_pdf.post_proc.para_split_v3 import para_split
 from magic_pdf.pre_proc.construct_page_dict import ocr_construct_page_component_v2