浏览代码

feat: adjust batch size calculation and enhance device management in model heads

myhloli 4 周之前
父节点
当前提交
915ba87f7d

+ 2 - 0
mineru/model/mfr/pp_formulanet_plus_m/predict_formula.py

@@ -65,6 +65,7 @@ class FormulaRecognizer(BaseOCRV20):
         )
 
     def predict(self, img_list, batch_size: int = 64):
+        batch_size = int(0.5 * batch_size)
         batch_imgs = self.pre_tfs["UniMERNetImgDecode"](imgs=img_list)
         batch_imgs = self.pre_tfs["UniMERNetTestTransform"](imgs=batch_imgs)
         batch_imgs = self.pre_tfs["LatexImageFormat"](imgs=batch_imgs)
@@ -78,6 +79,7 @@ class FormulaRecognizer(BaseOCRV20):
                     batch_data = inp[index: index + batch_size]
                     batch_preds = [self.net(batch_data)]
                     batch_preds = [p.reshape([-1]) for p in batch_preds[0]]
+                    batch_preds = [bp.cpu().numpy() for bp in batch_preds]
                     rec_formula += self.post_op(batch_preds)
                     pbar.update(len(batch_preds))
         return rec_formula

+ 11 - 7
mineru/model/utils/pytorchocr/modeling/heads/rec_ppformulanet_head.py

@@ -23,6 +23,7 @@ from dataclasses import dataclass, fields, is_dataclass
 
 from sympy import totient
 
+from mineru.utils.config_reader import get_device
 from .rec_unimernet_head import (
     MBartForCausalLM,
     MBartDecoder,
@@ -797,6 +798,7 @@ class PPFormulaNet_Head(UniMERNetHead):
                 generation_config["forced_eos_token_id"],
             )
         )
+        self.device = torch.device(get_device())
 
     def prepare_inputs_for_generation(
             self,
@@ -891,8 +893,8 @@ class PPFormulaNet_Head(UniMERNetHead):
 
     def stopping_criteria(self, input_ids):
         if self.is_export:
-            return input_ids[:, -1] == torch.Tensor([self.eos_token_id])
-        is_done = torch.isin(input_ids[:, -1], torch.Tensor([self.eos_token_id]))
+            return input_ids[:, -1].cpu() == torch.Tensor([self.eos_token_id])
+        is_done = torch.isin(input_ids[:, -1].cpu(), torch.Tensor([self.eos_token_id]))
         return is_done
 
     def stopping_criteria_parallel(self, input_ids):
@@ -997,6 +999,7 @@ class PPFormulaNet_Head(UniMERNetHead):
                         torch.ones(
                             (batch_size, parallel_step),
                             dtype=torch.int64,
+                            device=self.device,
                         )
                         * decoder_start_token_id
                 )
@@ -1005,6 +1008,7 @@ class PPFormulaNet_Head(UniMERNetHead):
                         torch.ones(
                             (batch_size, 1),
                             dtype=torch.int64,
+                            device=self.device,
                         )
                         * decoder_start_token_id
                 )
@@ -1078,11 +1082,11 @@ class PPFormulaNet_Head(UniMERNetHead):
         eos_token = self.eos_token_id
         if use_parallel:
             unfinished_sequences = torch.ones(
-                [batch_size, parallel_step], dtype=torch.int64
+                [batch_size, parallel_step], dtype=torch.int64, device=self.device
             )
             parallel_length = math.ceil(self.max_seq_len // parallel_step)
         else:
-            unfinished_sequences = torch.ones(batch_size, dtype=torch.int64)
+            unfinished_sequences = torch.ones(batch_size, dtype=torch.int64, device=self.device)
             parallel_length = self.max_seq_len
 
         i_idx = 0
@@ -1103,7 +1107,7 @@ class PPFormulaNet_Head(UniMERNetHead):
             model_inputs = self.prepare_inputs_for_generation_export(
                 past_key_values=past_key_values, **model_kwargs
             )
-            decoder_attention_mask = torch.ones(input_ids.shape)
+            decoder_attention_mask = torch.ones(input_ids.shape, device=self.device,)
 
             outputs = self.generate_single_iter(
                 decoder_input_ids=decoder_input_ids,
@@ -1147,12 +1151,12 @@ class PPFormulaNet_Head(UniMERNetHead):
             if use_parallel:
                 unfinished_sequences = (
                         unfinished_sequences
-                        & ~self.stopping_criteria_parallel(input_ids).to(torch.int64)
+                        & ~self.stopping_criteria_parallel(input_ids).to(torch.int64).to(self.device)
                 )
             else:
                 unfinished_sequences = unfinished_sequences & ~self.stopping_criteria(
                     input_ids
-                ).to(torch.int64)
+                ).to(torch.int64).to(self.device)
 
             if (
                     eos_token is not None

+ 11 - 4
mineru/model/utils/pytorchocr/modeling/heads/rec_unimernet_head.py

@@ -14,6 +14,8 @@ from torch import Tensor
 import torch.nn.functional as F
 from torch.nn import CrossEntropyLoss
 
+from mineru.utils.config_reader import get_device
+
 
 class ModelOutput(OrderedDict):
 
@@ -441,13 +443,14 @@ class MBartLearnedPositionalEmbedding(nn.Embedding):
     def __init__(self, num_embeddings, embedding_dim):
         self.offset = 2
         super().__init__(num_embeddings + self.offset, embedding_dim)
+        self.device = torch.device(get_device())
 
     def forward(self, input_ids, past_key_values_length=0):
         """`input_ids' shape is expected to be [bsz x seqlen]."""
         bsz, seq_len = input_ids.shape[:2]
         positions = torch.arange(
             past_key_values_length, past_key_values_length + seq_len, dtype=torch.int64
-        ).expand([bsz, -1])
+        ).expand([bsz, -1]).to(self.device)
         return nn.Embedding.forward(self, positions + self.offset)
 
 
@@ -656,6 +659,7 @@ class MBartDecoderLayer(nn.Module):
         self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
         self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
         self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+        self.device = torch.device(get_device())
 
     def forward(
             self,
@@ -672,9 +676,12 @@ class MBartDecoderLayer(nn.Module):
 
         residual = hidden_states
         hidden_states = self.self_attn_layer_norm(hidden_states)
-        self_attn_past_key_value = (
-            past_key_value[:2] if past_key_value is not None else None
-        )
+
+        self_attn_past_key_value = None
+        if past_key_value is not None:
+            self_attn_past_key_value = tuple(
+                t.to(self.device) if isinstance(t, torch.Tensor) else t for t in past_key_value[:2]
+            )
 
         hidden_states, self_attn_weights, present_key_value = self.self_attn(
             hidden_states=hidden_states,