-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmedical_chatbot.py
91 lines (68 loc) · 3.13 KB
/
medical_chatbot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
import streamlit as st
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.chains import RetrievalQA
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import PromptTemplate
from langchain_huggingface import HuggingFaceEndpoint
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv())
DB_FAISS_PATH="vectorstore/db_faiss"
@st.cache_resource
def get_vectorstore():
embedding_model=HuggingFaceBgeEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
db=FAISS.load_local(DB_FAISS_PATH, embedding_model,allow_dangerous_deserialization=True)
return db
def set_custom_prompt(custom_prompt_template):
prompt=PromptTemplate(template=custom_prompt_template, input_variables=["context","question"])
return prompt
def load_llm(huggingface_repo_id,HF_TOKEN):
llm=HuggingFaceEndpoint(
repo_id=huggingface_repo_id,
task="text-generation",
temperature=0.5,
model_kwargs={"token":HF_TOKEN,"max_length":"512"}
)
return llm
def main():
st.title("Ask Chatbot")
if 'messages' not in st.session_state:
st.session_state.messages=[]
for message in st.session_state.messages:
st.chat_message(message['role']).markdown(message['content'])
prompt= st.chat_input("pass your prompt here")
if prompt:
st.chat_message('user').markdown(prompt)
st.session_state.messages.append({'role':'user', 'content':prompt})
CUSTOM_PROMPT_TEMPLATE = """
Use the pieces of information provided in the context to answer user's question.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Dont provide anything out of the given context
Context: {context}
Question: {question}
Start the answer direclty. No small talk please.
"""
HUGGINGFACE_REPO_ID = "mistralai/Mistral-7B-Instruct-v0.3"
HF_TOKEN = os.getenv("HF_TOKEN")
try:
vectorstore=get_vectorstore()
if vectorstore is None:
st.error("Failed to load the vectore store")
qa_chain = RetrievalQA.from_chain_type(
llm=load_llm(huggingface_repo_id=HUGGINGFACE_REPO_ID, HF_TOKEN=HF_TOKEN),
chain_type="stuff",
retriever=vectorstore.as_retriever(search_kwargs={'k':3}),
return_source_documents=True,
chain_type_kwargs={'prompt':set_custom_prompt(CUSTOM_PROMPT_TEMPLATE)}
)
response=qa_chain.invoke({'query':prompt})
result=response["result"]
source_documets=response["source_documents"]
result_to_show=result #+"\n Source Docs: \n"+str(source_documets)
# response="I am medibot"
st.chat_message('assistant').markdown(result_to_show)
st.session_state.messages.append({'role':'assistant','content':result_to_show})
except Exception as e:
st.error(f"Error: {str(e)}")
if __name__ == "__main__":
main()