predictor.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import time
  3. from loguru import logger
  4. from .base_predictor import (
  5. DEFAULT_MAX_NEW_TOKENS,
  6. DEFAULT_NO_REPEAT_NGRAM_SIZE,
  7. DEFAULT_PRESENCE_PENALTY,
  8. DEFAULT_REPETITION_PENALTY,
  9. DEFAULT_TEMPERATURE,
  10. DEFAULT_TOP_K,
  11. DEFAULT_TOP_P,
  12. BasePredictor,
  13. )
  14. from .sglang_client_predictor import SglangClientPredictor
  15. hf_loaded = False
  16. try:
  17. from .hf_predictor import HuggingfacePredictor
  18. hf_loaded = True
  19. except ImportError as e:
  20. logger.warning("hf is not installed. If you are not using transformers, you can ignore this warning.")
  21. engine_loaded = False
  22. try:
  23. from sglang.srt.server_args import ServerArgs
  24. from .sglang_engine_predictor import SglangEnginePredictor
  25. engine_loaded = True
  26. except Exception as e:
  27. logger.exception(e)
  28. logger.warning("sglang is not installed. If you are not using sglang, you can ignore this warning.")
  29. def get_predictor(
  30. backend: str = "sglang-client",
  31. model_path: str | None = None,
  32. server_url: str | None = None,
  33. temperature: float = DEFAULT_TEMPERATURE,
  34. top_p: float = DEFAULT_TOP_P,
  35. top_k: int = DEFAULT_TOP_K,
  36. repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
  37. presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
  38. no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
  39. max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
  40. http_timeout: int = 600,
  41. **kwargs,
  42. ) -> BasePredictor:
  43. start_time = time.time()
  44. if backend == "transformers":
  45. if not model_path:
  46. raise ValueError("model_path must be provided for transformers backend.")
  47. if not hf_loaded:
  48. raise ImportError(
  49. "transformers is not installed, so huggingface backend cannot be used. "
  50. "If you need to use huggingface backend, please install transformers first."
  51. )
  52. predictor = HuggingfacePredictor(
  53. model_path=model_path,
  54. temperature=temperature,
  55. top_p=top_p,
  56. top_k=top_k,
  57. repetition_penalty=repetition_penalty,
  58. presence_penalty=presence_penalty,
  59. no_repeat_ngram_size=no_repeat_ngram_size,
  60. max_new_tokens=max_new_tokens,
  61. **kwargs,
  62. )
  63. elif backend == "sglang-engine":
  64. if not model_path:
  65. raise ValueError("model_path must be provided for sglang-engine backend.")
  66. if not engine_loaded:
  67. raise ImportError(
  68. "sglang is not installed, so sglang-engine backend cannot be used. "
  69. "If you need to use sglang-engine backend for inference, "
  70. "please install sglang[all]==0.4.7 or a newer version."
  71. )
  72. predictor = SglangEnginePredictor(
  73. server_args=ServerArgs(model_path, **kwargs),
  74. temperature=temperature,
  75. top_p=top_p,
  76. top_k=top_k,
  77. repetition_penalty=repetition_penalty,
  78. presence_penalty=presence_penalty,
  79. no_repeat_ngram_size=no_repeat_ngram_size,
  80. max_new_tokens=max_new_tokens,
  81. )
  82. elif backend == "sglang-client":
  83. if not server_url:
  84. raise ValueError("server_url must be provided for sglang-client backend.")
  85. predictor = SglangClientPredictor(
  86. server_url=server_url,
  87. temperature=temperature,
  88. top_p=top_p,
  89. top_k=top_k,
  90. repetition_penalty=repetition_penalty,
  91. presence_penalty=presence_penalty,
  92. no_repeat_ngram_size=no_repeat_ngram_size,
  93. max_new_tokens=max_new_tokens,
  94. http_timeout=http_timeout,
  95. )
  96. else:
  97. raise ValueError(f"Unsupported backend: {backend}. Supports: transformers, sglang-engine, sglang-client.")
  98. elapsed = round(time.time() - start_time, 2)
  99. logger.info(f"get_predictor cost: {elapsed}s")
  100. return predictor