فهرست منبع

feat(ocr): implement dynamic OCR processing for text spans with low contrast

- Comment out OCR model initialization and execution for low-contrast spans
- Add batch OCR processing for collected image spans
- Adjust contrast threshold for OCR processing
- Remove unnecessary OCR processing for high-contrast spans
- Implement more efficient OCR workflow by processing multiple spans at once
myhloli 7 ماه پیش
والد
کامیت
a024c30fc4
1فایلهای تغییر یافته به همراه66 افزوده شده و 20 حذف شده
  1. 66 20
      magic_pdf/pdf_parse_union_core_v2.py

+ 66 - 20
magic_pdf/pdf_parse_union_core_v2.py

@@ -193,7 +193,7 @@ def calculate_contrast(img, img_mode) -> float:
     std_dev = np.std(gray_img)
     # 对比度定义为标准差除以平均值(加上小常数避免除零错误)
     contrast = std_dev / (mean_value + 1e-6)
-    # logger.info(f"contrast: {contrast}")
+    # logger.debug(f"contrast: {contrast}")
     return round(contrast, 2)
 
 # @measure_time
@@ -286,33 +286,39 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
     if len(need_ocr_spans) > 0:
 
         # 初始化ocr模型
-        atom_model_manager = AtomModelSingleton()
-        ocr_model = atom_model_manager.get_atom_model(
-            atom_model_name='ocr',
-            ocr_show_log=False,
-            det_db_box_thresh=0.3,
-            lang=lang
-        )
+        # atom_model_manager = AtomModelSingleton()
+        # ocr_model = atom_model_manager.get_atom_model(
+        #     atom_model_name='ocr',
+        #     ocr_show_log=False,
+        #     det_db_box_thresh=0.3,
+        #     lang=lang
+        # )
 
         for span in need_ocr_spans:
             # 对span的bbox截图再ocr
             span_img = cut_image_to_pil_image(span['bbox'], pdf_page, mode='cv2')
 
             # 计算span的对比度,低于0.20的span不进行ocr
-            if calculate_contrast(span_img, img_mode='bgr') <= 0.20:
+            if calculate_contrast(span_img, img_mode='bgr') <= 0.17:
                 spans.remove(span)
                 continue
+                # pass
+
+            span['content'] = ''
+            span['score'] = 1
+            span['np_img'] = span_img
 
-            ocr_res = ocr_model.ocr(span_img, det=False)
-            if ocr_res and len(ocr_res) > 0:
-                if len(ocr_res[0]) > 0:
-                    ocr_text, ocr_score = ocr_res[0][0]
-                    # logger.info(f"ocr_text: {ocr_text}, ocr_score: {ocr_score}")
-                    if ocr_score > 0.5 and len(ocr_text) > 0:
-                        span['content'] = ocr_text
-                        span['score'] = float(round(ocr_score, 2))
-                    else:
-                        spans.remove(span)
+
+            # ocr_res = ocr_model.ocr(span_img, det=False)
+            # if ocr_res and len(ocr_res) > 0:
+            #     if len(ocr_res[0]) > 0:
+            #         ocr_text, ocr_score = ocr_res[0][0]
+            #         # logger.info(f"ocr_text: {ocr_text}, ocr_score: {ocr_score}")
+            #         if ocr_score > 0.5 and len(ocr_text) > 0:
+            #             span['content'] = ocr_text
+            #             span['score'] = float(round(ocr_score, 2))
+            #         else:
+            #             spans.remove(span)
 
     return spans
 
@@ -952,7 +958,47 @@ def pdf_parse_union(
             )
         pdf_info_dict[f'page_{page_id}'] = page_info
 
-    # PerformanceStats.print_stats()
+    need_ocr_list = []
+    img_crop_list = []
+    text_block_list = []
+    for pange_id, page_info in pdf_info_dict.items():
+        for block in page_info['preproc_blocks']:
+            if block['type'] in ['table', 'image']:
+                for sub_block in block['blocks']:
+                    if sub_block['type'] in ['image_caption', 'image_footnote', 'table_caption', 'table_footnote']:
+                        text_block_list.append(sub_block)
+            elif block['type'] in ['text', 'title']:
+                text_block_list.append(block)
+        for block in page_info['discarded_blocks']:
+            text_block_list.append(block)
+    for block in text_block_list:
+        for line in block['lines']:
+            for span in line['spans']:
+                if 'np_img' in span:
+                    need_ocr_list.append(span)
+                    img_crop_list.append(span['np_img'])
+                    span.pop('np_img')
+    if len(img_crop_list) > 0:
+        # Get OCR results for this language's images
+        atom_model_manager = AtomModelSingleton()
+        ocr_model = atom_model_manager.get_atom_model(
+            atom_model_name='ocr',
+            ocr_show_log=False,
+            det_db_box_thresh=0.3,
+            lang=lang
+        )
+        rec_start = time.time()
+        ocr_res_list = ocr_model.ocr(img_crop_list, det=False)[0]
+        # Verify we have matching counts
+        assert len(ocr_res_list) == len(need_ocr_list), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_list)}'
+        # Process OCR results for this language
+        for index, span in enumerate(need_ocr_list):
+            ocr_text, ocr_score = ocr_res_list[index]
+            span['content'] = ocr_text
+            span['score'] = float(round(ocr_score, 2))
+        rec_time = time.time() - rec_start
+        logger.info(f'ocr-dynamic-rec time: {round(rec_time, 2)}, total images processed: {len(img_crop_list)}')
+
 
     """分段"""
     para_split(pdf_info_dict)