query.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import os
  2. import click
  3. from llama_index.core.vector_stores.types import VectorStoreQuery
  4. from llama_index.embeddings.dashscope import (DashScopeEmbedding,
  5. DashScopeTextEmbeddingModels,
  6. DashScopeTextEmbeddingType)
  7. from llama_index.vector_stores.elasticsearch import (AsyncDenseVectorStrategy,
  8. ElasticsearchStore)
  9. # initialize qwen 7B model
  10. from modelscope import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
  11. es_vector_store = ElasticsearchStore(
  12. index_name='rag_index',
  13. es_url=os.getenv('ES_URL', 'http://127.0.0.1:9200'),
  14. es_user=os.getenv('ES_USER', 'elastic'),
  15. es_password=os.getenv('ES_PASSWORD', 'llama_index'),
  16. retrieval_strategy=AsyncDenseVectorStrategy(),
  17. )
  18. def embed_text(text):
  19. embedder = DashScopeEmbedding(
  20. model_name=DashScopeTextEmbeddingModels.TEXT_EMBEDDING_V2,
  21. text_type=DashScopeTextEmbeddingType.TEXT_TYPE_DOCUMENT,
  22. )
  23. return embedder.get_text_embedding(text)
  24. def search(vector_store: ElasticsearchStore, query: str):
  25. query_vec = VectorStoreQuery(query_embedding=embed_text(query))
  26. result = vector_store.query(query_vec)
  27. return '\n'.join([node.text for node in result.nodes])
  28. @click.command()
  29. @click.option(
  30. '-q',
  31. '--question',
  32. 'question',
  33. required=True,
  34. help='ask what you want to know!',
  35. )
  36. def cli(question):
  37. tokenizer = AutoTokenizer.from_pretrained('qwen/Qwen-7B-Chat',
  38. revision='v1.0.5',
  39. trust_remote_code=True)
  40. model = AutoModelForCausalLM.from_pretrained('qwen/Qwen-7B-Chat',
  41. revision='v1.0.5',
  42. device_map='auto',
  43. trust_remote_code=True,
  44. fp32=True).eval()
  45. model.generation_config = GenerationConfig.from_pretrained(
  46. 'Qwen/Qwen-7B-Chat', revision='v1.0.5', trust_remote_code=True)
  47. # define a prompt template for the vectorDB-enhanced LLM generation
  48. def answer_question(question, context, model):
  49. if context == '':
  50. prompt = question
  51. else:
  52. prompt = f'''请基于```内的内容回答问题。"
  53. ```
  54. {context}
  55. ```
  56. 我的问题是:{question}。
  57. '''
  58. history = None
  59. print(prompt)
  60. response, history = model.chat(tokenizer, prompt, history=None)
  61. return response
  62. answer = answer_question(question, search(es_vector_store, question),
  63. model)
  64. print(f'question: {question}\n'
  65. f'answer: {answer}')
  66. """
  67. python query.py -q 'how about the rights of men'
  68. """
  69. if __name__ == '__main__':
  70. cli()