|
@@ -23,6 +23,7 @@ from dataclasses import dataclass, fields, is_dataclass
|
|
|
|
|
|
|
|
from sympy import totient
|
|
from sympy import totient
|
|
|
|
|
|
|
|
|
|
+from mineru.utils.config_reader import get_device
|
|
|
from .rec_unimernet_head import (
|
|
from .rec_unimernet_head import (
|
|
|
MBartForCausalLM,
|
|
MBartForCausalLM,
|
|
|
MBartDecoder,
|
|
MBartDecoder,
|
|
@@ -797,6 +798,7 @@ class PPFormulaNet_Head(UniMERNetHead):
|
|
|
generation_config["forced_eos_token_id"],
|
|
generation_config["forced_eos_token_id"],
|
|
|
)
|
|
)
|
|
|
)
|
|
)
|
|
|
|
|
+ self.device = torch.device(get_device())
|
|
|
|
|
|
|
|
def prepare_inputs_for_generation(
|
|
def prepare_inputs_for_generation(
|
|
|
self,
|
|
self,
|
|
@@ -891,8 +893,8 @@ class PPFormulaNet_Head(UniMERNetHead):
|
|
|
|
|
|
|
|
def stopping_criteria(self, input_ids):
|
|
def stopping_criteria(self, input_ids):
|
|
|
if self.is_export:
|
|
if self.is_export:
|
|
|
- return input_ids[:, -1] == torch.Tensor([self.eos_token_id])
|
|
|
|
|
- is_done = torch.isin(input_ids[:, -1], torch.Tensor([self.eos_token_id]))
|
|
|
|
|
|
|
+ return input_ids[:, -1].cpu() == torch.Tensor([self.eos_token_id])
|
|
|
|
|
+ is_done = torch.isin(input_ids[:, -1].cpu(), torch.Tensor([self.eos_token_id]))
|
|
|
return is_done
|
|
return is_done
|
|
|
|
|
|
|
|
def stopping_criteria_parallel(self, input_ids):
|
|
def stopping_criteria_parallel(self, input_ids):
|
|
@@ -997,6 +999,7 @@ class PPFormulaNet_Head(UniMERNetHead):
|
|
|
torch.ones(
|
|
torch.ones(
|
|
|
(batch_size, parallel_step),
|
|
(batch_size, parallel_step),
|
|
|
dtype=torch.int64,
|
|
dtype=torch.int64,
|
|
|
|
|
+ device=self.device,
|
|
|
)
|
|
)
|
|
|
* decoder_start_token_id
|
|
* decoder_start_token_id
|
|
|
)
|
|
)
|
|
@@ -1005,6 +1008,7 @@ class PPFormulaNet_Head(UniMERNetHead):
|
|
|
torch.ones(
|
|
torch.ones(
|
|
|
(batch_size, 1),
|
|
(batch_size, 1),
|
|
|
dtype=torch.int64,
|
|
dtype=torch.int64,
|
|
|
|
|
+ device=self.device,
|
|
|
)
|
|
)
|
|
|
* decoder_start_token_id
|
|
* decoder_start_token_id
|
|
|
)
|
|
)
|
|
@@ -1078,11 +1082,11 @@ class PPFormulaNet_Head(UniMERNetHead):
|
|
|
eos_token = self.eos_token_id
|
|
eos_token = self.eos_token_id
|
|
|
if use_parallel:
|
|
if use_parallel:
|
|
|
unfinished_sequences = torch.ones(
|
|
unfinished_sequences = torch.ones(
|
|
|
- [batch_size, parallel_step], dtype=torch.int64
|
|
|
|
|
|
|
+ [batch_size, parallel_step], dtype=torch.int64, device=self.device
|
|
|
)
|
|
)
|
|
|
parallel_length = math.ceil(self.max_seq_len // parallel_step)
|
|
parallel_length = math.ceil(self.max_seq_len // parallel_step)
|
|
|
else:
|
|
else:
|
|
|
- unfinished_sequences = torch.ones(batch_size, dtype=torch.int64)
|
|
|
|
|
|
|
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.int64, device=self.device)
|
|
|
parallel_length = self.max_seq_len
|
|
parallel_length = self.max_seq_len
|
|
|
|
|
|
|
|
i_idx = 0
|
|
i_idx = 0
|
|
@@ -1103,7 +1107,7 @@ class PPFormulaNet_Head(UniMERNetHead):
|
|
|
model_inputs = self.prepare_inputs_for_generation_export(
|
|
model_inputs = self.prepare_inputs_for_generation_export(
|
|
|
past_key_values=past_key_values, **model_kwargs
|
|
past_key_values=past_key_values, **model_kwargs
|
|
|
)
|
|
)
|
|
|
- decoder_attention_mask = torch.ones(input_ids.shape)
|
|
|
|
|
|
|
+ decoder_attention_mask = torch.ones(input_ids.shape, device=self.device,)
|
|
|
|
|
|
|
|
outputs = self.generate_single_iter(
|
|
outputs = self.generate_single_iter(
|
|
|
decoder_input_ids=decoder_input_ids,
|
|
decoder_input_ids=decoder_input_ids,
|
|
@@ -1147,12 +1151,12 @@ class PPFormulaNet_Head(UniMERNetHead):
|
|
|
if use_parallel:
|
|
if use_parallel:
|
|
|
unfinished_sequences = (
|
|
unfinished_sequences = (
|
|
|
unfinished_sequences
|
|
unfinished_sequences
|
|
|
- & ~self.stopping_criteria_parallel(input_ids).to(torch.int64)
|
|
|
|
|
|
|
+ & ~self.stopping_criteria_parallel(input_ids).to(torch.int64).to(self.device)
|
|
|
)
|
|
)
|
|
|
else:
|
|
else:
|
|
|
unfinished_sequences = unfinished_sequences & ~self.stopping_criteria(
|
|
unfinished_sequences = unfinished_sequences & ~self.stopping_criteria(
|
|
|
input_ids
|
|
input_ids
|
|
|
- ).to(torch.int64)
|
|
|
|
|
|
|
+ ).to(torch.int64).to(self.device)
|
|
|
|
|
|
|
|
if (
|
|
if (
|
|
|
eos_token is not None
|
|
eos_token is not None
|