소스 검색

feat(ocr): implement language-specific OCR processing

- Add support for multiple languages in OCR processing
- Create separate lists for each language to improve processing efficiency
- Update OCR model initialization to use PytorchPaddleOCR instead of ModifiedPaddleOCR
- Modify get_ocr_result_list function to include language information- Improve logging for OCR detection and recognition
myhloli 7 달 전
부모
커밋
d7d85a2881

+ 48 - 17
magic_pdf/model/batch_analyze.py

@@ -124,7 +124,7 @@ class BatchAnalyze:
 
                 # Integration results
                 if ocr_res:
-                    ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_enable, new_image)
+                    ocr_result_list = get_ocr_result_list(ocr_res, useful_list, ocr_enable, new_image, _lang)
                     layout_res.extend(ocr_result_list)
             det_time += time.time() - det_start
             det_count += len(ocr_res_list)
@@ -177,27 +177,58 @@ class BatchAnalyze:
         if self.model.apply_table:
             logger.info(f'table time: {round(table_time, 2)}, image num: {table_count}')
 
-        need_ocr_list = []
-        img_crop_list = []
+        # Create dictionaries to store items by language
+        need_ocr_lists_by_lang = {}  # Dict of lists for each language
+        img_crop_lists_by_lang = {}  # Dict of lists for each language
+
         for layout_res in images_layout_res:
             for layout_res_item in layout_res:
                 if layout_res_item['category_id'] in [15]:
-                    if 'np_img' in layout_res_item:
-                        need_ocr_list.append(layout_res_item)
-                        img_crop_list.append(layout_res_item['np_img'])
+                    if 'np_img' in layout_res_item and 'lang' in layout_res_item:
+                        lang = layout_res_item['lang']
+
+                        # Initialize lists for this language if not exist
+                        if lang not in need_ocr_lists_by_lang:
+                            need_ocr_lists_by_lang[lang] = []
+                            img_crop_lists_by_lang[lang] = []
+
+                        # Add to the appropriate language-specific lists
+                        need_ocr_lists_by_lang[lang].append(layout_res_item)
+                        img_crop_lists_by_lang[lang].append(layout_res_item['np_img'])
+
+                        # Remove the fields after adding to lists
                         layout_res_item.pop('np_img')
+                        layout_res_item.pop('lang')
+
+
+        if len(img_crop_lists_by_lang) > 0:
+
+            # Process OCR by language
+            rec_time = 0
+            rec_start = time.time()
+            total_processed = 0
+
+            # Process each language separately
+            for lang, img_crop_list in img_crop_lists_by_lang.items():
+                if len(img_crop_list) > 0:
+                    # Get OCR results for this language's images
+                    ocr_res_list = self.model.ocr_model.ocr(img_crop_list, det=False)[0]
+                    need_ocr_list = need_ocr_lists_by_lang[lang]
+
+                    # Verify we have matching counts
+                    assert len(ocr_res_list) == len(
+                        need_ocr_list), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_list)} for lang: {lang}'
+
+                    # Process OCR results for this language
+                    for index, layout_res_item in enumerate(need_ocr_list):
+                        ocr_text, ocr_score = ocr_res_list[index]
+                        layout_res_item['text'] = ocr_text
+                        layout_res_item['score'] = float(round(ocr_score, 2))
+
+                    total_processed += len(img_crop_list)
 
-        rec_time = 0
-        rec_start = time.time()
-        if len(img_crop_list) > 0:
-            ocr_res_list = self.model.ocr_model.ocr(img_crop_list, det=False)[0]
-            assert len(ocr_res_list)==len(need_ocr_list), f'ocr_res_list: {len(ocr_res_list)}, need_ocr_list: {len(need_ocr_list)}'
-            for index, layout_res_item in enumerate(need_ocr_list):
-                ocr_text, ocr_score = ocr_res_list[index]
-                layout_res_item['text'] = ocr_text
-                layout_res_item['score'] = float(round(ocr_score, 2))
-        rec_time += time.time() - rec_start
-        logger.info(f'ocr-rec time: {round(rec_time, 2)}, image num: {len(img_crop_list)}')
+            rec_time += time.time() - rec_start
+            logger.info(f'ocr-rec time: {round(rec_time, 2)}, total images processed: {total_processed}')
 
 
 

+ 1 - 1
magic_pdf/model/pdf_extract_kit.py

@@ -14,7 +14,7 @@ from magic_pdf.model.model_list import AtomicModel
 from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
 from magic_pdf.model.sub_modules.model_utils import (
     clean_vram, crop_img, get_res_list_from_layout_res)
-from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
+from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.ocr_utils import (
     get_adjusted_mfdetrec_res, get_ocr_result_list)
 
 

+ 31 - 28
magic_pdf/model/sub_modules/model_init.py

@@ -7,32 +7,33 @@ from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv
 from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import DocLayoutYOLOModel
 from magic_pdf.model.sub_modules.mfd.yolov8.YOLOv8 import YOLOv8MFDModel
 from magic_pdf.model.sub_modules.mfr.unimernet.Unimernet import UnimernetModel
-
-try:
-    from magic_pdf_ascend_plugin.libs.license_verifier import (
-        LicenseExpiredError, LicenseFormatError, LicenseSignatureError,
-        load_license)
-    from magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu import ModifiedPaddleOCR
-    from magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu import RapidTableModel
-    license_key = load_license()
-    logger.info(f'Using Ascend Plugin Success, License id is {license_key["payload"]["id"]},'
-                f' License expired at {license_key["payload"]["date"]["end_date"]}')
-except Exception as e:
-    if isinstance(e, ImportError):
-        pass
-    elif isinstance(e, LicenseFormatError):
-        logger.error('Ascend Plugin: Invalid license format. Please check the license file.')
-    elif isinstance(e, LicenseSignatureError):
-        logger.error('Ascend Plugin: Invalid signature. The license may be tampered with.')
-    elif isinstance(e, LicenseExpiredError):
-        logger.error('Ascend Plugin: License has expired. Please renew your license.')
-    elif isinstance(e, FileNotFoundError):
-        logger.error('Ascend Plugin: Not found License file.')
-    else:
-        logger.error(f'Ascend Plugin: {e}')
-    from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
-    # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
-    from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
+from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
+from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
+# try:
+#     from magic_pdf_ascend_plugin.libs.license_verifier import (
+#         LicenseExpiredError, LicenseFormatError, LicenseSignatureError,
+#         load_license)
+#     from magic_pdf_ascend_plugin.model_plugin.ocr.paddleocr.ppocr_273_npu import ModifiedPaddleOCR
+#     from magic_pdf_ascend_plugin.model_plugin.table.rapidtable.rapid_table_npu import RapidTableModel
+#     license_key = load_license()
+#     logger.info(f'Using Ascend Plugin Success, License id is {license_key["payload"]["id"]},'
+#                 f' License expired at {license_key["payload"]["date"]["end_date"]}')
+# except Exception as e:
+#     if isinstance(e, ImportError):
+#         pass
+#     elif isinstance(e, LicenseFormatError):
+#         logger.error('Ascend Plugin: Invalid license format. Please check the license file.')
+#     elif isinstance(e, LicenseSignatureError):
+#         logger.error('Ascend Plugin: Invalid signature. The license may be tampered with.')
+#     elif isinstance(e, LicenseExpiredError):
+#         logger.error('Ascend Plugin: License has expired. Please renew your license.')
+#     elif isinstance(e, FileNotFoundError):
+#         logger.error('Ascend Plugin: Not found License file.')
+#     else:
+#         logger.error(f'Ascend Plugin: {e}')
+#     from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_273_mod import ModifiedPaddleOCR
+#     # from magic_pdf.model.sub_modules.ocr.paddleocr.ppocr_291_mod import ModifiedPaddleOCR
+#     from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
 
 
 def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None, table_sub_model_name=None):
@@ -94,7 +95,8 @@ def ocr_model_init(show_log: bool = False,
                    det_db_unclip_ratio=1.8,
                    ):
     if lang is not None and lang != '':
-        model = ModifiedPaddleOCR(
+        # model = ModifiedPaddleOCR(
+        model = PytorchPaddleOCR(
             show_log=show_log,
             det_db_box_thresh=det_db_box_thresh,
             lang=lang,
@@ -102,7 +104,8 @@ def ocr_model_init(show_log: bool = False,
             det_db_unclip_ratio=det_db_unclip_ratio,
         )
     else:
-        model = ModifiedPaddleOCR(
+        # model = ModifiedPaddleOCR(
+        model = PytorchPaddleOCR(
             show_log=show_log,
             det_db_box_thresh=det_db_box_thresh,
             use_dilation=use_dilation,

+ 3 - 2
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/ocr_utils.py

@@ -261,7 +261,7 @@ def get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list):
     return adjusted_mfdetrec_res
 
 
-def get_ocr_result_list(ocr_res, useful_list, ocr_enable, new_image):
+def get_ocr_result_list(ocr_res, useful_list, ocr_enable, new_image, lang):
     paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
     ocr_result_list = []
     ori_im = new_image.copy()
@@ -307,9 +307,10 @@ def get_ocr_result_list(ocr_res, useful_list, ocr_enable, new_image):
             ocr_result_list.append({
                 'category_id': 15,
                 'poly': p1 + p2 + p3 + p4,
-                'score': float(round(score, 2)),
+                'score': 1,
                 'text': text,
                 'np_img': img_crop,
+                'lang': lang,
             })
         else:
             ocr_result_list.append({

+ 2 - 0
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorch_paddle.py

@@ -66,6 +66,7 @@ class PytorchPaddleOCR(TextSystem):
             for img in imgs:
                 img = preprocess_image(img)
                 dt_boxes, elapse = self.text_detector(img)
+                logger.debug("dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse))
                 if dt_boxes is None:
                     ocr_res.append(None)
                     continue
@@ -84,6 +85,7 @@ class PytorchPaddleOCR(TextSystem):
                     img = preprocess_image(img)
                     img = [img]
                 rec_res, elapse = self.text_recognizer(img)
+                logger.debug("rec_res num  : {}, elapsed : {}".format(len(rec_res), elapse))
                 ocr_res.append(rec_res)
             return ocr_res