configuration_utils.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. import json
  16. import os
  17. import warnings
  18. from typing import Any, Dict, Optional, Union
  19. from paddle.common_ops_import import convert_dtype
  20. from ......utils import logging
  21. from ..transformers.configuration_utils import PretrainedConfig
  22. from ..utils import GENERATION_CONFIG_NAME, resolve_file_path
  23. DEFAULT_MAX_NEW_TOKENS = 20
  24. class GenerationConfig:
  25. r"""
  26. Arg:
  27. > Parameters that control the length of the output
  28. max_length (int, optional): The maximum length of the sequence to
  29. be generated. Default to 20.
  30. min_length (int, optional): The minimum length of the sequence to
  31. be generated. Default to 0.
  32. decode_strategy (str, optional): The decoding strategy in generation.
  33. Currently, there are three decoding strategies supported:
  34. "greedy_search", "sampling" and "beam_search". Default to
  35. "greedy_search".
  36. temperature (float, optional): The value used to module the next
  37. token probabilities in the "sampling" strategy. Default to 1.0,
  38. which means no effect.
  39. top_k (int, optional): The number of highest probability tokens to
  40. keep for top-k-filtering in the "sampling" strategy. Default to
  41. 0, which means no effect.
  42. top_p (float, optional): The cumulative probability for
  43. top-p-filtering in the "sampling" strategy. The value should
  44. satisfy :math:`0 <= top\_p < 1`. Default to 1.0, which means no
  45. effect.
  46. repetition_penalty (float, optional):
  47. The parameter for repetition penalty. 1.0 means no penalty. See `this paper
  48. <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details. Defaults to 1.0.
  49. num_beams (int, optional): The number of beams in the "beam_search"
  50. strategy. Default to 1.
  51. num_beam_groups (int, optional):
  52. Number of groups to divide `num_beams` into in order to use DIVERSE
  53. BEAM SEARCH. See `this paper <https://arxiv.org/pdf/1610.02424.pdf>`__
  54. for more details. Default to 1.
  55. length_penalty (float, optional): The exponential penalty to the
  56. sequence length in the "beam_search" strategy. The larger this
  57. param is, the more that the model would generate shorter
  58. sequences. Default to 0.0, which means no penalty.
  59. early_stopping (bool, optional): Whether to stop searching in the
  60. "beam_search" strategy when at least `num_beams` sentences are
  61. finished per batch or not. Default to False.
  62. bos_token_id (int, optional): The id of the `bos_token`. Default to
  63. None.
  64. eos_token_id (int, optional): The id of the `eos_token`. Default to
  65. None.
  66. pad_token_id (int, optional): The id of the `pad_token`. Default to
  67. None.
  68. decoder_start_token_id (int, optional): The start token id for
  69. encoder-decoder models. Default to None.
  70. forced_bos_token_id (int, optional): The id of the token to force as
  71. the first generated token. Usually use for multilingual models.
  72. Default to None.
  73. forced_eos_token_id (int, optional): The id of the token to force as
  74. the last generated token. Default to None.
  75. num_return_sequences (int, optional): The number of returned
  76. sequences for each sequence in the batch. Default to 1.
  77. diversity_rate (float, optional): If num_beam_groups is 1, this is the
  78. diversity_rate for Diverse Siblings Search. See
  79. `this paper https://arxiv.org/abs/1611.08562`__ for more details.
  80. If not, this is the diversity_rate for DIVERSE BEAM SEARCH.
  81. use_cache: (bool, optional): Whether to use the model cache to
  82. speed up decoding. Default to True.
  83. use_fast: (bool, optional): Whether to use fast entry of model
  84. for FastGeneration. Default to False.
  85. use_fp16_decoding: (bool, optional): Whether to use fp16 for decoding.
  86. Only works when fast entry is available. Default to False.
  87. trunc_input: (bool, optional): Whether to truncate the inputs from
  88. output sequences . Default to True.
  89. model_kwargs (dict): It can be used to specify additional kwargs
  90. passed to the model.
  91. """
  92. def _get_generation_mode(self):
  93. if hasattr(self, "num_beams") and self.num_beams == 1:
  94. if hasattr(self, "do_sample") and self.do_sample is True:
  95. generation_mode = "sampling"
  96. else:
  97. generation_mode = "greedy_search"
  98. else:
  99. generation_mode = "beam_search"
  100. return generation_mode
  101. def __init__(self, **kwargs):
  102. # Parameters that control the length of the output
  103. self.max_new_tokens = kwargs.get("max_new_tokens", DEFAULT_MAX_NEW_TOKENS)
  104. if "min_new_token" in kwargs:
  105. logging.warning(
  106. "<min_new_token> field is deprecated. Please use <min_new_tokens> instead."
  107. )
  108. kwargs["min_new_tokens"] = kwargs.pop("min_new_token")
  109. self.min_new_tokens = kwargs.pop("min_new_tokens", 0)
  110. self.max_length = kwargs.pop("max_length", 0)
  111. self.min_length = kwargs.pop("min_length", 0)
  112. self.early_stopping = kwargs.pop("early_stopping", False)
  113. self.trunc_input = kwargs.pop("trunc_input", True)
  114. # Parameters for manipulation of the model output logits
  115. self.diversity_rate = kwargs.pop("diversity_rate", 0.0)
  116. self.temperature = kwargs.pop("temperature", 1.0)
  117. self.top_k = kwargs.pop("top_k", 50)
  118. self.top_p = kwargs.pop("top_p", 1.0)
  119. self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
  120. self.length_penalty = kwargs.pop("length_penalty", 1.0)
  121. self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", None)
  122. self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
  123. self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
  124. self.num_beams = kwargs.pop("num_beams", 1)
  125. self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
  126. self.use_cache = kwargs.pop("use_cache", True)
  127. # Parameters that define the output variables of `generate`
  128. self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
  129. # Special tokens that can be used at generation time
  130. self.pad_token_id = kwargs.pop("pad_token_id", None)
  131. self.bos_token_id = kwargs.pop("bos_token_id", None)
  132. self.eos_token_id = kwargs.pop("eos_token_id", None)
  133. # Generation parameters exclusive to encoder-decoder models
  134. self.use_fast = kwargs.pop("use_fast", False)
  135. self.use_fp16_decoding = kwargs.pop("use_fp16_decoding", False)
  136. self.fast_ptq_sampling = kwargs.pop("fast_ptq_sampling", False)
  137. self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
  138. self._from_model_config = kwargs.pop("_from_model_config", False)
  139. # Additional attributes without default values
  140. if not self._from_model_config:
  141. # we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a
  142. # model's default configuration file
  143. for key, value in kwargs.items():
  144. try:
  145. setattr(self, key, value)
  146. except AttributeError as err:
  147. logging.error(f"Can't set {key} with value {value} for {self}")
  148. raise err
  149. # Parameters that control the generation strategy used
  150. if "decode_strategy" in kwargs:
  151. self.decode_strategy = kwargs.pop("decode_strategy")
  152. else:
  153. self.decode_strategy = self._get_generation_mode()
  154. # Validate the values of the attributes
  155. self.validate(is_init=True)
  156. def to_dict(self):
  157. return copy.deepcopy(self.__dict__)
  158. def __eq__(self, other):
  159. if not isinstance(other, GenerationConfig):
  160. return False
  161. self_dict = self.__dict__.copy()
  162. other_dict = other.__dict__.copy()
  163. # ignore metadata
  164. for metadata_field in ["_from_model_config", "paddlenlp_version"]:
  165. self_dict.pop(metadata_field, None)
  166. other_dict.pop(metadata_field, None)
  167. return self_dict == other_dict
  168. def __repr__(self):
  169. return f"{self.__class__.__name__} {self.to_json_string()}"
  170. def validate(self, is_init=False):
  171. """
  172. Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence
  173. of parameterization that can be detected as incorrect from the configuration instance alone.
  174. Note that some parameters are best validated at generate runtime, as they may depend on other inputs and/or the
  175. model, such as parameters related to the generation length.
  176. """
  177. # Validation of individual attributes
  178. if self.early_stopping not in {True, False, "never"}:
  179. raise ValueError(
  180. f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}."
  181. )
  182. # Validation of attribute relations:
  183. fix_location = ""
  184. if is_init:
  185. fix_location = (
  186. " This was detected when initializing the generation config instance, which means the corresponding "
  187. "file may hold incorrect parameterization and should be fixed."
  188. )
  189. # 1. detect sampling-only parameterization when not in sampling mode
  190. if self.decode_strategy == "greedy_search":
  191. greedy_wrong_parameter_msg = (
  192. "using greedy search strategy. However, `{flag_name}` is set to `{flag_value}` -- this flag is only "
  193. 'used in sample-based generation modes. You should set `decode_strategy="greedy_search" ` or unset `{flag_name}`.'
  194. + fix_location
  195. )
  196. if self.temperature != 1.0:
  197. warnings.warn(
  198. greedy_wrong_parameter_msg.format(
  199. flag_name="temperature", flag_value=self.temperature
  200. ),
  201. UserWarning,
  202. )
  203. if self.top_p != 1.0:
  204. warnings.warn(
  205. greedy_wrong_parameter_msg.format(
  206. flag_name="top_p", flag_value=self.top_p
  207. ),
  208. UserWarning,
  209. )
  210. # 2. detect beam-only parameterization when not in beam mode
  211. if self.decode_strategy != "beam_search":
  212. single_beam_wrong_parameter_msg = (
  213. "`num_beams` is set to 1. However, `{flag_name}` is set to `{flag_value}` -- this flag is only used "
  214. "in beam-based generation modes. You should set `num_beams>1` or unset `{flag_name}`."
  215. + fix_location
  216. )
  217. if self.early_stopping is not False:
  218. warnings.warn(
  219. single_beam_wrong_parameter_msg.format(
  220. flag_name="early_stopping", flag_value=self.early_stopping
  221. ),
  222. UserWarning,
  223. )
  224. if self.num_beam_groups != 1:
  225. warnings.warn(
  226. single_beam_wrong_parameter_msg.format(
  227. flag_name="num_beam_groups", flag_value=self.num_beam_groups
  228. ),
  229. UserWarning,
  230. )
  231. if self.length_penalty != 1.0:
  232. warnings.warn(
  233. single_beam_wrong_parameter_msg.format(
  234. flag_name="length_penalty", flag_value=self.length_penalty
  235. ),
  236. UserWarning,
  237. )
  238. # 4. check `num_return_sequences`
  239. if self.num_return_sequences != 1:
  240. if self.decode_strategy == "greedy_search":
  241. raise ValueError(
  242. "Greedy methods without beam search do not support `num_return_sequences` different than 1 "
  243. f"(got {self.num_return_sequences})."
  244. )
  245. @classmethod
  246. def from_pretrained(
  247. cls,
  248. pretrained_model_name_or_path: Union[str, os.PathLike],
  249. from_hf_hub: bool = False,
  250. from_aistudio: bool = False,
  251. config_file_name: Optional[Union[str, os.PathLike]] = None,
  252. cache_dir: Optional[Union[str, os.PathLike]] = None,
  253. force_download: bool = False,
  254. **kwargs,
  255. ) -> "GenerationConfig":
  256. r"""
  257. Instantiate a [`GenerationConfig`] from a generation configuration file.
  258. Args:
  259. pretrained_model_name_or_path (`str` or `os.PathLike`):
  260. This can be either:
  261. - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
  262. paddlenlp bos server. Valid model ids can be located at the root-level, like `bert-base-uncased`, or
  263. namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`.
  264. - a path to a *directory* containing a configuration file saved using the
  265. [`~PretrainedConfig.save_pretrained`] method, e.g., `./my_model_directory/`.
  266. - a path or url to a saved configuration JSON *file*, e.g., `./my_model_directory/configuration.json`.
  267. from_hf_hub (bool, *optional*):
  268. load config from huggingface hub: https://huggingface.co/models
  269. cache_dir (`str` or `os.PathLike`, *optional*):
  270. Path to a directory in which a downloaded pretrained model configuration should be cached if the
  271. standard cache should not be used.
  272. force_download (`bool`, *optional*, defaults to `False`):
  273. Whether or not to force to (re-)download the configuration files and override the cached versions if
  274. they exist.
  275. return_unused_kwargs (`bool`, *optional*, defaults to `False`):
  276. If `False`, then this function returns just the final configuration object.
  277. If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a
  278. dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the
  279. part of `kwargs` which has not been used to update `config` and is otherwise ignored.
  280. kwargs (`Dict[str, Any]`, *optional*):
  281. The values in kwargs of any keys which are configuration attributes will be used to override the loaded
  282. values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
  283. by the `return_unused_kwargs` keyword parameter.
  284. Returns:
  285. [`GenerationConfig`]: The configuration object instantiated from this pretrained model.
  286. Examples:
  287. ```python
  288. >>> from paddlenlp.transformers import GenerationConfig
  289. >>> generation_config = GenerationConfig.from_pretrained("gpt2")
  290. >>> # E.g. config was saved using *save_pretrained('./test/saved_model/')*
  291. >>> generation_config.save_pretrained("./test/saved_model/")
  292. >>> generation_config = GenerationConfig.from_pretrained("./test/saved_model/")
  293. >>> # You can also specify configuration names to your generation configuration file
  294. >>> generation_config.save_pretrained("./test/saved_model/", config_file_name="my_configuration.json")
  295. >>> generation_config = GenerationConfig.from_pretrained("./test/saved_model/", "my_configuration.json")
  296. >>> # If you'd like to try a minor variation to an existing configuration, you can also pass generation
  297. >>> # arguments to `.from_pretrained()`. Be mindful that typos and unused arguments will be ignored
  298. >>> generation_config, unused_kwargs = GenerationConfig.from_pretrained(
  299. ... "gpt2", top_k=1, foo=False, do_sample=True, return_unused_kwargs=True
  300. ... )
  301. >>> generation_config.top_k
  302. 1
  303. >>> unused_kwargs
  304. {'foo': False}
  305. ```"""
  306. config_file_name = (
  307. config_file_name if config_file_name is not None else GENERATION_CONFIG_NAME
  308. )
  309. subfolder = kwargs.pop("subfolder", "")
  310. if subfolder is None:
  311. subfolder = ""
  312. # NOTE resolve_file_path 适配
  313. resolved_config_file = resolve_file_path(
  314. pretrained_model_name_or_path,
  315. [config_file_name],
  316. subfolder,
  317. cache_dir=cache_dir,
  318. force_download=force_download,
  319. from_aistudio=from_aistudio,
  320. from_hf_hub=from_hf_hub,
  321. )
  322. assert (
  323. resolved_config_file is not None
  324. ), f"please make sure {config_file_name} under {pretrained_model_name_or_path}"
  325. try:
  326. logging.info(f"Loading configuration file {resolved_config_file}")
  327. # Load config dict
  328. config_dict = cls._dict_from_json_file(resolved_config_file)
  329. except (json.JSONDecodeError, UnicodeDecodeError):
  330. raise EnvironmentError(
  331. f"Config file<'{resolved_config_file}'> is not a valid JSON file."
  332. )
  333. return cls.from_dict(config_dict, **kwargs)
  334. @classmethod
  335. def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
  336. with open(json_file, "r", encoding="utf-8") as reader:
  337. text = reader.read()
  338. return json.loads(text)
  339. def dict_paddle_dtype_to_str(self, d: Dict[str, Any]) -> None:
  340. """
  341. Checks whether the passed dictionary and its nested dicts have a *paddle_dtype* key and if it's not None,
  342. converts paddle.dtype to a string of just the type. For example, `paddle.float32` get converted into *"float32"*
  343. string, which can then be stored in the json format.
  344. """
  345. if d.get("dtype", None) is not None and not isinstance(d["dtype"], str):
  346. d["dtype"] = convert_dtype(d["dtype"])
  347. for value in d.values():
  348. if isinstance(value, dict):
  349. self.dict_paddle_dtype_to_str(value)
  350. @classmethod
  351. def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "GenerationConfig":
  352. """
  353. Instantiates a [`GenerationConfig`] from a Python dictionary of parameters.
  354. Args:
  355. config_dict (`Dict[str, Any]`):
  356. Dictionary that will be used to instantiate the configuration object.
  357. kwargs (`Dict[str, Any]`):
  358. Additional parameters from which to initialize the configuration object.
  359. Returns:
  360. [`GenerationConfig`]: The configuration object instantiated from those parameters.
  361. """
  362. return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
  363. config = cls(**{**config_dict, **kwargs})
  364. unused_kwargs = config.update(**kwargs)
  365. # logging.info(f"Generate config {config}")
  366. if return_unused_kwargs:
  367. return config, unused_kwargs
  368. else:
  369. return config
  370. def to_diff_dict(self) -> Dict[str, Any]:
  371. """
  372. Removes all attributes from config which correspond to the default config attributes for better readability and
  373. serializes to a Python dictionary.
  374. Returns:
  375. `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
  376. """
  377. config_dict = self.to_dict()
  378. # get the default config dict
  379. default_config_dict = GenerationConfig().to_dict()
  380. serializable_config_dict = {}
  381. # only serialize values that differ from the default config
  382. for key, value in config_dict.items():
  383. if (
  384. key not in default_config_dict
  385. or key == "transformers_version"
  386. or value != default_config_dict[key]
  387. ):
  388. serializable_config_dict[key] = value
  389. self.dict_paddle_dtype_to_str(serializable_config_dict)
  390. return serializable_config_dict
  391. def to_json_string(self, use_diff: bool = True) -> str:
  392. """
  393. Serializes this instance to a JSON string.
  394. Args:
  395. use_diff (`bool`, *optional*, defaults to `True`):
  396. If set to `True`, only the difference between the config instance and the default `GenerationConfig()`
  397. is serialized to JSON string.
  398. Returns:
  399. `str`: String containing all the attributes that make up this configuration instance in JSON format.
  400. """
  401. if use_diff is True:
  402. config_dict = self.to_diff_dict()
  403. else:
  404. config_dict = self.to_dict()
  405. return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
  406. def to_json_file(
  407. self, json_file_path: Union[str, os.PathLike], use_diff: bool = True
  408. ):
  409. """
  410. Save this instance to a JSON file.
  411. Args:
  412. json_file_path (`str` or `os.PathLike`):
  413. Path to the JSON file in which this configuration instance's parameters will be saved.
  414. use_diff (`bool`, *optional*, defaults to `True`):
  415. If set to `True`, only the difference between the config instance and the default `GenerationConfig()`
  416. is serialized to JSON file.
  417. """
  418. with open(json_file_path, "w", encoding="utf-8") as writer:
  419. writer.write(self.to_json_string(use_diff=use_diff))
  420. @classmethod
  421. def from_model_config(cls, model_config: PretrainedConfig) -> "GenerationConfig":
  422. """
  423. Instantiates a [`GenerationConfig`] from a [`PretrainedConfig`]. This function is useful to convert legacy
  424. [`PretrainedConfig`] objects, which may contain generation parameters, into a stand-alone [`GenerationConfig`].
  425. Args:
  426. model_config (`PretrainedConfig`):
  427. The model config that will be used to instantiate the generation config.
  428. Returns:
  429. [`GenerationConfig`]: The configuration object instantiated from those parameters.
  430. """
  431. config_dict = model_config.to_dict()
  432. config_dict.pop("_from_model_config", None)
  433. config = cls.from_dict(
  434. config_dict, return_unused_kwargs=False, _from_model_config=True
  435. )
  436. # Special case: some models have generation attributes set in the decoder. Use them if still unset in the
  437. # generation config.
  438. for decoder_name in ("decoder", "generator", "text_config"):
  439. if decoder_name in config_dict:
  440. default_generation_config = GenerationConfig()
  441. decoder_config = config_dict[decoder_name]
  442. for attr in config.to_dict().keys():
  443. if attr in decoder_config and getattr(config, attr) == getattr(
  444. default_generation_config, attr
  445. ):
  446. setattr(config, attr, decoder_config[attr])
  447. return config
  448. def update(self, **kwargs):
  449. """
  450. Updates attributes of this class instance with attributes from `kwargs` if they match existing atributtes,
  451. returning all the unused kwargs.
  452. Args:
  453. kwargs (`Dict[str, Any]`):
  454. Dictionary of attributes to tentatively update this class.
  455. Returns:
  456. `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
  457. """
  458. to_remove = []
  459. for key, value in kwargs.items():
  460. if hasattr(self, key):
  461. setattr(self, key, value)
  462. to_remove.append(key)
  463. # remove all the attributes that were updated, without modifying the input dict
  464. unused_kwargs = {
  465. key: value for key, value in kwargs.items() if key not in to_remove
  466. }
  467. return unused_kwargs