predictor.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import time
  3. from loguru import logger
  4. from .base_predictor import (
  5. DEFAULT_MAX_NEW_TOKENS,
  6. DEFAULT_NO_REPEAT_NGRAM_SIZE,
  7. DEFAULT_PRESENCE_PENALTY,
  8. DEFAULT_REPETITION_PENALTY,
  9. DEFAULT_TEMPERATURE,
  10. DEFAULT_TOP_K,
  11. DEFAULT_TOP_P,
  12. BasePredictor,
  13. )
  14. from .sglang_client_predictor import SglangClientPredictor
  15. hf_loaded = False
  16. try:
  17. from .hf_predictor import HuggingfacePredictor
  18. hf_loaded = True
  19. except ImportError as e:
  20. logger.warning("hf is not installed. If you are not using transformers, you can ignore this warning.")
  21. engine_loaded = False
  22. try:
  23. from sglang.srt.server_args import ServerArgs
  24. from .sglang_engine_predictor import SglangEnginePredictor
  25. engine_loaded = True
  26. except Exception as e:
  27. logger.warning("sglang is not installed. If you are not using sglang, you can ignore this warning.")
  28. def get_predictor(
  29. backend: str = "sglang-client",
  30. model_path: str | None = None,
  31. server_url: str | None = None,
  32. temperature: float = DEFAULT_TEMPERATURE,
  33. top_p: float = DEFAULT_TOP_P,
  34. top_k: int = DEFAULT_TOP_K,
  35. repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
  36. presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
  37. no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
  38. max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
  39. http_timeout: int = 600,
  40. **kwargs,
  41. ) -> BasePredictor:
  42. start_time = time.time()
  43. if backend == "transformers":
  44. if not model_path:
  45. raise ValueError("model_path must be provided for transformers backend.")
  46. if not hf_loaded:
  47. raise ImportError(
  48. "transformers is not installed, so huggingface backend cannot be used. "
  49. "If you need to use huggingface backend, please install transformers first."
  50. )
  51. predictor = HuggingfacePredictor(
  52. model_path=model_path,
  53. temperature=temperature,
  54. top_p=top_p,
  55. top_k=top_k,
  56. repetition_penalty=repetition_penalty,
  57. presence_penalty=presence_penalty,
  58. no_repeat_ngram_size=no_repeat_ngram_size,
  59. max_new_tokens=max_new_tokens,
  60. **kwargs,
  61. )
  62. elif backend == "sglang-engine":
  63. if not model_path:
  64. raise ValueError("model_path must be provided for sglang-engine backend.")
  65. if not engine_loaded:
  66. raise ImportError(
  67. "sglang is not installed, so sglang-engine backend cannot be used. "
  68. "If you need to use sglang-engine backend for inference, "
  69. "please install sglang[all]==0.4.8 or a newer version."
  70. )
  71. predictor = SglangEnginePredictor(
  72. server_args=ServerArgs(model_path, **kwargs),
  73. temperature=temperature,
  74. top_p=top_p,
  75. top_k=top_k,
  76. repetition_penalty=repetition_penalty,
  77. presence_penalty=presence_penalty,
  78. no_repeat_ngram_size=no_repeat_ngram_size,
  79. max_new_tokens=max_new_tokens,
  80. )
  81. elif backend == "sglang-client":
  82. if not server_url:
  83. raise ValueError("server_url must be provided for sglang-client backend.")
  84. predictor = SglangClientPredictor(
  85. server_url=server_url,
  86. temperature=temperature,
  87. top_p=top_p,
  88. top_k=top_k,
  89. repetition_penalty=repetition_penalty,
  90. presence_penalty=presence_penalty,
  91. no_repeat_ngram_size=no_repeat_ngram_size,
  92. max_new_tokens=max_new_tokens,
  93. http_timeout=http_timeout,
  94. )
  95. else:
  96. raise ValueError(f"Unsupported backend: {backend}. Supports: transformers, sglang-engine, sglang-client.")
  97. elapsed = round(time.time() - start_time, 2)
  98. logger.info(f"get_predictor cost: {elapsed}s")
  99. return predictor