utils.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import inspect
  15. from typing import List
  16. __all__ = [
  17. "convert_to_dict_message",
  18. "fn_args_to_dict",
  19. ]
  20. def convert_to_dict_message(conversation: List[List[str]]):
  21. """Convert the list of chat messages to a role dictionary chat messages."""
  22. conversations = []
  23. for index, item in enumerate(conversation):
  24. assert (
  25. 1 <= len(item) <= 2
  26. ), "Each Rounds in conversation should have 1 or 2 elements."
  27. if isinstance(item[0], str):
  28. conversations.append({"role": "user", "content": item[0]})
  29. if len(item) == 2 and isinstance(item[1], str):
  30. conversations.append({"role": "assistant", "content": item[1]})
  31. else:
  32. # If there is only one element in item, it must be the last round.
  33. # If it is not the last round, it must be an error.
  34. if index != len(conversation) - 1:
  35. raise ValueError(f"Round {index} has error round")
  36. else:
  37. raise ValueError("Each round in list should be string")
  38. return conversations
  39. def fn_args_to_dict(func, *args, **kwargs):
  40. """
  41. Inspect function `func` and its arguments for running, and extract a
  42. dict mapping between argument names and keys.
  43. """
  44. (spec_args, spec_varargs, spec_varkw, spec_defaults, _, _, _) = (
  45. inspect.getfullargspec(func)
  46. )
  47. # add positional argument values
  48. init_dict = dict(zip(spec_args, args))
  49. # add default argument values
  50. kwargs_dict = (
  51. dict(zip(spec_args[-len(spec_defaults) :], spec_defaults))
  52. if spec_defaults
  53. else {}
  54. )
  55. for k in list(kwargs_dict.keys()):
  56. if k in init_dict:
  57. kwargs_dict.pop(k)
  58. kwargs_dict.update(kwargs)
  59. init_dict.update(kwargs_dict)
  60. return init_dict