-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconnect_memory_with_llm.py
60 lines (48 loc) · 1.95 KB
/
connect_memory_with_llm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
from langchain_huggingface import HuggingFaceEndpoint
from langchain_core.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv())
# Setiip LLM(Mistral with HuggingFace)
HF_TOKEN = os.getenv("HF_TOKEN")
HUGGINGFACE_REPO_ID = "mistralai/Mistral-7B-Instruct-v0.3"
def load_llm(huggingface_repo_id):
llm=HuggingFaceEndpoint(
repo_id=huggingface_repo_id,
task="text-generation",
temperature=0.5,
model_kwargs={"token":HF_TOKEN,"max_length":"512"}
)
return llm
# Connect LLM with FAISS and Create chain
CUSTOM_PROMPT_TEMPLATE = """
Use the pieces of information provided in the context to answer user's question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Dont provide anything out of the given context
Context: {context}
Question: {question}
Start the answer direclty. No small talk please.
"""
def set_custom_prompt(custom_prompt_template):
prompt=PromptTemplate(template=custom_prompt_template, input_variables=["context", "question"])
return prompt
# Load Database
DB_FAISS_PATH="vectorstore/db_faiss"
embedding_model = HuggingFaceEmbeddings(model_name = "sentence-transformers/all-MiniLM-L6-v2")
db=FAISS.load_local(DB_FAISS_PATH, embedding_model, allow_dangerous_deserialization=True)
# Create QA chain
qa_chain = RetrievalQA.from_chain_type(
llm=load_llm(HUGGINGFACE_REPO_ID),
chain_type="stuff",
retriever=db.as_retriever(search_kwargs={'k':3}),
return_source_documents=True,
chain_type_kwargs={'prompt':set_custom_prompt(CUSTOM_PROMPT_TEMPLATE)}
)
# Invoke with a single query
user_query=input("Write Query Here: ")
response=qa_chain.invoke({'query':user_query})
print("RESULT: ",response["result"])
print("SOURSE DOCUMETS: ", response["source_documents"])