|
|
@@ -468,7 +468,7 @@ class UnimernetModel(VisionEncoderDecoderModel):
|
|
|
).loss
|
|
|
return {"loss": loss}
|
|
|
|
|
|
- def generate(self, samples, do_sample: bool = False, temperature: float = 0.2, top_p: float = 0.95):
|
|
|
+ def generate(self, samples, do_sample: bool = False, temperature: float = 0.2, top_p: float = 0.95, batch_size=64):
|
|
|
pixel_values = samples["image"]
|
|
|
num_channels = pixel_values.shape[1]
|
|
|
if num_channels == 1:
|
|
|
@@ -478,7 +478,13 @@ class UnimernetModel(VisionEncoderDecoderModel):
|
|
|
if do_sample:
|
|
|
kwargs["temperature"] = temperature
|
|
|
kwargs["top_p"] = top_p
|
|
|
-
|
|
|
+
|
|
|
+ if self.tokenizer.tokenizer.model_max_length > 1152:
|
|
|
+ if batch_size <= 32:
|
|
|
+ self.tokenizer.tokenizer.model_max_length = 1152 # 6g
|
|
|
+ else:
|
|
|
+ self.tokenizer.tokenizer.model_max_length = 1344 # 8g
|
|
|
+
|
|
|
outputs = super().generate(
|
|
|
pixel_values=pixel_values,
|
|
|
max_new_tokens=self.tokenizer.tokenizer.model_max_length, # required
|