|
|
@@ -14,7 +14,7 @@
|
|
|
|
|
|
import copy
|
|
|
import inspect
|
|
|
-from typing import Optional, Union
|
|
|
+from typing import List, Optional, Union
|
|
|
|
|
|
import paddle
|
|
|
import paddle.distributed as dist
|
|
|
@@ -86,7 +86,7 @@ def get_scale_by_dtype(dtype: str = None, return_positive: bool = True) -> float
|
|
|
def get_unfinished_flag(
|
|
|
input_ids: Tensor,
|
|
|
unfinished_flag: Tensor,
|
|
|
- eos_token_id: Union[int, list[int], list[list[int]]],
|
|
|
+ eos_token_id: Union[int, List[int], List[List[int]]],
|
|
|
) -> Tensor:
|
|
|
"""get unfinished flag for generation step
|
|
|
|