Quellcode durchsuchen

feat(table): add RapidOCR support for RapidTable model

- Integrate RapidOCR with RapidTable model for table recognition
- Improve memory management for devices with <= 8GB VRAM
- Update table recognition process to use RapidOCR for RapidTable
- Add rapidocr-paddle dependency in setup.py
myhloli vor 1 Jahr
Ursprung
Commit
fe2c2c0d8e
2 geänderte Dateien mit 35 neuen und 25 gelöschten Zeilen
  1. 34 25
      magic_pdf/model/pdf_extract_kit.py
  2. 1 0
      setup.py

+ 34 - 25
magic_pdf/model/pdf_extract_kit.py

@@ -26,6 +26,7 @@ try:
     from unimernet.processors import load_processor
     from doclayout_yolo import YOLOv10
     from rapid_table import RapidTable
+    from rapidocr_paddle import RapidOCR
 
 except ImportError as e:
     logger.exception(e)
@@ -42,6 +43,7 @@ from magic_pdf.model.ppTableModel import ppTableModel
 
 
 def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
+    ocr_engine = None
     if table_model_type == MODEL_NAME.STRUCT_EQTABLE:
         table_model = StructTableModel(model_path, max_time=max_time)
     elif table_model_type == MODEL_NAME.TABLE_MASTER:
@@ -52,11 +54,15 @@ def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
         table_model = ppTableModel(config)
     elif table_model_type == MODEL_NAME.RAPID_TABLE:
         table_model = RapidTable()
+        ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
     else:
         logger.error("table model type not allow")
         exit(1)
 
-    return table_model
+    if ocr_engine:
+        return [table_model, ocr_engine]
+    else:
+        return table_model
 
 
 def mfd_model_init(weight):
@@ -283,23 +289,32 @@ class CustomPEKModel:
                 doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name]))
             )
         # 初始化ocr
-        # if self.apply_ocr:
-        self.ocr_model = atom_model_manager.get_atom_model(
-            atom_model_name=AtomicModel.OCR,
-            ocr_show_log=show_log,
-            det_db_box_thresh=0.3,
-            lang=self.lang
-        )
+        if self.apply_ocr:
+            self.ocr_model = atom_model_manager.get_atom_model(
+                atom_model_name=AtomicModel.OCR,
+                ocr_show_log=show_log,
+                det_db_box_thresh=0.3,
+                lang=self.lang
+            )
         # init table model
         if self.apply_table:
             table_model_dir = self.configs["weights"][self.table_model_name]
-            self.table_model = atom_model_manager.get_atom_model(
-                atom_model_name=AtomicModel.Table,
-                table_model_name=self.table_model_name,
-                table_model_path=str(os.path.join(models_dir, table_model_dir)),
-                table_max_time=self.table_max_time,
-                device=self.device
-            )
+            if self.table_model_name in [MODEL_NAME.STRUCT_EQTABLE, MODEL_NAME.TABLE_MASTER]:
+                self.table_model = atom_model_manager.get_atom_model(
+                    atom_model_name=AtomicModel.Table,
+                    table_model_name=self.table_model_name,
+                    table_model_path=str(os.path.join(models_dir, table_model_dir)),
+                    table_max_time=self.table_max_time,
+                    device=self.device
+                )
+            elif self.table_model_name in [MODEL_NAME.RAPID_TABLE]:
+                self.table_model, self.ocr_engine =atom_model_manager.get_atom_model(
+                    atom_model_name=AtomicModel.Table,
+                    table_model_name=self.table_model_name,
+                    table_model_path=str(os.path.join(models_dir, table_model_dir)),
+                    table_max_time=self.table_max_time,
+                    device=self.device
+                )
 
         logger.info('DocAnalysis init done!')
 
@@ -381,9 +396,8 @@ class CustomPEKModel:
                 table_res_list.append(res)
 
         if torch.cuda.is_available() and self.device != 'cpu':
-            properties = torch.cuda.get_device_properties(self.device)
-            total_memory = properties.total_memory / (1024 ** 3)  # 将字节转换为 GB
-            if total_memory <= 10:
+            total_memory = torch.cuda.get_device_properties(self.device).total_memory / (1024 ** 3)  # 将字节转换为 GB
+            if total_memory <= 8:
                 gc_start = time.time()
                 clean_memory()
                 gc_time = round(time.time() - gc_start, 2)
@@ -456,13 +470,8 @@ class CustomPEKModel:
                 elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
                     html_code = self.table_model.img2html(new_image)
                 elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
-                    new_image_bgr = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
-                    ocr_result = self.ocr_model.ocr(new_image_bgr)[0]
-                    new_ocr_result = []
-                    for box_ocr_res in ocr_result:
-                        text, score = box_ocr_res[1]
-                        new_ocr_result.append([box_ocr_res[0], text, score])
-                    html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(new_image), new_ocr_result)
+                    ocr_result, _ = self.ocr_engine(np.asarray(new_image))
+                    html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(new_image), ocr_result)
 
                 run_time = time.time() - single_table_start_time
                 # logger.info(f"------------table recognition processing ends within {run_time}s-----")

+ 1 - 0
setup.py

@@ -47,6 +47,7 @@ if __name__ == '__main__':
                      "einops",  # struct-eqtable依赖
                      "accelerate",  # struct-eqtable依赖
                      "doclayout_yolo==0.0.2",  # doclayout_yolo
+                     "rapidocr-paddle",  # rapidocr-paddle
                      "rapid_table",  # rapid_table
                      "detectron2"
                      ],