Răsfoiți Sursa

refactor(magic_pdf): optimize table recognition and layout detection

- Update table recognition logic to process each table individually
- Refactor layout detection to use tqdm for progress tracking
- Optimize OCR recognition by using a single tqdm wrapper
- Improve MFR prediction with a more accurate progress bar
- Simplify MFD prediction by removing unnecessary total calculation
myhloli 7 luni în urmă
părinte
comite
1fd72f5f3a

+ 23 - 23
magic_pdf/model/batch_analyze.py

@@ -102,10 +102,13 @@ class BatchAnalyze:
                                           'single_page_mfdetrec_res':single_page_mfdetrec_res,
                                           'layout_res':layout_res,
                                           })
-            table_res_list_all_page.append({'table_res_list':table_res_list,
-                                            'lang':_lang,
-                                            'np_array_img':np_array_img,
-                                          })
+
+            for table_res in table_res_list:
+                table_img, _ = crop_img(table_res, np_array_img)
+                table_res_list_all_page.append({'table_res':table_res,
+                                                'lang':_lang,
+                                                'table_img':table_img,
+                                              })
 
         # 文本框检测
         det_start = time.time()
@@ -149,8 +152,8 @@ class BatchAnalyze:
             table_start = time.time()
             table_count = 0
             # for table_res_list_dict in table_res_list_all_page:
-            for table_res_list_dict in tqdm(table_res_list_all_page, desc="Table Predict"):
-                _lang = table_res_list_dict['lang']
+            for table_res_dict in tqdm(table_res_list_all_page, desc="Table Predict"):
+                _lang = table_res_dict['lang']
                 atom_model_manager = AtomModelSingleton()
                 ocr_engine = atom_model_manager.get_atom_model(
                     atom_model_name='ocr',
@@ -168,26 +171,23 @@ class BatchAnalyze:
                     ocr_engine=ocr_engine,
                     table_sub_model_name='slanet_plus'
                 )
-                for res in table_res_list_dict['table_res_list']:
-                    new_image, _ = crop_img(res, table_res_list_dict['np_array_img'])
-                    html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(new_image)
-                    # 判断是否返回正常
-                    if html_code:
-                        expected_ending = html_code.strip().endswith(
-                            '</html>'
-                        ) or html_code.strip().endswith('</table>')
-                        if expected_ending:
-                            res['html'] = html_code
-                        else:
-                            logger.warning(
-                                'table recognition processing fails, not found expected HTML table end'
-                            )
+                html_code, table_cell_bboxes, logic_points, elapse = table_model.predict(table_res_dict['table_img'])
+                # 判断是否返回正常
+                if html_code:
+                    expected_ending = html_code.strip().endswith(
+                        '</html>'
+                    ) or html_code.strip().endswith('</table>')
+                    if expected_ending:
+                        table_res_dict['table_res']['html'] = html_code
                     else:
                         logger.warning(
-                            'table recognition processing fails, not get html return'
+                            'table recognition processing fails, not found expected HTML table end'
                         )
-                table_count += len(table_res_list_dict['table_res_list'])
-            # logger.info(f'table time: {round(time.time() - table_start, 2)}, image num: {table_count}')
+                else:
+                    logger.warning(
+                        'table recognition processing fails, not get html return'
+                    )
+            # logger.info(f'table time: {round(time.time() - table_start, 2)}, image num: {len(table_res_list_all_page)}')
 
         # Create dictionaries to store items by language
         need_ocr_lists_by_lang = {}  # Dict of lists for each language

+ 1 - 1
magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py

@@ -33,7 +33,7 @@ class DocLayoutYOLOModel(object):
     def batch_predict(self, images: list, batch_size: int) -> list:
         images_layout_res = []
         # for index in range(0, len(images), batch_size):
-        for index in tqdm(range(0, len(images), batch_size), total=len(images) // batch_size + (1 if len(images) % batch_size != 0 else 0), desc="Layout Predict"):
+        for index in tqdm(range(0, len(images), batch_size), desc="Layout Predict"):
             doclayout_yolo_res = [
                 image_res.cpu()
                 for image_res in self.model.predict(

+ 1 - 3
magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py

@@ -16,9 +16,7 @@ class YOLOv8MFDModel(object):
     def batch_predict(self, images: list, batch_size: int) -> list:
         images_mfd_res = []
         # for index in range(0, len(images), batch_size):
-        for index in tqdm(range(0, len(images), batch_size),
-                          total=len(images) // batch_size + (1 if len(images) % batch_size != 0 else 0),
-                          desc="MFD Predict"):
+        for index in tqdm(range(0, len(images), batch_size), desc="MFD Predict"):
             mfd_res = [
                 image_res.cpu()
                 for image_res in self.mfd_model.predict(

+ 12 - 6
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py

@@ -109,12 +109,18 @@ class UnimernetModel(object):
         # Process batches and store results
         mfr_res = []
         # for mf_img in dataloader:
-        for mf_img in tqdm(dataloader, desc="MFR Predict"):
-            mf_img = mf_img.to(dtype=self.model.dtype)
-            mf_img = mf_img.to(self.device)
-            with torch.no_grad():
-                output = self.model.generate({"image": mf_img})
-            mfr_res.extend(output["fixed_str"])
+
+        with tqdm(total=len(sorted_images), desc="MFR Predict") as pbar:
+            for index, mf_img in enumerate(dataloader):
+                mf_img = mf_img.to(dtype=self.model.dtype)
+                mf_img = mf_img.to(self.device)
+                with torch.no_grad():
+                    output = self.model.generate({"image": mf_img})
+                mfr_res.extend(output["fixed_str"])
+
+                # 更新进度条,每次增加batch_size,但要注意最后一个batch可能不足batch_size
+                current_batch_size = min(batch_size, len(sorted_images) - index * batch_size)
+                pbar.update(current_batch_size)
 
         # Restore original order
         unsorted_results = [""] * len(mfr_res)

+ 1 - 0
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/pytorch_paddle.py

@@ -72,6 +72,7 @@ class PytorchPaddleOCR(TextSystem):
         kwargs['det_model_path'] = os.path.join(ocr_models_dir, det)
         kwargs['rec_model_path'] = os.path.join(ocr_models_dir, rec)
         kwargs['rec_char_dict_path'] = os.path.join(root_dir, 'pytorchocr', 'utils', 'resources', 'dict', dict_file)
+        # kwargs['rec_batch_num'] = 8
 
         kwargs['device'] = get_device()
 

+ 129 - 121
magic_pdf/model/sub_modules/ocr/paddleocr2pytorch/tools/infer/predict_rec.py

@@ -302,131 +302,139 @@ class TextRecognizer(BaseOCRV20):
         batch_num = self.rec_batch_num
         elapse = 0
         # for beg_img_no in range(0, img_num, batch_num):
-        for beg_img_no in tqdm(range(0, img_num, batch_num), desc='OCR-rec Predict', disable=not tqdm_enable):
-            end_img_no = min(img_num, beg_img_no + batch_num)
-            norm_img_batch = []
-            max_wh_ratio = 0
-            for ino in range(beg_img_no, end_img_no):
-                # h, w = img_list[ino].shape[0:2]
-                h, w = img_list[indices[ino]].shape[0:2]
-                wh_ratio = w * 1.0 / h
-                max_wh_ratio = max(max_wh_ratio, wh_ratio)
-            for ino in range(beg_img_no, end_img_no):
-                if self.rec_algorithm == "SAR":
-                    norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
-                        img_list[indices[ino]], self.rec_image_shape)
-                    norm_img = norm_img[np.newaxis, :]
-                    valid_ratio = np.expand_dims(valid_ratio, axis=0)
-                    valid_ratios = []
-                    valid_ratios.append(valid_ratio)
-                    norm_img_batch.append(norm_img)
-
-                elif self.rec_algorithm == "SVTR":
-                    norm_img = self.resize_norm_img_svtr(img_list[indices[ino]],
-                                                         self.rec_image_shape)
-                    norm_img = norm_img[np.newaxis, :]
-                    norm_img_batch.append(norm_img)
-                elif self.rec_algorithm == "SRN":
-                    norm_img = self.process_image_srn(img_list[indices[ino]],
-                                                      self.rec_image_shape, 8,
-                                                      self.max_text_length)
-                    encoder_word_pos_list = []
-                    gsrm_word_pos_list = []
-                    gsrm_slf_attn_bias1_list = []
-                    gsrm_slf_attn_bias2_list = []
-                    encoder_word_pos_list.append(norm_img[1])
-                    gsrm_word_pos_list.append(norm_img[2])
-                    gsrm_slf_attn_bias1_list.append(norm_img[3])
-                    gsrm_slf_attn_bias2_list.append(norm_img[4])
-                    norm_img_batch.append(norm_img[0])
+        with tqdm(total=img_num, desc='OCR-rec Predict', disable=not tqdm_enable) as pbar:
+            index = 0
+            for beg_img_no in range(0, img_num, batch_num):
+                end_img_no = min(img_num, beg_img_no + batch_num)
+                norm_img_batch = []
+                max_wh_ratio = 0
+                for ino in range(beg_img_no, end_img_no):
+                    # h, w = img_list[ino].shape[0:2]
+                    h, w = img_list[indices[ino]].shape[0:2]
+                    wh_ratio = w * 1.0 / h
+                    max_wh_ratio = max(max_wh_ratio, wh_ratio)
+                for ino in range(beg_img_no, end_img_no):
+                    if self.rec_algorithm == "SAR":
+                        norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
+                            img_list[indices[ino]], self.rec_image_shape)
+                        norm_img = norm_img[np.newaxis, :]
+                        valid_ratio = np.expand_dims(valid_ratio, axis=0)
+                        valid_ratios = []
+                        valid_ratios.append(valid_ratio)
+                        norm_img_batch.append(norm_img)
+
+                    elif self.rec_algorithm == "SVTR":
+                        norm_img = self.resize_norm_img_svtr(img_list[indices[ino]],
+                                                             self.rec_image_shape)
+                        norm_img = norm_img[np.newaxis, :]
+                        norm_img_batch.append(norm_img)
+                    elif self.rec_algorithm == "SRN":
+                        norm_img = self.process_image_srn(img_list[indices[ino]],
+                                                          self.rec_image_shape, 8,
+                                                          self.max_text_length)
+                        encoder_word_pos_list = []
+                        gsrm_word_pos_list = []
+                        gsrm_slf_attn_bias1_list = []
+                        gsrm_slf_attn_bias2_list = []
+                        encoder_word_pos_list.append(norm_img[1])
+                        gsrm_word_pos_list.append(norm_img[2])
+                        gsrm_slf_attn_bias1_list.append(norm_img[3])
+                        gsrm_slf_attn_bias2_list.append(norm_img[4])
+                        norm_img_batch.append(norm_img[0])
+                    elif self.rec_algorithm == "CAN":
+                        norm_img = self.norm_img_can(img_list[indices[ino]],
+                                                     max_wh_ratio)
+                        norm_img = norm_img[np.newaxis, :]
+                        norm_img_batch.append(norm_img)
+                        norm_image_mask = np.ones(norm_img.shape, dtype='float32')
+                        word_label = np.ones([1, 36], dtype='int64')
+                        norm_img_mask_batch = []
+                        word_label_list = []
+                        norm_img_mask_batch.append(norm_image_mask)
+                        word_label_list.append(word_label)
+                    else:
+                        norm_img = self.resize_norm_img(img_list[indices[ino]],
+                                                        max_wh_ratio)
+                        norm_img = norm_img[np.newaxis, :]
+                        norm_img_batch.append(norm_img)
+                norm_img_batch = np.concatenate(norm_img_batch)
+                norm_img_batch = norm_img_batch.copy()
+
+                if self.rec_algorithm == "SRN":
+                    starttime = time.time()
+                    encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
+                    gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
+                    gsrm_slf_attn_bias1_list = np.concatenate(
+                        gsrm_slf_attn_bias1_list)
+                    gsrm_slf_attn_bias2_list = np.concatenate(
+                        gsrm_slf_attn_bias2_list)
+
+                    with torch.no_grad():
+                        inp = torch.from_numpy(norm_img_batch)
+                        encoder_word_pos_inp = torch.from_numpy(encoder_word_pos_list)
+                        gsrm_word_pos_inp = torch.from_numpy(gsrm_word_pos_list)
+                        gsrm_slf_attn_bias1_inp = torch.from_numpy(gsrm_slf_attn_bias1_list)
+                        gsrm_slf_attn_bias2_inp = torch.from_numpy(gsrm_slf_attn_bias2_list)
+
+                        inp = inp.to(self.device)
+                        encoder_word_pos_inp = encoder_word_pos_inp.to(self.device)
+                        gsrm_word_pos_inp = gsrm_word_pos_inp.to(self.device)
+                        gsrm_slf_attn_bias1_inp = gsrm_slf_attn_bias1_inp.to(self.device)
+                        gsrm_slf_attn_bias2_inp = gsrm_slf_attn_bias2_inp.to(self.device)
+
+                        backbone_out = self.net.backbone(inp) # backbone_feat
+                        prob_out = self.net.head(backbone_out, [encoder_word_pos_inp, gsrm_word_pos_inp, gsrm_slf_attn_bias1_inp, gsrm_slf_attn_bias2_inp])
+                    # preds = {"predict": prob_out[2]}
+                    preds = {"predict": prob_out["predict"]}
+
+                elif self.rec_algorithm == "SAR":
+                    starttime = time.time()
+                    # valid_ratios = np.concatenate(valid_ratios)
+                    # inputs = [
+                    #     norm_img_batch,
+                    #     valid_ratios,
+                    # ]
+
+                    with torch.no_grad():
+                        inp = torch.from_numpy(norm_img_batch)
+                        inp = inp.to(self.device)
+                        preds = self.net(inp)
+
                 elif self.rec_algorithm == "CAN":
-                    norm_img = self.norm_img_can(img_list[indices[ino]],
-                                                 max_wh_ratio)
-                    norm_img = norm_img[np.newaxis, :]
-                    norm_img_batch.append(norm_img)
-                    norm_image_mask = np.ones(norm_img.shape, dtype='float32')
-                    word_label = np.ones([1, 36], dtype='int64')
-                    norm_img_mask_batch = []
-                    word_label_list = []
-                    norm_img_mask_batch.append(norm_image_mask)
-                    word_label_list.append(word_label)
-                else:
-                    norm_img = self.resize_norm_img(img_list[indices[ino]],
-                                                    max_wh_ratio)
-                    norm_img = norm_img[np.newaxis, :]
-                    norm_img_batch.append(norm_img)
-            norm_img_batch = np.concatenate(norm_img_batch)
-            norm_img_batch = norm_img_batch.copy()
-
-            if self.rec_algorithm == "SRN":
-                starttime = time.time()
-                encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
-                gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
-                gsrm_slf_attn_bias1_list = np.concatenate(
-                    gsrm_slf_attn_bias1_list)
-                gsrm_slf_attn_bias2_list = np.concatenate(
-                    gsrm_slf_attn_bias2_list)
-
-                with torch.no_grad():
-                    inp = torch.from_numpy(norm_img_batch)
-                    encoder_word_pos_inp = torch.from_numpy(encoder_word_pos_list)
-                    gsrm_word_pos_inp = torch.from_numpy(gsrm_word_pos_list)
-                    gsrm_slf_attn_bias1_inp = torch.from_numpy(gsrm_slf_attn_bias1_list)
-                    gsrm_slf_attn_bias2_inp = torch.from_numpy(gsrm_slf_attn_bias2_list)
-
-                    inp = inp.to(self.device)
-                    encoder_word_pos_inp = encoder_word_pos_inp.to(self.device)
-                    gsrm_word_pos_inp = gsrm_word_pos_inp.to(self.device)
-                    gsrm_slf_attn_bias1_inp = gsrm_slf_attn_bias1_inp.to(self.device)
-                    gsrm_slf_attn_bias2_inp = gsrm_slf_attn_bias2_inp.to(self.device)
-
-                    backbone_out = self.net.backbone(inp) # backbone_feat
-                    prob_out = self.net.head(backbone_out, [encoder_word_pos_inp, gsrm_word_pos_inp, gsrm_slf_attn_bias1_inp, gsrm_slf_attn_bias2_inp])
-                # preds = {"predict": prob_out[2]}
-                preds = {"predict": prob_out["predict"]}
-
-            elif self.rec_algorithm == "SAR":
-                starttime = time.time()
-                # valid_ratios = np.concatenate(valid_ratios)
-                # inputs = [
-                #     norm_img_batch,
-                #     valid_ratios,
-                # ]
-
-                with torch.no_grad():
-                    inp = torch.from_numpy(norm_img_batch)
-                    inp = inp.to(self.device)
-                    preds = self.net(inp)
-
-            elif self.rec_algorithm == "CAN":
-                starttime = time.time()
-                norm_img_mask_batch = np.concatenate(norm_img_mask_batch)
-                word_label_list = np.concatenate(word_label_list)
-                inputs = [norm_img_batch, norm_img_mask_batch, word_label_list]
-
-                inp = [torch.from_numpy(e_i) for e_i in inputs]
-                inp = [e_i.to(self.device) for e_i in inp]
-                with torch.no_grad():
-                    outputs = self.net(inp)
-                    outputs = [v.cpu().numpy() for k, v in enumerate(outputs)]
-
-                preds = outputs
+                    starttime = time.time()
+                    norm_img_mask_batch = np.concatenate(norm_img_mask_batch)
+                    word_label_list = np.concatenate(word_label_list)
+                    inputs = [norm_img_batch, norm_img_mask_batch, word_label_list]
 
-            else:
-                starttime = time.time()
+                    inp = [torch.from_numpy(e_i) for e_i in inputs]
+                    inp = [e_i.to(self.device) for e_i in inp]
+                    with torch.no_grad():
+                        outputs = self.net(inp)
+                        outputs = [v.cpu().numpy() for k, v in enumerate(outputs)]
 
-                with torch.no_grad():
-                    inp = torch.from_numpy(norm_img_batch)
-                    inp = inp.to(self.device)
-                    prob_out = self.net(inp)
+                    preds = outputs
 
-                if isinstance(prob_out, list):
-                    preds = [v.cpu().numpy() for v in prob_out]
                 else:
-                    preds = prob_out.cpu().numpy()
+                    starttime = time.time()
+
+                    with torch.no_grad():
+                        inp = torch.from_numpy(norm_img_batch)
+                        inp = inp.to(self.device)
+                        prob_out = self.net(inp)
+
+                    if isinstance(prob_out, list):
+                        preds = [v.cpu().numpy() for v in prob_out]
+                    else:
+                        preds = prob_out.cpu().numpy()
+
+                rec_result = self.postprocess_op(preds)
+                for rno in range(len(rec_result)):
+                    rec_res[indices[beg_img_no + rno]] = rec_result[rno]
+                elapse += time.time() - starttime
+
+                # 更新进度条,每次增加batch_size,但要注意最后一个batch可能不足batch_size
+                current_batch_size = min(batch_num, img_num - index * batch_num)
+                index += 1
+                pbar.update(current_batch_size)
 
-            rec_result = self.postprocess_op(preds)
-            for rno in range(len(rec_result)):
-                rec_res[indices[beg_img_no + rno]] = rec_result[rno]
-            elapse += time.time() - starttime
         return rec_res, elapse