Browse Source

refactor(model): integrate AtomModelSingleton for OCR and improve OCR result handling

- Replace direct OCR model access with AtomModelSingleton for better model management
- Round OCR scores to 2 decimal places for consistency
- Improve error handling and logging in batch analysis
- Simplify OCR result processing in pdf_parse_union_core_v2.py
myhloli 7 tháng trước cách đây
mục cha
commit
59d6b195b0
2 tập tin đã thay đổi với 12 bổ sung6 xóa
  1. 11 5
      magic_pdf/model/batch_analyze.py
  2. 1 1
      magic_pdf/pdf_parse_union_core_v2.py

+ 11 - 5
magic_pdf/model/batch_analyze.py

@@ -5,7 +5,7 @@ import torch
 from loguru import logger
 
 from magic_pdf.config.constants import MODEL_NAME
-from magic_pdf.model.pdf_extract_kit import CustomPEKModel
+from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
 from magic_pdf.model.sub_modules.model_utils import (
     clean_vram, crop_img, get_res_list_from_layout_res)
 from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.ocr_utils import (
@@ -212,15 +212,21 @@ class BatchAnalyze:
             for lang, img_crop_list in img_crop_lists_by_lang.items():
                 if len(img_crop_list) > 0:
                     # Get OCR results for this language's images
-                    ocr_res_list = self.model.ocr_model.ocr(img_crop_list, det=False)[0]
-                    need_ocr_list = need_ocr_lists_by_lang[lang]
+                    atom_model_manager = AtomModelSingleton()
+                    ocr_model = atom_model_manager.get_atom_model(
+                        atom_model_name='ocr',
+                        ocr_show_log=False,
+                        det_db_box_thresh=0.3,
+                        lang=lang
+                    )
+                    ocr_res_list = ocr_model.ocr(img_crop_list, det=False)[0]
 
                     # Verify we have matching counts
                     assert len(ocr_res_list) == len(
-                        need_ocr_list), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_list)} for lang: {lang}'
+                        need_ocr_lists_by_lang[lang]), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_lists_by_lang[lang])} for lang: {lang}'
 
                     # Process OCR results for this language
-                    for index, layout_res_item in enumerate(need_ocr_list):
+                    for index, layout_res_item in enumerate(need_ocr_lists_by_lang[lang]):
                         ocr_text, ocr_score = ocr_res_list[index]
                         layout_res_item['text'] = ocr_text
                         layout_res_item['score'] = float(round(ocr_score, 2))

+ 1 - 1
magic_pdf/pdf_parse_union_core_v2.py

@@ -309,7 +309,7 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
                     # logger.info(f"ocr_text: {ocr_text}, ocr_score: {ocr_score}")
                     if ocr_score > 0.5 and len(ocr_text) > 0:
                         span['content'] = ocr_text
-                        span['score'] = ocr_score
+                        span['score'] = float(round(ocr_score, 2))
                     else:
                         spans.remove(span)