Allow Offline Chat model calling functions to save conversation traces

This commit is contained in:
Debanjum Singh Solanky
2024-10-24 14:26:57 -07:00
parent eb6424f14d
commit a3022b7556

View File

@@ -12,11 +12,12 @@ from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.utils import download_model from khoj.processor.conversation.offline.utils import download_model
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import (
ThreadedGenerator, ThreadedGenerator,
commit_conversation_trace,
generate_chatml_messages_with_context, generate_chatml_messages_with_context,
) )
from khoj.utils import state from khoj.utils import state
from khoj.utils.constants import empty_escape_sequences from khoj.utils.constants import empty_escape_sequences
from khoj.utils.helpers import ConversationCommand, is_none_or_empty from khoj.utils.helpers import ConversationCommand, in_debug_mode, is_none_or_empty
from khoj.utils.rawconfig import LocationData from khoj.utils.rawconfig import LocationData
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -34,6 +35,7 @@ def extract_questions_offline(
max_prompt_size: int = None, max_prompt_size: int = None,
temperature: float = 0.7, temperature: float = 0.7,
personality_context: Optional[str] = None, personality_context: Optional[str] = None,
tracer: dict = {},
) -> List[str]: ) -> List[str]:
""" """
Infer search queries to retrieve relevant notes to answer user query Infer search queries to retrieve relevant notes to answer user query
@@ -94,6 +96,7 @@ def extract_questions_offline(
max_prompt_size=max_prompt_size, max_prompt_size=max_prompt_size,
temperature=temperature, temperature=temperature,
response_type="json_object", response_type="json_object",
tracer=tracer,
) )
finally: finally:
state.chat_lock.release() state.chat_lock.release()
@@ -146,6 +149,7 @@ def converse_offline(
location_data: LocationData = None, location_data: LocationData = None,
user_name: str = None, user_name: str = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {},
) -> Union[ThreadedGenerator, Iterator[str]]: ) -> Union[ThreadedGenerator, Iterator[str]]:
""" """
Converse with user using Llama Converse with user using Llama
@@ -154,6 +158,7 @@ def converse_offline(
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
compiled_references_message = "\n\n".join({f"{item['compiled']}" for item in references}) compiled_references_message = "\n\n".join({f"{item['compiled']}" for item in references})
tracer["chat_model"] = model
current_date = datetime.now() current_date = datetime.now()
@@ -213,13 +218,14 @@ def converse_offline(
logger.debug(f"Conversation Context for {model}: {truncated_messages}") logger.debug(f"Conversation Context for {model}: {truncated_messages}")
g = ThreadedGenerator(references, online_results, completion_func=completion_func) g = ThreadedGenerator(references, online_results, completion_func=completion_func)
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size)) t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer))
t.start() t.start()
return g return g
def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None): def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None, tracer: dict = {}):
stop_phrases = ["<s>", "INST]", "Notes:"] stop_phrases = ["<s>", "INST]", "Notes:"]
aggregated_response = ""
state.chat_lock.acquire() state.chat_lock.acquire()
try: try:
@@ -227,7 +233,14 @@ def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int
messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True
) )
for response in response_iterator: for response in response_iterator:
g.send(response["choices"][0]["delta"].get("content", "")) response_delta = response["choices"][0]["delta"].get("content", "")
aggregated_response += response_delta
g.send(response_delta)
# Save conversation trace
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, aggregated_response, tracer)
finally: finally:
state.chat_lock.release() state.chat_lock.release()
g.close() g.close()
@@ -242,6 +255,7 @@ def send_message_to_model_offline(
stop=[], stop=[],
max_prompt_size: int = None, max_prompt_size: int = None,
response_type: str = "text", response_type: str = "text",
tracer: dict = {},
): ):
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
@@ -249,7 +263,17 @@ def send_message_to_model_offline(
response = offline_chat_model.create_chat_completion( response = offline_chat_model.create_chat_completion(
messages_dict, stop=stop, stream=streaming, temperature=temperature, response_format={"type": response_type} messages_dict, stop=stop, stream=streaming, temperature=temperature, response_format={"type": response_type}
) )
if streaming: if streaming:
return response return response
else:
return response["choices"][0]["message"].get("content", "") response_text = response["choices"][0]["message"].get("content", "")
# Save conversation trace for non-streaming responses
# Streamed responses need to be saved by the calling function
tracer["chat_model"] = model
tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, response_text, tracer)
return response_text