base_predictor.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. import asyncio
  2. from abc import ABC, abstractmethod
  3. from typing import AsyncIterable, Iterable, List, Optional, Union
  4. DEFAULT_SYSTEM_PROMPT = (
  5. "A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers."
  6. )
  7. DEFAULT_USER_PROMPT = "Document Parsing:"
  8. DEFAULT_TEMPERATURE = 0.0
  9. DEFAULT_TOP_P = 0.8
  10. DEFAULT_TOP_K = 20
  11. DEFAULT_REPETITION_PENALTY = 1.0
  12. DEFAULT_PRESENCE_PENALTY = 0.0
  13. DEFAULT_NO_REPEAT_NGRAM_SIZE = 100
  14. DEFAULT_MAX_NEW_TOKENS = 16384
  15. class BasePredictor(ABC):
  16. system_prompt = DEFAULT_SYSTEM_PROMPT
  17. def __init__(
  18. self,
  19. temperature: float = DEFAULT_TEMPERATURE,
  20. top_p: float = DEFAULT_TOP_P,
  21. top_k: int = DEFAULT_TOP_K,
  22. repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
  23. presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
  24. no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
  25. max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
  26. ) -> None:
  27. self.temperature = temperature
  28. self.top_p = top_p
  29. self.top_k = top_k
  30. self.repetition_penalty = repetition_penalty
  31. self.presence_penalty = presence_penalty
  32. self.no_repeat_ngram_size = no_repeat_ngram_size
  33. self.max_new_tokens = max_new_tokens
  34. @abstractmethod
  35. def predict(
  36. self,
  37. image: str | bytes,
  38. prompt: str = "",
  39. temperature: Optional[float] = None,
  40. top_p: Optional[float] = None,
  41. top_k: Optional[int] = None,
  42. repetition_penalty: Optional[float] = None,
  43. presence_penalty: Optional[float] = None,
  44. no_repeat_ngram_size: Optional[int] = None,
  45. max_new_tokens: Optional[int] = None,
  46. ) -> str: ...
  47. @abstractmethod
  48. def batch_predict(
  49. self,
  50. images: List[str] | List[bytes],
  51. prompts: Union[List[str], str] = "",
  52. temperature: Optional[float] = None,
  53. top_p: Optional[float] = None,
  54. top_k: Optional[int] = None,
  55. repetition_penalty: Optional[float] = None,
  56. presence_penalty: Optional[float] = None,
  57. no_repeat_ngram_size: Optional[int] = None,
  58. max_new_tokens: Optional[int] = None,
  59. ) -> List[str]: ...
  60. @abstractmethod
  61. def stream_predict(
  62. self,
  63. image: str | bytes,
  64. prompt: str = "",
  65. temperature: Optional[float] = None,
  66. top_p: Optional[float] = None,
  67. top_k: Optional[int] = None,
  68. repetition_penalty: Optional[float] = None,
  69. presence_penalty: Optional[float] = None,
  70. no_repeat_ngram_size: Optional[int] = None,
  71. max_new_tokens: Optional[int] = None,
  72. ) -> Iterable[str]: ...
  73. async def aio_predict(
  74. self,
  75. image: str | bytes,
  76. prompt: str = "",
  77. temperature: Optional[float] = None,
  78. top_p: Optional[float] = None,
  79. top_k: Optional[int] = None,
  80. repetition_penalty: Optional[float] = None,
  81. presence_penalty: Optional[float] = None,
  82. no_repeat_ngram_size: Optional[int] = None,
  83. max_new_tokens: Optional[int] = None,
  84. ) -> str:
  85. return await asyncio.to_thread(
  86. self.predict,
  87. image,
  88. prompt,
  89. temperature,
  90. top_p,
  91. top_k,
  92. repetition_penalty,
  93. presence_penalty,
  94. no_repeat_ngram_size,
  95. max_new_tokens,
  96. )
  97. async def aio_batch_predict(
  98. self,
  99. images: List[str] | List[bytes],
  100. prompts: Union[List[str], str] = "",
  101. temperature: Optional[float] = None,
  102. top_p: Optional[float] = None,
  103. top_k: Optional[int] = None,
  104. repetition_penalty: Optional[float] = None,
  105. presence_penalty: Optional[float] = None,
  106. no_repeat_ngram_size: Optional[int] = None,
  107. max_new_tokens: Optional[int] = None,
  108. ) -> List[str]:
  109. return await asyncio.to_thread(
  110. self.batch_predict,
  111. images,
  112. prompts,
  113. temperature,
  114. top_p,
  115. top_k,
  116. repetition_penalty,
  117. presence_penalty,
  118. no_repeat_ngram_size,
  119. max_new_tokens,
  120. )
  121. async def aio_stream_predict(
  122. self,
  123. image: str | bytes,
  124. prompt: str = "",
  125. temperature: Optional[float] = None,
  126. top_p: Optional[float] = None,
  127. top_k: Optional[int] = None,
  128. repetition_penalty: Optional[float] = None,
  129. presence_penalty: Optional[float] = None,
  130. no_repeat_ngram_size: Optional[int] = None,
  131. max_new_tokens: Optional[int] = None,
  132. ) -> AsyncIterable[str]:
  133. queue = asyncio.Queue()
  134. loop = asyncio.get_running_loop()
  135. def synced_predict():
  136. for chunk in self.stream_predict(
  137. image=image,
  138. prompt=prompt,
  139. temperature=temperature,
  140. top_p=top_p,
  141. top_k=top_k,
  142. repetition_penalty=repetition_penalty,
  143. presence_penalty=presence_penalty,
  144. no_repeat_ngram_size=no_repeat_ngram_size,
  145. max_new_tokens=max_new_tokens,
  146. ):
  147. asyncio.run_coroutine_threadsafe(queue.put(chunk), loop)
  148. asyncio.run_coroutine_threadsafe(queue.put(None), loop)
  149. asyncio.create_task(
  150. asyncio.to_thread(synced_predict),
  151. )
  152. while True:
  153. chunk = await queue.get()
  154. if chunk is None:
  155. return
  156. assert isinstance(chunk, str)
  157. yield chunk
  158. def build_prompt(self, prompt: str) -> str:
  159. if prompt.startswith("<|im_start|>"):
  160. return prompt
  161. if not prompt:
  162. prompt = DEFAULT_USER_PROMPT
  163. return f"<|im_start|>system\n{self.system_prompt}<|im_end|><|im_start|>user\n<image>\n{prompt}<|im_end|><|im_start|>assistant\n"
  164. # Modify here. We add <|box_start|> at the end of the prompt to force the model to generate bounding box.
  165. # if "Document OCR" in prompt:
  166. # return f"<|im_start|>system\n{self.system_prompt}<|im_end|><|im_start|>user\n<image>\n{prompt}<|im_end|><|im_start|>assistant\n<|box_start|>"
  167. # else:
  168. # return f"<|im_start|>system\n{self.system_prompt}<|im_end|><|im_start|>user\n<image>\n{prompt}<|im_end|><|im_start|>assistant\n"
  169. def close(self):
  170. pass