Pārlūkot izejas kodu

refactor(ocr): comment out print statements and update table model initialization

- Comment out print statements in base_ocr_v20.py and pytorch_paddle.py
- Update table model initialization to use lang parameter instead of ocr_engine
- Remove unused RapidOCR initialization in rapid_table.py
myhloli 7 mēneši atpakaļ
vecāks
revīzija
5252c46e4c

+ 11 - 3
magic_pdf/model/sub_modules/model_init.py

@@ -36,7 +36,7 @@ from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableM
 #     from magic_pdf.model.sub_modules.table.rapidtable.rapid_table import RapidTableModel
 
 
-def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr_engine=None, table_sub_model_name=None):
+def table_model_init(table_model_type, model_path, max_time, _device_='cpu', lang=None, table_sub_model_name=None):
     if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
         from magic_pdf.model.sub_modules.table.structeqtable.struct_eqtable import StructTableModel
         table_model = StructTableModel(model_path, max_new_tokens=2048, max_time=max_time)
@@ -48,6 +48,14 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu', ocr
         }
         table_model = TableMasterPaddleModel(config)
     elif table_model_type == MODEL_NAME.RAPID_TABLE:
+        atom_model_manager = AtomModelSingleton()
+        ocr_engine = atom_model_manager.get_atom_model(
+            atom_model_name='ocr',
+            ocr_show_log=False,
+            det_db_box_thresh=0.5,
+            det_db_unclip_ratio=1.6,
+            lang=lang
+        )
         table_model = RapidTableModel(ocr_engine, table_sub_model_name)
     else:
         logger.error('table model type not allow')
@@ -134,7 +142,7 @@ class AtomModelSingleton:
         elif atom_model_name in [AtomicModel.Layout]:
             key = (atom_model_name, layout_model_name)
         elif atom_model_name in [AtomicModel.Table]:
-            key = (atom_model_name, table_model_name)
+            key = (atom_model_name, table_model_name, lang)
         else:
             key = atom_model_name
 
@@ -182,7 +190,7 @@ def atom_model_init(model_name: str, **kwargs):
             kwargs.get('table_model_path'),
             kwargs.get('table_max_time'),
             kwargs.get('device'),
-            kwargs.get('ocr_engine'),
+            kwargs.get('lang'),
             kwargs.get('table_sub_model_name')
         )
     elif model_name == AtomicModel.LangDetect:

+ 4 - 4
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorch_paddle.py

@@ -109,7 +109,7 @@ class PytorchPaddleOCR(TextSystem):
             for img in imgs:
                 img = preprocess_image(img)
                 dt_boxes, elapse = self.text_detector(img)
-                logger.debug("dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse))
+                # logger.debug("dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse))
                 if dt_boxes is None:
                     ocr_res.append(None)
                     continue
@@ -128,7 +128,7 @@ class PytorchPaddleOCR(TextSystem):
                     img = preprocess_image(img)
                     img = [img]
                 rec_res, elapse = self.text_recognizer(img)
-                logger.debug("rec_res num  : {}, elapsed : {}".format(len(rec_res), elapse))
+                # logger.debug("rec_res num  : {}, elapsed : {}".format(len(rec_res), elapse))
                 ocr_res.append(rec_res)
             return ocr_res
 
@@ -146,7 +146,7 @@ class PytorchPaddleOCR(TextSystem):
             return None, None
         else:
             pass
-            logger.debug("dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse))
+            # logger.debug("dt_boxes num : {}, elapsed : {}".format(len(dt_boxes), elapse))
         img_crop_list = []
 
         dt_boxes = sorted_boxes(dt_boxes)
@@ -163,7 +163,7 @@ class PytorchPaddleOCR(TextSystem):
             img_crop_list.append(img_crop)
 
         rec_res, elapse = self.text_recognizer(img_crop_list)
-        logger.debug("rec_res num  : {}, elapsed : {}".format(len(rec_res), elapse))
+        # logger.debug("rec_res num  : {}, elapsed : {}".format(len(rec_res), elapse))
 
         filter_boxes, filter_rec_res = [], []
         for box, rec_result in zip(dt_boxes, rec_res):

+ 2 - 2
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorchocr/base_ocr_v20.py

@@ -27,11 +27,11 @@ class BaseOCRV20:
 
     def load_state_dict(self, weights):
         self.net.load_state_dict(weights)
-        print('weights is loaded.')
+        # print('weights is loaded.')
 
     def load_pytorch_weights(self, weights_path):
         self.net.load_state_dict(torch.load(weights_path, weights_only=True))
-        print('model is loaded: {}'.format(weights_path))
+        # print('model is loaded: {}'.format(weights_path))
 
     def inference(self, inputs):
         with torch.no_grad():

+ 9 - 17
magic_pdf/model/sub_modules/table/rapidtable/rapid_table.py

@@ -23,25 +23,17 @@ class RapidTableModel(object):
 
         self.table_model = RapidTable(input_args)
 
-        # if ocr_engine is None:
-        #     self.ocr_model_name = "RapidOCR"
-        #     if torch.cuda.is_available():
-        #         from rapidocr_paddle import RapidOCR
-        #         self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
-        #     else:
-        #         from rapidocr_onnxruntime import RapidOCR
-        #         self.ocr_engine = RapidOCR()
+        # self.ocr_model_name = "RapidOCR"
+        # if torch.cuda.is_available():
+        #     from rapidocr_paddle import RapidOCR
+        #     self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
         # else:
-        #     self.ocr_model_name = "PaddleOCR"
-        #     self.ocr_engine = ocr_engine
+        #     from rapidocr_onnxruntime import RapidOCR
+        #     self.ocr_engine = RapidOCR()
+
+        self.ocr_model_name = "PaddleOCR"
+        self.ocr_engine = ocr_engine
 
-        self.ocr_model_name = "RapidOCR"
-        if torch.cuda.is_available():
-            from rapidocr_paddle import RapidOCR
-            self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
-        else:
-            from rapidocr_onnxruntime import RapidOCR
-            self.ocr_engine = RapidOCR()
 
     def predict(self, image):