mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 21:29:13 +00:00
Move to newer chat models to extract questions & summarize chats
Deprecate usage of the older gpt3 models in-place of the newer chat based models - text-davinci-003 is only 50% cheaper than gpt4 and less reliable for question extraction - Using gpt-3.50turbo for summarization should reduce cost of chat - Keep conversation.chat_session as a list instead of a string - Update completion_with_backoff func to use ChatML format
This commit is contained in:
@@ -205,7 +205,7 @@ def configure_conversation_processor(conversation_processor_config):
|
|||||||
else:
|
else:
|
||||||
# Initialize Conversation Logs
|
# Initialize Conversation Logs
|
||||||
conversation_processor.meta_log = {}
|
conversation_processor.meta_log = {}
|
||||||
conversation_processor.chat_session = ""
|
conversation_processor.chat_session = []
|
||||||
|
|
||||||
return conversation_processor
|
return conversation_processor
|
||||||
|
|
||||||
@@ -225,9 +225,9 @@ def save_chat_session():
|
|||||||
chat_session = state.processor_config.conversation.chat_session
|
chat_session = state.processor_config.conversation.chat_session
|
||||||
openai_api_key = state.processor_config.conversation.openai_api_key
|
openai_api_key = state.processor_config.conversation.openai_api_key
|
||||||
conversation_log = state.processor_config.conversation.meta_log
|
conversation_log = state.processor_config.conversation.meta_log
|
||||||
model = state.processor_config.conversation.model
|
chat_model = state.processor_config.conversation.chat_model
|
||||||
session = {
|
session = {
|
||||||
"summary": summarize(chat_session, summary_type="chat", model=model, api_key=openai_api_key),
|
"summary": summarize(chat_session, model=chat_model, api_key=openai_api_key),
|
||||||
"session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"],
|
"session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"],
|
||||||
"session-end": len(conversation_log["chat"]),
|
"session-end": len(conversation_log["chat"]),
|
||||||
}
|
}
|
||||||
@@ -242,7 +242,7 @@ def save_chat_session():
|
|||||||
with open(conversation_logfile, "w+", encoding="utf-8") as logfile:
|
with open(conversation_logfile, "w+", encoding="utf-8") as logfile:
|
||||||
json.dump(conversation_log, logfile, indent=2)
|
json.dump(conversation_log, logfile, indent=2)
|
||||||
|
|
||||||
state.processor_config.conversation.chat_session = None
|
state.processor_config.conversation.chat_session = []
|
||||||
logger.info("📩 Saved current chat session to conversation logs")
|
logger.info("📩 Saved current chat session to conversation logs")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,9 @@ import logging
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
# External Packages
|
||||||
|
from langchain.schema import ChatMessage
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils.constants import empty_escape_sequences
|
from khoj.utils.constants import empty_escape_sequences
|
||||||
from khoj.processor.conversation import prompts
|
from khoj.processor.conversation import prompts
|
||||||
@@ -16,22 +19,16 @@ from khoj.processor.conversation.utils import (
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def summarize(text, summary_type, model, user_query=None, api_key=None, temperature=0.5, max_tokens=200):
|
def summarize(session, model, api_key=None, temperature=0.5, max_tokens=200):
|
||||||
"""
|
"""
|
||||||
Summarize user input using OpenAI's GPT
|
Summarize conversation session using the specified OpenAI chat model
|
||||||
"""
|
"""
|
||||||
# Setup Prompt based on Summary Type
|
messages = [ChatMessage(content=prompts.summarize_chat.format(), role="system")] + session
|
||||||
if summary_type == "chat":
|
|
||||||
prompt = prompts.summarize_chat.format(text=text)
|
|
||||||
elif summary_type == "notes":
|
|
||||||
prompt = prompts.summarize_notes.format(text=text, user_query=user_query)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid summary type: {summary_type}")
|
|
||||||
|
|
||||||
# Get Response from GPT
|
# Get Response from GPT
|
||||||
logger.debug(f"Prompt for GPT: {prompt}")
|
logger.debug(f"Prompt for GPT: {messages}")
|
||||||
response = completion_with_backoff(
|
response = completion_with_backoff(
|
||||||
prompt=prompt,
|
messages=messages,
|
||||||
model_name=model,
|
model_name=model,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
@@ -41,11 +38,11 @@ def summarize(text, summary_type, model, user_query=None, api_key=None, temperat
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Extract, Clean Message from GPT's Response
|
# Extract, Clean Message from GPT's Response
|
||||||
return str(response).replace("\n\n", "")
|
return str(response.content).replace("\n\n", "")
|
||||||
|
|
||||||
|
|
||||||
def extract_questions(
|
def extract_questions(
|
||||||
text, model: Optional[str] = "text-davinci-003", conversation_log={}, api_key=None, temperature=0, max_tokens=100
|
text, model: Optional[str] = "gpt-4", conversation_log={}, api_key=None, temperature=0, max_tokens=100
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Infer search queries to retrieve relevant notes to answer user query
|
Infer search queries to retrieve relevant notes to answer user query
|
||||||
@@ -74,10 +71,11 @@ def extract_questions(
|
|||||||
chat_history=chat_history,
|
chat_history=chat_history,
|
||||||
text=text,
|
text=text,
|
||||||
)
|
)
|
||||||
|
messages = [ChatMessage(content=prompt, role="assistant")]
|
||||||
|
|
||||||
# Get Response from GPT
|
# Get Response from GPT
|
||||||
response = completion_with_backoff(
|
response = completion_with_backoff(
|
||||||
prompt=prompt,
|
messages=messages,
|
||||||
model_name=model,
|
model_name=model,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
@@ -88,7 +86,7 @@ def extract_questions(
|
|||||||
# Extract, Clean Message from GPT's Response
|
# Extract, Clean Message from GPT's Response
|
||||||
try:
|
try:
|
||||||
questions = (
|
questions = (
|
||||||
response.strip(empty_escape_sequences)
|
response.content.strip(empty_escape_sequences)
|
||||||
.replace("['", '["')
|
.replace("['", '["')
|
||||||
.replace("']", '"]')
|
.replace("']", '"]')
|
||||||
.replace("', '", '", "')
|
.replace("', '", '", "')
|
||||||
|
|||||||
@@ -37,12 +37,7 @@ Question: {query}
|
|||||||
## Summarize Chat
|
## Summarize Chat
|
||||||
## --
|
## --
|
||||||
summarize_chat = PromptTemplate.from_template(
|
summarize_chat = PromptTemplate.from_template(
|
||||||
"""
|
f"{personality.format()} Summarize the conversation from your first person perspective"
|
||||||
You are an AI. Summarize the conversation below from your perspective:
|
|
||||||
|
|
||||||
{text}
|
|
||||||
|
|
||||||
Summarize the conversation from the AI's first-person perspective:"""
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import json
|
|||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
from langchain.chat_models import ChatOpenAI
|
from langchain.chat_models import ChatOpenAI
|
||||||
from langchain.llms import OpenAI
|
|
||||||
from langchain.schema import ChatMessage
|
from langchain.schema import ChatMessage
|
||||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
@@ -89,13 +88,11 @@ class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
|
|||||||
reraise=True,
|
reraise=True,
|
||||||
)
|
)
|
||||||
def completion_with_backoff(**kwargs):
|
def completion_with_backoff(**kwargs):
|
||||||
prompt = kwargs.pop("prompt")
|
messages = kwargs.pop("messages")
|
||||||
if "api_key" in kwargs:
|
if not "openai_api_key" in kwargs:
|
||||||
kwargs["openai_api_key"] = kwargs.get("api_key")
|
|
||||||
else:
|
|
||||||
kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY")
|
kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY")
|
||||||
llm = OpenAI(**kwargs, request_timeout=20, max_retries=1)
|
llm = ChatOpenAI(**kwargs, request_timeout=20, max_retries=1)
|
||||||
return llm(prompt)
|
return llm(messages=messages)
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
@@ -126,11 +123,12 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None):
|
|||||||
streaming=True,
|
streaming=True,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
callback_manager=BaseCallbackManager([callback_handler]),
|
callback_manager=BaseCallbackManager([callback_handler]),
|
||||||
model_name=model_name,
|
model_name=model_name, # type: ignore
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
|
openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
|
||||||
request_timeout=20,
|
request_timeout=20,
|
||||||
max_retries=1,
|
max_retries=1,
|
||||||
|
client=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
chat(messages=messages)
|
chat(messages=messages)
|
||||||
@@ -196,15 +194,6 @@ def reciprocal_conversation_to_chatml(message_pair):
|
|||||||
return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])]
|
return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])]
|
||||||
|
|
||||||
|
|
||||||
def message_to_prompt(
|
|
||||||
user_message, conversation_history="", gpt_message=None, start_sequence="\nAI:", restart_sequence="\nHuman:"
|
|
||||||
):
|
|
||||||
"""Create prompt for GPT from messages and conversation history"""
|
|
||||||
gpt_message = f" {gpt_message}" if gpt_message else ""
|
|
||||||
|
|
||||||
return f"{conversation_history}{restart_sequence} {user_message}{start_sequence}{gpt_message}"
|
|
||||||
|
|
||||||
|
|
||||||
def message_to_log(user_message, gpt_message, user_message_metadata={}, khoj_message_metadata={}, conversation_log=[]):
|
def message_to_log(user_message, gpt_message, user_message_metadata={}, khoj_message_metadata={}, conversation_log=[]):
|
||||||
"""Create json logs from messages, metadata for conversation log"""
|
"""Create json logs from messages, metadata for conversation log"""
|
||||||
default_khoj_message_metadata = {
|
default_khoj_message_metadata = {
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from sentence_transformers import util
|
|||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.configure import configure_processor, configure_search
|
from khoj.configure import configure_processor, configure_search
|
||||||
from khoj.processor.conversation.gpt import converse, extract_questions
|
from khoj.processor.conversation.gpt import converse, extract_questions
|
||||||
from khoj.processor.conversation.utils import message_to_log, message_to_prompt
|
from khoj.processor.conversation.utils import message_to_log, reciprocal_conversation_to_chatml
|
||||||
from khoj.search_type import image_search, text_search
|
from khoj.search_type import image_search, text_search
|
||||||
from khoj.search_filter.date_filter import DateFilter
|
from khoj.search_filter.date_filter import DateFilter
|
||||||
from khoj.search_filter.file_filter import FileFilter
|
from khoj.search_filter.file_filter import FileFilter
|
||||||
@@ -448,10 +448,9 @@ async def chat(
|
|||||||
user_message_time: str,
|
user_message_time: str,
|
||||||
compiled_references: List[str],
|
compiled_references: List[str],
|
||||||
inferred_queries: List[str],
|
inferred_queries: List[str],
|
||||||
chat_session: str,
|
|
||||||
meta_log,
|
meta_log,
|
||||||
):
|
):
|
||||||
state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
|
state.processor_config.conversation.chat_session += reciprocal_conversation_to_chatml([q, gpt_response])
|
||||||
state.processor_config.conversation.meta_log["chat"] = message_to_log(
|
state.processor_config.conversation.meta_log["chat"] = message_to_log(
|
||||||
q,
|
q,
|
||||||
gpt_response,
|
gpt_response,
|
||||||
@@ -470,7 +469,6 @@ async def chat(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Load Conversation History
|
# Load Conversation History
|
||||||
chat_session = state.processor_config.conversation.chat_session
|
|
||||||
meta_log = state.processor_config.conversation.meta_log
|
meta_log = state.processor_config.conversation.meta_log
|
||||||
|
|
||||||
# If user query is empty, return nothing
|
# If user query is empty, return nothing
|
||||||
@@ -479,7 +477,6 @@ async def chat(
|
|||||||
|
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
api_key = state.processor_config.conversation.openai_api_key
|
api_key = state.processor_config.conversation.openai_api_key
|
||||||
model = state.processor_config.conversation.model
|
|
||||||
chat_model = state.processor_config.conversation.chat_model
|
chat_model = state.processor_config.conversation.chat_model
|
||||||
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
conversation_type = "general" if q.startswith("@general") else "notes"
|
conversation_type = "general" if q.startswith("@general") else "notes"
|
||||||
@@ -489,7 +486,7 @@ async def chat(
|
|||||||
if conversation_type == "notes":
|
if conversation_type == "notes":
|
||||||
# Infer search queries from user message
|
# Infer search queries from user message
|
||||||
with timer("Extracting search queries took", logger):
|
with timer("Extracting search queries took", logger):
|
||||||
inferred_queries = extract_questions(q, model=model, api_key=api_key, conversation_log=meta_log)
|
inferred_queries = extract_questions(q, api_key=api_key, conversation_log=meta_log)
|
||||||
|
|
||||||
# Collate search results as context for GPT
|
# Collate search results as context for GPT
|
||||||
with timer("Searching knowledge base took", logger):
|
with timer("Searching knowledge base took", logger):
|
||||||
@@ -525,7 +522,6 @@ async def chat(
|
|||||||
user_message_time=user_message_time,
|
user_message_time=user_message_time,
|
||||||
compiled_references=compiled_references,
|
compiled_references=compiled_references,
|
||||||
inferred_queries=inferred_queries,
|
inferred_queries=inferred_queries,
|
||||||
chat_session=chat_session,
|
|
||||||
meta_log=meta_log,
|
meta_log=meta_log,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -72,7 +72,7 @@ class ConversationProcessorConfigModel:
|
|||||||
self.model = processor_config.model
|
self.model = processor_config.model
|
||||||
self.chat_model = processor_config.chat_model
|
self.chat_model = processor_config.chat_model
|
||||||
self.conversation_logfile = Path(processor_config.conversation_logfile)
|
self.conversation_logfile = Path(processor_config.conversation_logfile)
|
||||||
self.chat_session = ""
|
self.chat_session: List[str] = []
|
||||||
self.meta_log: dict = {}
|
self.meta_log: dict = {}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user