Move to newer chat models to extract questions & summarize chats

Deprecate usage of the older gpt3 models in-place of the newer chat
based models
- text-davinci-003 is only 50% cheaper than gpt4 and less reliable for
  question extraction
- Using gpt-3.50turbo for summarization should reduce cost of chat

- Keep conversation.chat_session as a list instead of a string
- Update completion_with_backoff func to use ChatML format
This commit is contained in:
Debanjum Singh Solanky
2023-07-07 17:14:23 -07:00
parent 171ce19e1f
commit af30d01e85
6 changed files with 28 additions and 50 deletions

View File

@@ -205,7 +205,7 @@ def configure_conversation_processor(conversation_processor_config):
else:
# Initialize Conversation Logs
conversation_processor.meta_log = {}
conversation_processor.chat_session = ""
conversation_processor.chat_session = []
return conversation_processor
@@ -225,9 +225,9 @@ def save_chat_session():
chat_session = state.processor_config.conversation.chat_session
openai_api_key = state.processor_config.conversation.openai_api_key
conversation_log = state.processor_config.conversation.meta_log
model = state.processor_config.conversation.model
chat_model = state.processor_config.conversation.chat_model
session = {
"summary": summarize(chat_session, summary_type="chat", model=model, api_key=openai_api_key),
"summary": summarize(chat_session, model=chat_model, api_key=openai_api_key),
"session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"],
"session-end": len(conversation_log["chat"]),
}
@@ -242,7 +242,7 @@ def save_chat_session():
with open(conversation_logfile, "w+", encoding="utf-8") as logfile:
json.dump(conversation_log, logfile, indent=2)
state.processor_config.conversation.chat_session = None
state.processor_config.conversation.chat_session = []
logger.info("📩 Saved current chat session to conversation logs")

View File

@@ -3,6 +3,9 @@ import logging
from datetime import datetime
from typing import Optional
# External Packages
from langchain.schema import ChatMessage
# Internal Packages
from khoj.utils.constants import empty_escape_sequences
from khoj.processor.conversation import prompts
@@ -16,22 +19,16 @@ from khoj.processor.conversation.utils import (
logger = logging.getLogger(__name__)
def summarize(text, summary_type, model, user_query=None, api_key=None, temperature=0.5, max_tokens=200):
def summarize(session, model, api_key=None, temperature=0.5, max_tokens=200):
"""
Summarize user input using OpenAI's GPT
Summarize conversation session using the specified OpenAI chat model
"""
# Setup Prompt based on Summary Type
if summary_type == "chat":
prompt = prompts.summarize_chat.format(text=text)
elif summary_type == "notes":
prompt = prompts.summarize_notes.format(text=text, user_query=user_query)
else:
raise ValueError(f"Invalid summary type: {summary_type}")
messages = [ChatMessage(content=prompts.summarize_chat.format(), role="system")] + session
# Get Response from GPT
logger.debug(f"Prompt for GPT: {prompt}")
logger.debug(f"Prompt for GPT: {messages}")
response = completion_with_backoff(
prompt=prompt,
messages=messages,
model_name=model,
temperature=temperature,
max_tokens=max_tokens,
@@ -41,11 +38,11 @@ def summarize(text, summary_type, model, user_query=None, api_key=None, temperat
)
# Extract, Clean Message from GPT's Response
return str(response).replace("\n\n", "")
return str(response.content).replace("\n\n", "")
def extract_questions(
text, model: Optional[str] = "text-davinci-003", conversation_log={}, api_key=None, temperature=0, max_tokens=100
text, model: Optional[str] = "gpt-4", conversation_log={}, api_key=None, temperature=0, max_tokens=100
):
"""
Infer search queries to retrieve relevant notes to answer user query
@@ -74,10 +71,11 @@ def extract_questions(
chat_history=chat_history,
text=text,
)
messages = [ChatMessage(content=prompt, role="assistant")]
# Get Response from GPT
response = completion_with_backoff(
prompt=prompt,
messages=messages,
model_name=model,
temperature=temperature,
max_tokens=max_tokens,
@@ -88,7 +86,7 @@ def extract_questions(
# Extract, Clean Message from GPT's Response
try:
questions = (
response.strip(empty_escape_sequences)
response.content.strip(empty_escape_sequences)
.replace("['", '["')
.replace("']", '"]')
.replace("', '", '", "')

View File

@@ -37,12 +37,7 @@ Question: {query}
## Summarize Chat
## --
summarize_chat = PromptTemplate.from_template(
"""
You are an AI. Summarize the conversation below from your perspective:
{text}
Summarize the conversation from the AI's first-person perspective:"""
f"{personality.format()} Summarize the conversation from your first person perspective"
)

View File

@@ -9,7 +9,6 @@ import json
# External Packages
from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAI
from langchain.schema import ChatMessage
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.callbacks.base import BaseCallbackManager
@@ -89,13 +88,11 @@ class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
reraise=True,
)
def completion_with_backoff(**kwargs):
prompt = kwargs.pop("prompt")
if "api_key" in kwargs:
kwargs["openai_api_key"] = kwargs.get("api_key")
else:
messages = kwargs.pop("messages")
if not "openai_api_key" in kwargs:
kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY")
llm = OpenAI(**kwargs, request_timeout=20, max_retries=1)
return llm(prompt)
llm = ChatOpenAI(**kwargs, request_timeout=20, max_retries=1)
return llm(messages=messages)
@retry(
@@ -126,11 +123,12 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None):
streaming=True,
verbose=True,
callback_manager=BaseCallbackManager([callback_handler]),
model_name=model_name,
model_name=model_name, # type: ignore
temperature=temperature,
openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
request_timeout=20,
max_retries=1,
client=None,
)
chat(messages=messages)
@@ -196,15 +194,6 @@ def reciprocal_conversation_to_chatml(message_pair):
return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])]
def message_to_prompt(
user_message, conversation_history="", gpt_message=None, start_sequence="\nAI:", restart_sequence="\nHuman:"
):
"""Create prompt for GPT from messages and conversation history"""
gpt_message = f" {gpt_message}" if gpt_message else ""
return f"{conversation_history}{restart_sequence} {user_message}{start_sequence}{gpt_message}"
def message_to_log(user_message, gpt_message, user_message_metadata={}, khoj_message_metadata={}, conversation_log=[]):
"""Create json logs from messages, metadata for conversation log"""
default_khoj_message_metadata = {

View File

@@ -15,7 +15,7 @@ from sentence_transformers import util
# Internal Packages
from khoj.configure import configure_processor, configure_search
from khoj.processor.conversation.gpt import converse, extract_questions
from khoj.processor.conversation.utils import message_to_log, message_to_prompt
from khoj.processor.conversation.utils import message_to_log, reciprocal_conversation_to_chatml
from khoj.search_type import image_search, text_search
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
@@ -448,10 +448,9 @@ async def chat(
user_message_time: str,
compiled_references: List[str],
inferred_queries: List[str],
chat_session: str,
meta_log,
):
state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
state.processor_config.conversation.chat_session += reciprocal_conversation_to_chatml([q, gpt_response])
state.processor_config.conversation.meta_log["chat"] = message_to_log(
q,
gpt_response,
@@ -470,7 +469,6 @@ async def chat(
)
# Load Conversation History
chat_session = state.processor_config.conversation.chat_session
meta_log = state.processor_config.conversation.meta_log
# If user query is empty, return nothing
@@ -479,7 +477,6 @@ async def chat(
# Initialize Variables
api_key = state.processor_config.conversation.openai_api_key
model = state.processor_config.conversation.model
chat_model = state.processor_config.conversation.chat_model
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
conversation_type = "general" if q.startswith("@general") else "notes"
@@ -489,7 +486,7 @@ async def chat(
if conversation_type == "notes":
# Infer search queries from user message
with timer("Extracting search queries took", logger):
inferred_queries = extract_questions(q, model=model, api_key=api_key, conversation_log=meta_log)
inferred_queries = extract_questions(q, api_key=api_key, conversation_log=meta_log)
# Collate search results as context for GPT
with timer("Searching knowledge base took", logger):
@@ -525,7 +522,6 @@ async def chat(
user_message_time=user_message_time,
compiled_references=compiled_references,
inferred_queries=inferred_queries,
chat_session=chat_session,
meta_log=meta_log,
)

View File

@@ -72,7 +72,7 @@ class ConversationProcessorConfigModel:
self.model = processor_config.model
self.chat_model = processor_config.chat_model
self.conversation_logfile = Path(processor_config.conversation_logfile)
self.chat_session = ""
self.chat_session: List[str] = []
self.meta_log: dict = {}