configuration_utils.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. import os
  16. import warnings
  17. from typing import Any, Dict, Optional, Union
  18. from paddle.common_ops_import import convert_dtype
  19. from ......utils import logging
  20. from ..transformers.configuration_utils import PretrainedConfig
  21. from ..utils import GENERATION_CONFIG_NAME, resolve_file_path
  22. DEFAULT_MAX_NEW_TOKENS = 20
  23. class GenerationConfig:
  24. r"""
  25. Arg:
  26. > Parameters that control the length of the output
  27. max_length (int, optional): The maximum length of the sequence to
  28. be generated. Default to 20.
  29. min_length (int, optional): The minimum length of the sequence to
  30. be generated. Default to 0.
  31. decode_strategy (str, optional): The decoding strategy in generation.
  32. Currently, there are three decoding strategies supported:
  33. "greedy_search", "sampling" and "beam_search". Default to
  34. "greedy_search".
  35. temperature (float, optional): The value used to module the next
  36. token probabilities in the "sampling" strategy. Default to 1.0,
  37. which means no effect.
  38. top_k (int, optional): The number of highest probability tokens to
  39. keep for top-k-filtering in the "sampling" strategy. Default to
  40. 0, which means no effect.
  41. top_p (float, optional): The cumulative probability for
  42. top-p-filtering in the "sampling" strategy. The value should
  43. satisfy :math:`0 <= top\_p < 1`. Default to 1.0, which means no
  44. effect.
  45. repetition_penalty (float, optional):
  46. The parameter for repetition penalty. 1.0 means no penalty. See `this paper
  47. <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details. Defaults to 1.0.
  48. num_beams (int, optional): The number of beams in the "beam_search"
  49. strategy. Default to 1.
  50. num_beam_groups (int, optional):
  51. Number of groups to divide `num_beams` into in order to use DIVERSE
  52. BEAM SEARCH. See `this paper <https://arxiv.org/pdf/1610.02424.pdf>`__
  53. for more details. Default to 1.
  54. length_penalty (float, optional): The exponential penalty to the
  55. sequence length in the "beam_search" strategy. The larger this
  56. param is, the more that the model would generate shorter
  57. sequences. Default to 0.0, which means no penalty.
  58. early_stopping (bool, optional): Whether to stop searching in the
  59. "beam_search" strategy when at least `num_beams` sentences are
  60. finished per batch or not. Default to False.
  61. bos_token_id (int, optional): The id of the `bos_token`. Default to
  62. None.
  63. eos_token_id (int, optional): The id of the `eos_token`. Default to
  64. None.
  65. pad_token_id (int, optional): The id of the `pad_token`. Default to
  66. None.
  67. decoder_start_token_id (int, optional): The start token id for
  68. encoder-decoder models. Default to None.
  69. forced_bos_token_id (int, optional): The id of the token to force as
  70. the first generated token. Usually use for multilingual models.
  71. Default to None.
  72. forced_eos_token_id (int, optional): The id of the token to force as
  73. the last generated token. Default to None.
  74. num_return_sequences (int, optional): The number of returned
  75. sequences for each sequence in the batch. Default to 1.
  76. diversity_rate (float, optional): If num_beam_groups is 1, this is the
  77. diversity_rate for Diverse Siblings Search. See
  78. `this paper https://arxiv.org/abs/1611.08562`__ for more details.
  79. If not, this is the diversity_rate for DIVERSE BEAM SEARCH.
  80. use_cache: (bool, optional): Whether to use the model cache to
  81. speed up decoding. Default to True.
  82. use_fast: (bool, optional): Whether to use fast entry of model
  83. for FastGeneration. Default to False.
  84. use_fp16_decoding: (bool, optional): Whether to use fp16 for decoding.
  85. Only works when fast entry is available. Default to False.
  86. trunc_input: (bool, optional): Whether to truncate the inputs from
  87. output sequences . Default to True.
  88. model_kwargs (dict): It can be used to specify additional kwargs
  89. passed to the model.
  90. """
  91. def _get_generation_mode(self):
  92. if hasattr(self, "num_beams") and self.num_beams == 1:
  93. if hasattr(self, "do_sample") and self.do_sample is True:
  94. generation_mode = "sampling"
  95. else:
  96. generation_mode = "greedy_search"
  97. else:
  98. generation_mode = "beam_search"
  99. return generation_mode
  100. def __init__(self, **kwargs):
  101. # Parameters that control the length of the output
  102. self.max_new_tokens = kwargs.get("max_new_tokens", DEFAULT_MAX_NEW_TOKENS)
  103. if "min_new_token" in kwargs:
  104. logging.warning(
  105. "<min_new_token> field is deprecated. Please use <min_new_tokens> instead."
  106. )
  107. kwargs["min_new_tokens"] = kwargs.pop("min_new_token")
  108. self.min_new_tokens = kwargs.pop("min_new_tokens", 0)
  109. self.max_length = kwargs.pop("max_length", 0)
  110. self.min_length = kwargs.pop("min_length", 0)
  111. self.early_stopping = kwargs.pop("early_stopping", False)
  112. self.trunc_input = kwargs.pop("trunc_input", True)
  113. # Parameters for manipulation of the model output logits
  114. self.diversity_rate = kwargs.pop("diversity_rate", 0.0)
  115. self.temperature = kwargs.pop("temperature", 1.0)
  116. self.top_k = kwargs.pop("top_k", 50)
  117. self.top_p = kwargs.pop("top_p", 1.0)
  118. self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
  119. self.length_penalty = kwargs.pop("length_penalty", 1.0)
  120. self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", None)
  121. self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
  122. self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
  123. self.num_beams = kwargs.pop("num_beams", 1)
  124. self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
  125. self.use_cache = kwargs.pop("use_cache", True)
  126. # Parameters that define the output variables of `generate`
  127. self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
  128. # Special tokens that can be used at generation time
  129. self.pad_token_id = kwargs.pop("pad_token_id", None)
  130. self.bos_token_id = kwargs.pop("bos_token_id", None)
  131. self.eos_token_id = kwargs.pop("eos_token_id", None)
  132. # Generation parameters exclusive to encoder-decoder models
  133. self.use_fast = kwargs.pop("use_fast", False)
  134. self.use_fp16_decoding = kwargs.pop("use_fp16_decoding", False)
  135. self.fast_ptq_sampling = kwargs.pop("fast_ptq_sampling", False)
  136. self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
  137. self._from_model_config = kwargs.pop("_from_model_config", False)
  138. # Additional attributes without default values
  139. if not self._from_model_config:
  140. # we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a
  141. # model's default configuration file
  142. for key, value in kwargs.items():
  143. try:
  144. setattr(self, key, value)
  145. except AttributeError as err:
  146. logging.error(f"Can't set {key} with value {value} for {self}")
  147. raise err
  148. # Parameters that control the generation strategy used
  149. if "decode_strategy" in kwargs:
  150. self.decode_strategy = kwargs.pop("decode_strategy")
  151. else:
  152. self.decode_strategy = self._get_generation_mode()
  153. # Validate the values of the attributes
  154. self.validate(is_init=True)
  155. def __eq__(self, other):
  156. if not isinstance(other, GenerationConfig):
  157. return False
  158. self_dict = self.__dict__.copy()
  159. other_dict = other.__dict__.copy()
  160. # ignore metadata
  161. for metadata_field in ["_from_model_config", "paddlenlp_version"]:
  162. self_dict.pop(metadata_field, None)
  163. other_dict.pop(metadata_field, None)
  164. return self_dict == other_dict
  165. def __repr__(self):
  166. return f"{self.__class__.__name__} {self.to_json_string()}"
  167. def validate(self, is_init=False):
  168. """
  169. Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence
  170. of parameterization that can be detected as incorrect from the configuration instance alone.
  171. Note that some parameters are best validated at generate runtime, as they may depend on other inputs and/or the
  172. model, such as parameters related to the generation length.
  173. """
  174. # Validation of individual attributes
  175. if self.early_stopping not in {True, False, "never"}:
  176. raise ValueError(
  177. f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}."
  178. )
  179. # Validation of attribute relations:
  180. fix_location = ""
  181. if is_init:
  182. fix_location = (
  183. " This was detected when initializing the generation config instance, which means the corresponding "
  184. "file may hold incorrect parameterization and should be fixed."
  185. )
  186. # 1. detect sampling-only parameterization when not in sampling mode
  187. if self.decode_strategy == "greedy_search":
  188. greedy_wrong_parameter_msg = (
  189. "using greedy search strategy. However, `{flag_name}` is set to `{flag_value}` -- this flag is only "
  190. 'used in sample-based generation modes. You should set `decode_strategy="greedy_search" ` or unset `{flag_name}`.'
  191. + fix_location
  192. )
  193. if self.temperature != 1.0:
  194. warnings.warn(
  195. greedy_wrong_parameter_msg.format(
  196. flag_name="temperature", flag_value=self.temperature
  197. ),
  198. UserWarning,
  199. )
  200. if self.top_p != 1.0:
  201. warnings.warn(
  202. greedy_wrong_parameter_msg.format(
  203. flag_name="top_p", flag_value=self.top_p
  204. ),
  205. UserWarning,
  206. )
  207. # 2. detect beam-only parameterization when not in beam mode
  208. if self.decode_strategy != "beam_search":
  209. single_beam_wrong_parameter_msg = (
  210. "`num_beams` is set to 1. However, `{flag_name}` is set to `{flag_value}` -- this flag is only used "
  211. "in beam-based generation modes. You should set `num_beams>1` or unset `{flag_name}`."
  212. + fix_location
  213. )
  214. if self.early_stopping is not False:
  215. warnings.warn(
  216. single_beam_wrong_parameter_msg.format(
  217. flag_name="early_stopping", flag_value=self.early_stopping
  218. ),
  219. UserWarning,
  220. )
  221. if self.num_beam_groups != 1:
  222. warnings.warn(
  223. single_beam_wrong_parameter_msg.format(
  224. flag_name="num_beam_groups", flag_value=self.num_beam_groups
  225. ),
  226. UserWarning,
  227. )
  228. if self.length_penalty != 1.0:
  229. warnings.warn(
  230. single_beam_wrong_parameter_msg.format(
  231. flag_name="length_penalty", flag_value=self.length_penalty
  232. ),
  233. UserWarning,
  234. )
  235. # 4. check `num_return_sequences`
  236. if self.num_return_sequences != 1:
  237. if self.decode_strategy == "greedy_search":
  238. raise ValueError(
  239. "Greedy methods without beam search do not support `num_return_sequences` different than 1 "
  240. f"(got {self.num_return_sequences})."
  241. )
  242. @classmethod
  243. def from_pretrained(
  244. cls,
  245. pretrained_model_name_or_path: Union[str, os.PathLike],
  246. from_hf_hub: bool = False,
  247. from_aistudio: bool = False,
  248. config_file_name: Optional[Union[str, os.PathLike]] = None,
  249. cache_dir: Optional[Union[str, os.PathLike]] = None,
  250. force_download: bool = False,
  251. **kwargs,
  252. ) -> "GenerationConfig":
  253. r"""
  254. Instantiate a [`GenerationConfig`] from a generation configuration file.
  255. Args:
  256. pretrained_model_name_or_path (`str` or `os.PathLike`):
  257. This can be either:
  258. - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
  259. paddlenlp bos server. Valid model ids can be located at the root-level, like `bert-base-uncased`, or
  260. namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`.
  261. - a path to a *directory* containing a configuration file saved using the
  262. [`~PretrainedConfig.save_pretrained`] method, e.g., `./my_model_directory/`.
  263. - a path or url to a saved configuration JSON *file*, e.g., `./my_model_directory/configuration.json`.
  264. from_hf_hub (bool, *optional*):
  265. load config from huggingface hub: https://huggingface.co/models
  266. cache_dir (`str` or `os.PathLike`, *optional*):
  267. Path to a directory in which a downloaded pretrained model configuration should be cached if the
  268. standard cache should not be used.
  269. force_download (`bool`, *optional*, defaults to `False`):
  270. Whether or not to force to (re-)download the configuration files and override the cached versions if
  271. they exist.
  272. return_unused_kwargs (`bool`, *optional*, defaults to `False`):
  273. If `False`, then this function returns just the final configuration object.
  274. If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a
  275. dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the
  276. part of `kwargs` which has not been used to update `config` and is otherwise ignored.
  277. kwargs (`Dict[str, Any]`, *optional*):
  278. The values in kwargs of any keys which are configuration attributes will be used to override the loaded
  279. values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
  280. by the `return_unused_kwargs` keyword parameter.
  281. Returns:
  282. [`GenerationConfig`]: The configuration object instantiated from this pretrained model.
  283. Examples:
  284. ```python
  285. >>> from paddlenlp.transformers import GenerationConfig
  286. >>> generation_config = GenerationConfig.from_pretrained("gpt2")
  287. >>> # E.g. config was saved using *save_pretrained('./test/saved_model/')*
  288. >>> generation_config.save_pretrained("./test/saved_model/")
  289. >>> generation_config = GenerationConfig.from_pretrained("./test/saved_model/")
  290. >>> # You can also specify configuration names to your generation configuration file
  291. >>> generation_config.save_pretrained("./test/saved_model/", config_file_name="my_configuration.json")
  292. >>> generation_config = GenerationConfig.from_pretrained("./test/saved_model/", "my_configuration.json")
  293. >>> # If you'd like to try a minor variation to an existing configuration, you can also pass generation
  294. >>> # arguments to `.from_pretrained()`. Be mindful that typos and unused arguments will be ignored
  295. >>> generation_config, unused_kwargs = GenerationConfig.from_pretrained(
  296. ... "gpt2", top_k=1, foo=False, do_sample=True, return_unused_kwargs=True
  297. ... )
  298. >>> generation_config.top_k
  299. 1
  300. >>> unused_kwargs
  301. {'foo': False}
  302. ```"""
  303. config_file_name = (
  304. config_file_name if config_file_name is not None else GENERATION_CONFIG_NAME
  305. )
  306. subfolder = kwargs.pop("subfolder", "")
  307. if subfolder is None:
  308. subfolder = ""
  309. # NOTE resolve_file_path 适配
  310. resolved_config_file = resolve_file_path(
  311. pretrained_model_name_or_path,
  312. [config_file_name],
  313. subfolder,
  314. cache_dir=cache_dir,
  315. force_download=force_download,
  316. from_aistudio=from_aistudio,
  317. from_hf_hub=from_hf_hub,
  318. )
  319. assert (
  320. resolved_config_file is not None
  321. ), f"please make sure {config_file_name} under {pretrained_model_name_or_path}"
  322. try:
  323. logging.info(f"Loading configuration file {resolved_config_file}")
  324. # Load config dict
  325. config_dict = cls._dict_from_json_file(resolved_config_file)
  326. except (json.JSONDecodeError, UnicodeDecodeError):
  327. raise EnvironmentError(
  328. f"Config file<'{resolved_config_file}'> is not a valid JSON file."
  329. )
  330. return cls.from_dict(config_dict, **kwargs)
  331. @classmethod
  332. def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
  333. with open(json_file, "r", encoding="utf-8") as reader:
  334. text = reader.read()
  335. return json.loads(text)
  336. def dict_paddle_dtype_to_str(self, d: Dict[str, Any]) -> None:
  337. """
  338. Checks whether the passed dictionary and its nested dicts have a *paddle_dtype* key and if it's not None,
  339. converts paddle.dtype to a string of just the type. For example, `paddle.float32` get converted into *"float32"*
  340. string, which can then be stored in the json format.
  341. """
  342. if d.get("dtype", None) is not None and not isinstance(d["dtype"], str):
  343. d["dtype"] = convert_dtype(d["dtype"])
  344. for value in d.values():
  345. if isinstance(value, dict):
  346. self.dict_paddle_dtype_to_str(value)
  347. @classmethod
  348. def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "GenerationConfig":
  349. """
  350. Instantiates a [`GenerationConfig`] from a Python dictionary of parameters.
  351. Args:
  352. config_dict (`Dict[str, Any]`):
  353. Dictionary that will be used to instantiate the configuration object.
  354. kwargs (`Dict[str, Any]`):
  355. Additional parameters from which to initialize the configuration object.
  356. Returns:
  357. [`GenerationConfig`]: The configuration object instantiated from those parameters.
  358. """
  359. return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
  360. config = cls(**{**config_dict, **kwargs})
  361. unused_kwargs = config.update(**kwargs)
  362. # logging.info(f"Generate config {config}")
  363. if return_unused_kwargs:
  364. return config, unused_kwargs
  365. else:
  366. return config
  367. def to_diff_dict(self) -> Dict[str, Any]:
  368. """
  369. Removes all attributes from config which correspond to the default config attributes for better readability and
  370. serializes to a Python dictionary.
  371. Returns:
  372. `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
  373. """
  374. config_dict = self.to_dict()
  375. # get the default config dict
  376. default_config_dict = GenerationConfig().to_dict()
  377. serializable_config_dict = {}
  378. # only serialize values that differ from the default config
  379. for key, value in config_dict.items():
  380. if (
  381. key not in default_config_dict
  382. or key == "transformers_version"
  383. or value != default_config_dict[key]
  384. ):
  385. serializable_config_dict[key] = value
  386. self.dict_paddle_dtype_to_str(serializable_config_dict)
  387. return serializable_config_dict
  388. def to_json_string(self, use_diff: bool = True) -> str:
  389. """
  390. Serializes this instance to a JSON string.
  391. Args:
  392. use_diff (`bool`, *optional*, defaults to `True`):
  393. If set to `True`, only the difference between the config instance and the default `GenerationConfig()`
  394. is serialized to JSON string.
  395. Returns:
  396. `str`: String containing all the attributes that make up this configuration instance in JSON format.
  397. """
  398. if use_diff is True:
  399. config_dict = self.to_diff_dict()
  400. else:
  401. config_dict = self.to_dict()
  402. return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
  403. def to_json_file(
  404. self, json_file_path: Union[str, os.PathLike], use_diff: bool = True
  405. ):
  406. """
  407. Save this instance to a JSON file.
  408. Args:
  409. json_file_path (`str` or `os.PathLike`):
  410. Path to the JSON file in which this configuration instance's parameters will be saved.
  411. use_diff (`bool`, *optional*, defaults to `True`):
  412. If set to `True`, only the difference between the config instance and the default `GenerationConfig()`
  413. is serialized to JSON file.
  414. """
  415. with open(json_file_path, "w", encoding="utf-8") as writer:
  416. writer.write(self.to_json_string(use_diff=use_diff))
  417. @classmethod
  418. def from_model_config(cls, model_config: PretrainedConfig) -> "GenerationConfig":
  419. """
  420. Instantiates a [`GenerationConfig`] from a [`PretrainedConfig`]. This function is useful to convert legacy
  421. [`PretrainedConfig`] objects, which may contain generation parameters, into a stand-alone [`GenerationConfig`].
  422. Args:
  423. model_config (`PretrainedConfig`):
  424. The model config that will be used to instantiate the generation config.
  425. Returns:
  426. [`GenerationConfig`]: The configuration object instantiated from those parameters.
  427. """
  428. config_dict = model_config.to_dict()
  429. config_dict.pop("_from_model_config", None)
  430. config = cls.from_dict(
  431. config_dict, return_unused_kwargs=False, _from_model_config=True
  432. )
  433. # Special case: some models have generation attributes set in the decoder. Use them if still unset in the
  434. # generation config.
  435. for decoder_name in ("decoder", "generator", "text_config"):
  436. if decoder_name in config_dict:
  437. default_generation_config = GenerationConfig()
  438. decoder_config = config_dict[decoder_name]
  439. for attr in config.to_dict().keys():
  440. if attr in decoder_config and getattr(config, attr) == getattr(
  441. default_generation_config, attr
  442. ):
  443. setattr(config, attr, decoder_config[attr])
  444. return config
  445. def update(self, **kwargs):
  446. """
  447. Updates attributes of this class instance with attributes from `kwargs` if they match existing atributtes,
  448. returning all the unused kwargs.
  449. Args:
  450. kwargs (`Dict[str, Any]`):
  451. Dictionary of attributes to tentatively update this class.
  452. Returns:
  453. `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
  454. """
  455. to_remove = []
  456. for key, value in kwargs.items():
  457. if hasattr(self, key):
  458. setattr(self, key, value)
  459. to_remove.append(key)
  460. # remove all the attributes that were updated, without modifying the input dict
  461. unused_kwargs = {
  462. key: value for key, value in kwargs.items() if key not in to_remove
  463. }
  464. return unused_kwargs