diff --git a/config/khoj_docker.yml b/config/khoj_docker.yml index 4f5216d3..176a7b1e 100644 --- a/config/khoj_docker.yml +++ b/config/khoj_docker.yml @@ -51,4 +51,5 @@ search-type: processor: #conversation: # openai-api-key: null + # model: "text-davinci-003" # conversation-logfile: "/data/embeddings/conversation_logs.json" \ No newline at end of file diff --git a/config/khoj_sample.yml b/config/khoj_sample.yml index 7c8a2ddf..18296e20 100644 --- a/config/khoj_sample.yml +++ b/config/khoj_sample.yml @@ -52,4 +52,5 @@ search-type: processor: conversation: openai-api-key: # "YOUR_OPENAI_API_KEY" + model: "text-davinci-003" conversation-logfile: "~/.khoj/processor/conversation/conversation_logs.json" diff --git a/src/processor/conversation/gpt.py b/src/processor/conversation/gpt.py index 66481b66..be433d66 100644 --- a/src/processor/conversation/gpt.py +++ b/src/processor/conversation/gpt.py @@ -10,7 +10,7 @@ import openai from src.utils.constants import empty_escape_sequences -def summarize(text, summary_type, user_query=None, api_key=None, temperature=0.5, max_tokens=100): +def summarize(text, summary_type, model, user_query=None, api_key=None, temperature=0.5, max_tokens=100): """ Summarize user input using OpenAI's GPT """ @@ -35,8 +35,8 @@ Summarize the notes in second person perspective and use past tense:''' # Get Response from GPT response = openai.Completion.create( - engine="davinci-instruct-beta-v3", prompt=prompt, + model=model, temperature=temperature, max_tokens=max_tokens, top_p=1, @@ -49,7 +49,7 @@ Summarize the notes in second person perspective and use past tense:''' return str(story).replace("\n\n", "") -def extract_search_type(text, api_key=None, temperature=0.5, max_tokens=100, verbose=0): +def extract_search_type(text, model, api_key=None, temperature=0.5, max_tokens=100, verbose=0): """ Extract search type from user query using OpenAI's GPT """ @@ -84,8 +84,8 @@ A:{ "search-type": "notes" }''' # Get Response from GPT response = openai.Completion.create( - engine="davinci", prompt=prompt, + model=model, temperature=temperature, max_tokens=max_tokens, top_p=1, @@ -98,7 +98,7 @@ A:{ "search-type": "notes" }''' return json.loads(story.strip(empty_escape_sequences)) -def understand(text, api_key=None, temperature=0.5, max_tokens=100, verbose=0): +def understand(text, model, api_key=None, temperature=0.5, max_tokens=100, verbose=0): """ Understand user input using OpenAI's GPT """ @@ -155,8 +155,8 @@ A: { "intent": {"type": "generate", "activity": "chat", "query": "Can you dance # Get Response from GPT response = openai.Completion.create( - engine="davinci", prompt=prompt, + model=model, temperature=temperature, max_tokens=max_tokens, top_p=1, @@ -169,7 +169,7 @@ A: { "intent": {"type": "generate", "activity": "chat", "query": "Can you dance return json.loads(story.strip(empty_escape_sequences)) -def converse(text, conversation_history=None, api_key=None, temperature=0.9, max_tokens=150): +def converse(text, model, conversation_history=None, api_key=None, temperature=0.9, max_tokens=150): """ Converse with user using OpenAI's GPT """ @@ -189,8 +189,8 @@ The following is a conversation with an AI assistant. The assistant is helpful, # Get Response from GPT response = openai.Completion.create( - engine="davinci", prompt=prompt, + model=model, temperature=temperature, max_tokens=max_tokens, top_p=1, diff --git a/src/routers/api_beta.py b/src/routers/api_beta.py index 0719e817..35983805 100644 --- a/src/routers/api_beta.py +++ b/src/routers/api_beta.py @@ -38,9 +38,11 @@ def chat(q: str): # Load Conversation History chat_session = state.processor_config.conversation.chat_session meta_log = state.processor_config.conversation.meta_log + model = state.processor_config.conversation.model + api_key = state.processor_config.conversation.openai_api_key # Converse with OpenAI GPT - metadata = understand(q, api_key=state.processor_config.conversation.openai_api_key, verbose=state.verbose) + metadata = understand(q, model=model, api_key=api_key, verbose=state.verbose) logger.debug(f'Understood: {get_from_dict(metadata, "intent")}') if get_from_dict(metadata, "intent", "memory-type") == "notes": @@ -48,9 +50,9 @@ def chat(q: str): result_list = search(query, n=1, t=SearchType.Org, r=True) collated_result = "\n".join([item.entry for item in result_list]) logger.debug(f'Semantically Similar Notes:\n{collated_result}') - gpt_response = summarize(collated_result, summary_type="notes", user_query=q, api_key=state.processor_config.conversation.openai_api_key) + gpt_response = summarize(collated_result, summary_type="notes", user_query=q, model=model, api_key=api_key) else: - gpt_response = converse(q, chat_session, api_key=state.processor_config.conversation.openai_api_key) + gpt_response = converse(q, model, chat_session, api_key=api_key) # Update Conversation History state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response) @@ -70,8 +72,9 @@ def shutdown_event(): chat_session = state.processor_config.conversation.chat_session openai_api_key = state.processor_config.conversation.openai_api_key conversation_log = state.processor_config.conversation.meta_log + model = state.processor_config.conversation.model session = { - "summary": summarize(chat_session, summary_type="chat", api_key=openai_api_key), + "summary": summarize(chat_session, summary_type="chat", model=model, api_key=openai_api_key), "session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"], "session-end": len(conversation_log["chat"]) } diff --git a/src/utils/config.py b/src/utils/config.py index c417b2bf..bfbf65d9 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -51,6 +51,7 @@ class SearchModels(): class ConversationProcessorConfigModel(): def __init__(self, processor_config: ConversationProcessorConfig): self.openai_api_key = processor_config.openai_api_key + self.model = processor_config.model self.conversation_logfile = Path(processor_config.conversation_logfile) self.chat_session = '' self.meta_log: dict = {} diff --git a/src/utils/rawconfig.py b/src/utils/rawconfig.py index 5ed3a9eb..c814726e 100644 --- a/src/utils/rawconfig.py +++ b/src/utils/rawconfig.py @@ -66,6 +66,7 @@ class SearchConfig(ConfigBase): class ConversationProcessorConfig(ConfigBase): openai_api_key: str conversation_logfile: Path + model: Optional[str] = "text-davinci-003" class ProcessorConfig(ConfigBase): conversation: Optional[ConversationProcessorConfig]