mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-04 05:39:06 +00:00
Move to newer chat models to extract questions & summarize chats
Deprecate usage of the older gpt3 models in-place of the newer chat based models - text-davinci-003 is only 50% cheaper than gpt4 and less reliable for question extraction - Using gpt-3.50turbo for summarization should reduce cost of chat - Keep conversation.chat_session as a list instead of a string - Update completion_with_backoff func to use ChatML format
This commit is contained in:
@@ -205,7 +205,7 @@ def configure_conversation_processor(conversation_processor_config):
|
||||
else:
|
||||
# Initialize Conversation Logs
|
||||
conversation_processor.meta_log = {}
|
||||
conversation_processor.chat_session = ""
|
||||
conversation_processor.chat_session = []
|
||||
|
||||
return conversation_processor
|
||||
|
||||
@@ -225,9 +225,9 @@ def save_chat_session():
|
||||
chat_session = state.processor_config.conversation.chat_session
|
||||
openai_api_key = state.processor_config.conversation.openai_api_key
|
||||
conversation_log = state.processor_config.conversation.meta_log
|
||||
model = state.processor_config.conversation.model
|
||||
chat_model = state.processor_config.conversation.chat_model
|
||||
session = {
|
||||
"summary": summarize(chat_session, summary_type="chat", model=model, api_key=openai_api_key),
|
||||
"summary": summarize(chat_session, model=chat_model, api_key=openai_api_key),
|
||||
"session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"],
|
||||
"session-end": len(conversation_log["chat"]),
|
||||
}
|
||||
@@ -242,7 +242,7 @@ def save_chat_session():
|
||||
with open(conversation_logfile, "w+", encoding="utf-8") as logfile:
|
||||
json.dump(conversation_log, logfile, indent=2)
|
||||
|
||||
state.processor_config.conversation.chat_session = None
|
||||
state.processor_config.conversation.chat_session = []
|
||||
logger.info("📩 Saved current chat session to conversation logs")
|
||||
|
||||
|
||||
|
||||
@@ -3,6 +3,9 @@ import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
# External Packages
|
||||
from langchain.schema import ChatMessage
|
||||
|
||||
# Internal Packages
|
||||
from khoj.utils.constants import empty_escape_sequences
|
||||
from khoj.processor.conversation import prompts
|
||||
@@ -16,22 +19,16 @@ from khoj.processor.conversation.utils import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def summarize(text, summary_type, model, user_query=None, api_key=None, temperature=0.5, max_tokens=200):
|
||||
def summarize(session, model, api_key=None, temperature=0.5, max_tokens=200):
|
||||
"""
|
||||
Summarize user input using OpenAI's GPT
|
||||
Summarize conversation session using the specified OpenAI chat model
|
||||
"""
|
||||
# Setup Prompt based on Summary Type
|
||||
if summary_type == "chat":
|
||||
prompt = prompts.summarize_chat.format(text=text)
|
||||
elif summary_type == "notes":
|
||||
prompt = prompts.summarize_notes.format(text=text, user_query=user_query)
|
||||
else:
|
||||
raise ValueError(f"Invalid summary type: {summary_type}")
|
||||
messages = [ChatMessage(content=prompts.summarize_chat.format(), role="system")] + session
|
||||
|
||||
# Get Response from GPT
|
||||
logger.debug(f"Prompt for GPT: {prompt}")
|
||||
logger.debug(f"Prompt for GPT: {messages}")
|
||||
response = completion_with_backoff(
|
||||
prompt=prompt,
|
||||
messages=messages,
|
||||
model_name=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
@@ -41,11 +38,11 @@ def summarize(text, summary_type, model, user_query=None, api_key=None, temperat
|
||||
)
|
||||
|
||||
# Extract, Clean Message from GPT's Response
|
||||
return str(response).replace("\n\n", "")
|
||||
return str(response.content).replace("\n\n", "")
|
||||
|
||||
|
||||
def extract_questions(
|
||||
text, model: Optional[str] = "text-davinci-003", conversation_log={}, api_key=None, temperature=0, max_tokens=100
|
||||
text, model: Optional[str] = "gpt-4", conversation_log={}, api_key=None, temperature=0, max_tokens=100
|
||||
):
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
@@ -74,10 +71,11 @@ def extract_questions(
|
||||
chat_history=chat_history,
|
||||
text=text,
|
||||
)
|
||||
messages = [ChatMessage(content=prompt, role="assistant")]
|
||||
|
||||
# Get Response from GPT
|
||||
response = completion_with_backoff(
|
||||
prompt=prompt,
|
||||
messages=messages,
|
||||
model_name=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
@@ -88,7 +86,7 @@ def extract_questions(
|
||||
# Extract, Clean Message from GPT's Response
|
||||
try:
|
||||
questions = (
|
||||
response.strip(empty_escape_sequences)
|
||||
response.content.strip(empty_escape_sequences)
|
||||
.replace("['", '["')
|
||||
.replace("']", '"]')
|
||||
.replace("', '", '", "')
|
||||
|
||||
@@ -37,12 +37,7 @@ Question: {query}
|
||||
## Summarize Chat
|
||||
## --
|
||||
summarize_chat = PromptTemplate.from_template(
|
||||
"""
|
||||
You are an AI. Summarize the conversation below from your perspective:
|
||||
|
||||
{text}
|
||||
|
||||
Summarize the conversation from the AI's first-person perspective:"""
|
||||
f"{personality.format()} Summarize the conversation from your first person perspective"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ import json
|
||||
|
||||
# External Packages
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.llms import OpenAI
|
||||
from langchain.schema import ChatMessage
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
@@ -89,13 +88,11 @@ class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
|
||||
reraise=True,
|
||||
)
|
||||
def completion_with_backoff(**kwargs):
|
||||
prompt = kwargs.pop("prompt")
|
||||
if "api_key" in kwargs:
|
||||
kwargs["openai_api_key"] = kwargs.get("api_key")
|
||||
else:
|
||||
messages = kwargs.pop("messages")
|
||||
if not "openai_api_key" in kwargs:
|
||||
kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY")
|
||||
llm = OpenAI(**kwargs, request_timeout=20, max_retries=1)
|
||||
return llm(prompt)
|
||||
llm = ChatOpenAI(**kwargs, request_timeout=20, max_retries=1)
|
||||
return llm(messages=messages)
|
||||
|
||||
|
||||
@retry(
|
||||
@@ -126,11 +123,12 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None):
|
||||
streaming=True,
|
||||
verbose=True,
|
||||
callback_manager=BaseCallbackManager([callback_handler]),
|
||||
model_name=model_name,
|
||||
model_name=model_name, # type: ignore
|
||||
temperature=temperature,
|
||||
openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
|
||||
request_timeout=20,
|
||||
max_retries=1,
|
||||
client=None,
|
||||
)
|
||||
|
||||
chat(messages=messages)
|
||||
@@ -196,15 +194,6 @@ def reciprocal_conversation_to_chatml(message_pair):
|
||||
return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])]
|
||||
|
||||
|
||||
def message_to_prompt(
|
||||
user_message, conversation_history="", gpt_message=None, start_sequence="\nAI:", restart_sequence="\nHuman:"
|
||||
):
|
||||
"""Create prompt for GPT from messages and conversation history"""
|
||||
gpt_message = f" {gpt_message}" if gpt_message else ""
|
||||
|
||||
return f"{conversation_history}{restart_sequence} {user_message}{start_sequence}{gpt_message}"
|
||||
|
||||
|
||||
def message_to_log(user_message, gpt_message, user_message_metadata={}, khoj_message_metadata={}, conversation_log=[]):
|
||||
"""Create json logs from messages, metadata for conversation log"""
|
||||
default_khoj_message_metadata = {
|
||||
|
||||
@@ -15,7 +15,7 @@ from sentence_transformers import util
|
||||
# Internal Packages
|
||||
from khoj.configure import configure_processor, configure_search
|
||||
from khoj.processor.conversation.gpt import converse, extract_questions
|
||||
from khoj.processor.conversation.utils import message_to_log, message_to_prompt
|
||||
from khoj.processor.conversation.utils import message_to_log, reciprocal_conversation_to_chatml
|
||||
from khoj.search_type import image_search, text_search
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
@@ -448,10 +448,9 @@ async def chat(
|
||||
user_message_time: str,
|
||||
compiled_references: List[str],
|
||||
inferred_queries: List[str],
|
||||
chat_session: str,
|
||||
meta_log,
|
||||
):
|
||||
state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
|
||||
state.processor_config.conversation.chat_session += reciprocal_conversation_to_chatml([q, gpt_response])
|
||||
state.processor_config.conversation.meta_log["chat"] = message_to_log(
|
||||
q,
|
||||
gpt_response,
|
||||
@@ -470,7 +469,6 @@ async def chat(
|
||||
)
|
||||
|
||||
# Load Conversation History
|
||||
chat_session = state.processor_config.conversation.chat_session
|
||||
meta_log = state.processor_config.conversation.meta_log
|
||||
|
||||
# If user query is empty, return nothing
|
||||
@@ -479,7 +477,6 @@ async def chat(
|
||||
|
||||
# Initialize Variables
|
||||
api_key = state.processor_config.conversation.openai_api_key
|
||||
model = state.processor_config.conversation.model
|
||||
chat_model = state.processor_config.conversation.chat_model
|
||||
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
conversation_type = "general" if q.startswith("@general") else "notes"
|
||||
@@ -489,7 +486,7 @@ async def chat(
|
||||
if conversation_type == "notes":
|
||||
# Infer search queries from user message
|
||||
with timer("Extracting search queries took", logger):
|
||||
inferred_queries = extract_questions(q, model=model, api_key=api_key, conversation_log=meta_log)
|
||||
inferred_queries = extract_questions(q, api_key=api_key, conversation_log=meta_log)
|
||||
|
||||
# Collate search results as context for GPT
|
||||
with timer("Searching knowledge base took", logger):
|
||||
@@ -525,7 +522,6 @@ async def chat(
|
||||
user_message_time=user_message_time,
|
||||
compiled_references=compiled_references,
|
||||
inferred_queries=inferred_queries,
|
||||
chat_session=chat_session,
|
||||
meta_log=meta_log,
|
||||
)
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ class ConversationProcessorConfigModel:
|
||||
self.model = processor_config.model
|
||||
self.chat_model = processor_config.chat_model
|
||||
self.conversation_logfile = Path(processor_config.conversation_logfile)
|
||||
self.chat_session = ""
|
||||
self.chat_session: List[str] = []
|
||||
self.meta_log: dict = {}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user