mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-07 21:29:13 +00:00
Allow Offline Chat model calling functions to save conversation traces
This commit is contained in:
@@ -12,11 +12,12 @@ from khoj.processor.conversation import prompts
|
|||||||
from khoj.processor.conversation.offline.utils import download_model
|
from khoj.processor.conversation.offline.utils import download_model
|
||||||
from khoj.processor.conversation.utils import (
|
from khoj.processor.conversation.utils import (
|
||||||
ThreadedGenerator,
|
ThreadedGenerator,
|
||||||
|
commit_conversation_trace,
|
||||||
generate_chatml_messages_with_context,
|
generate_chatml_messages_with_context,
|
||||||
)
|
)
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.constants import empty_escape_sequences
|
from khoj.utils.constants import empty_escape_sequences
|
||||||
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
|
from khoj.utils.helpers import ConversationCommand, in_debug_mode, is_none_or_empty
|
||||||
from khoj.utils.rawconfig import LocationData
|
from khoj.utils.rawconfig import LocationData
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -34,6 +35,7 @@ def extract_questions_offline(
|
|||||||
max_prompt_size: int = None,
|
max_prompt_size: int = None,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
personality_context: Optional[str] = None,
|
personality_context: Optional[str] = None,
|
||||||
|
tracer: dict = {},
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Infer search queries to retrieve relevant notes to answer user query
|
Infer search queries to retrieve relevant notes to answer user query
|
||||||
@@ -94,6 +96,7 @@ def extract_questions_offline(
|
|||||||
max_prompt_size=max_prompt_size,
|
max_prompt_size=max_prompt_size,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
response_type="json_object",
|
response_type="json_object",
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
state.chat_lock.release()
|
state.chat_lock.release()
|
||||||
@@ -146,6 +149,7 @@ def converse_offline(
|
|||||||
location_data: LocationData = None,
|
location_data: LocationData = None,
|
||||||
user_name: str = None,
|
user_name: str = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
tracer: dict = {},
|
||||||
) -> Union[ThreadedGenerator, Iterator[str]]:
|
) -> Union[ThreadedGenerator, Iterator[str]]:
|
||||||
"""
|
"""
|
||||||
Converse with user using Llama
|
Converse with user using Llama
|
||||||
@@ -154,6 +158,7 @@ def converse_offline(
|
|||||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||||
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
||||||
compiled_references_message = "\n\n".join({f"{item['compiled']}" for item in references})
|
compiled_references_message = "\n\n".join({f"{item['compiled']}" for item in references})
|
||||||
|
tracer["chat_model"] = model
|
||||||
|
|
||||||
current_date = datetime.now()
|
current_date = datetime.now()
|
||||||
|
|
||||||
@@ -213,13 +218,14 @@ def converse_offline(
|
|||||||
logger.debug(f"Conversation Context for {model}: {truncated_messages}")
|
logger.debug(f"Conversation Context for {model}: {truncated_messages}")
|
||||||
|
|
||||||
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
|
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
|
||||||
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size))
|
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer))
|
||||||
t.start()
|
t.start()
|
||||||
return g
|
return g
|
||||||
|
|
||||||
|
|
||||||
def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None):
|
def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None, tracer: dict = {}):
|
||||||
stop_phrases = ["<s>", "INST]", "Notes:"]
|
stop_phrases = ["<s>", "INST]", "Notes:"]
|
||||||
|
aggregated_response = ""
|
||||||
|
|
||||||
state.chat_lock.acquire()
|
state.chat_lock.acquire()
|
||||||
try:
|
try:
|
||||||
@@ -227,7 +233,14 @@ def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int
|
|||||||
messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True
|
messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True
|
||||||
)
|
)
|
||||||
for response in response_iterator:
|
for response in response_iterator:
|
||||||
g.send(response["choices"][0]["delta"].get("content", ""))
|
response_delta = response["choices"][0]["delta"].get("content", "")
|
||||||
|
aggregated_response += response_delta
|
||||||
|
g.send(response_delta)
|
||||||
|
|
||||||
|
# Save conversation trace
|
||||||
|
if in_debug_mode() or state.verbose > 1:
|
||||||
|
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
state.chat_lock.release()
|
state.chat_lock.release()
|
||||||
g.close()
|
g.close()
|
||||||
@@ -242,6 +255,7 @@ def send_message_to_model_offline(
|
|||||||
stop=[],
|
stop=[],
|
||||||
max_prompt_size: int = None,
|
max_prompt_size: int = None,
|
||||||
response_type: str = "text",
|
response_type: str = "text",
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||||
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
||||||
@@ -249,7 +263,17 @@ def send_message_to_model_offline(
|
|||||||
response = offline_chat_model.create_chat_completion(
|
response = offline_chat_model.create_chat_completion(
|
||||||
messages_dict, stop=stop, stream=streaming, temperature=temperature, response_format={"type": response_type}
|
messages_dict, stop=stop, stream=streaming, temperature=temperature, response_format={"type": response_type}
|
||||||
)
|
)
|
||||||
|
|
||||||
if streaming:
|
if streaming:
|
||||||
return response
|
return response
|
||||||
else:
|
|
||||||
return response["choices"][0]["message"].get("content", "")
|
response_text = response["choices"][0]["message"].get("content", "")
|
||||||
|
|
||||||
|
# Save conversation trace for non-streaming responses
|
||||||
|
# Streamed responses need to be saved by the calling function
|
||||||
|
tracer["chat_model"] = model
|
||||||
|
tracer["temperature"] = temperature
|
||||||
|
if in_debug_mode() or state.verbose > 1:
|
||||||
|
commit_conversation_trace(messages, response_text, tracer)
|
||||||
|
|
||||||
|
return response_text
|
||||||
|
|||||||
Reference in New Issue
Block a user