mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 05:39:12 +00:00
308 lines
11 KiB
Python
308 lines
11 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
from datetime import datetime, timedelta
|
|
from threading import Thread
|
|
from typing import Any, Iterator, List, Optional, Union
|
|
|
|
import pyjson5
|
|
from langchain.schema import ChatMessage
|
|
from llama_cpp import Llama
|
|
|
|
from khoj.database.models import Agent, ChatModelOptions, KhojUser
|
|
from khoj.processor.conversation import prompts
|
|
from khoj.processor.conversation.offline.utils import download_model
|
|
from khoj.processor.conversation.utils import (
|
|
ThreadedGenerator,
|
|
clean_json,
|
|
commit_conversation_trace,
|
|
generate_chatml_messages_with_context,
|
|
messages_to_print,
|
|
)
|
|
from khoj.utils import state
|
|
from khoj.utils.constants import empty_escape_sequences
|
|
from khoj.utils.helpers import (
|
|
ConversationCommand,
|
|
in_debug_mode,
|
|
is_none_or_empty,
|
|
is_promptrace_enabled,
|
|
truncate_code_context,
|
|
)
|
|
from khoj.utils.rawconfig import FileAttachment, LocationData
|
|
from khoj.utils.yaml import yaml_dump
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def extract_questions_offline(
|
|
text: str,
|
|
model: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
|
|
loaded_model: Union[Any, None] = None,
|
|
conversation_log={},
|
|
use_history: bool = True,
|
|
should_extract_questions: bool = True,
|
|
location_data: LocationData = None,
|
|
user: KhojUser = None,
|
|
max_prompt_size: int = None,
|
|
temperature: float = 0.7,
|
|
personality_context: Optional[str] = None,
|
|
query_files: str = None,
|
|
tracer: dict = {},
|
|
) -> List[str]:
|
|
"""
|
|
Infer search queries to retrieve relevant notes to answer user query
|
|
"""
|
|
all_questions = text.split("? ")
|
|
all_questions = [q + "?" for q in all_questions[:-1]] + [all_questions[-1]]
|
|
|
|
if not should_extract_questions:
|
|
return all_questions
|
|
|
|
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
|
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
|
|
|
location = f"{location_data}" if location_data else "Unknown"
|
|
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""
|
|
|
|
# Extract Past User Message and Inferred Questions from Conversation Log
|
|
chat_history = ""
|
|
|
|
if use_history:
|
|
for chat in conversation_log.get("chat", [])[-4:]:
|
|
if chat["by"] == "khoj" and "text-to-image" not in chat["intent"].get("type"):
|
|
chat_history += f"Q: {chat['intent']['query']}\n"
|
|
chat_history += f"Khoj: {chat['message']}\n\n"
|
|
|
|
# Get dates relative to today for prompt creation
|
|
today = datetime.today()
|
|
yesterday = (today - timedelta(days=1)).strftime("%Y-%m-%d")
|
|
last_year = today.year - 1
|
|
example_questions = prompts.extract_questions_offline.format(
|
|
query=text,
|
|
chat_history=chat_history,
|
|
current_date=today.strftime("%Y-%m-%d"),
|
|
day_of_week=today.strftime("%A"),
|
|
current_month=today.strftime("%Y-%m"),
|
|
yesterday_date=yesterday,
|
|
last_year=last_year,
|
|
this_year=today.year,
|
|
location=location,
|
|
username=username,
|
|
personality_context=personality_context,
|
|
)
|
|
|
|
messages = generate_chatml_messages_with_context(
|
|
example_questions,
|
|
model_name=model,
|
|
loaded_model=offline_chat_model,
|
|
max_prompt_size=max_prompt_size,
|
|
model_type=ChatModelOptions.ModelType.OFFLINE,
|
|
query_files=query_files,
|
|
)
|
|
|
|
state.chat_lock.acquire()
|
|
try:
|
|
response = send_message_to_model_offline(
|
|
messages,
|
|
loaded_model=offline_chat_model,
|
|
model=model,
|
|
max_prompt_size=max_prompt_size,
|
|
temperature=temperature,
|
|
response_type="json_object",
|
|
tracer=tracer,
|
|
)
|
|
finally:
|
|
state.chat_lock.release()
|
|
|
|
# Extract and clean the chat model's response
|
|
try:
|
|
response = clean_json(empty_escape_sequences)
|
|
response = pyjson5.loads(response)
|
|
questions = [q.strip() for q in response["queries"] if q.strip()]
|
|
questions = filter_questions(questions)
|
|
except:
|
|
logger.warning(f"Llama returned invalid JSON. Falling back to using user message as search query.\n{response}")
|
|
return all_questions
|
|
logger.debug(f"Questions extracted by {model}: {questions}")
|
|
return questions
|
|
|
|
|
|
def filter_questions(questions: List[str]):
|
|
# Skip questions that seem to be apologizing for not being able to answer the question
|
|
hint_words = [
|
|
"sorry",
|
|
"apologize",
|
|
"unable",
|
|
"can't",
|
|
"cannot",
|
|
"don't know",
|
|
"don't understand",
|
|
"do not know",
|
|
"do not understand",
|
|
]
|
|
filtered_questions = set()
|
|
for q in questions:
|
|
if not any([word in q.lower() for word in hint_words]) and not is_none_or_empty(q):
|
|
filtered_questions.add(q)
|
|
|
|
return list(filtered_questions)
|
|
|
|
|
|
def converse_offline(
|
|
user_query,
|
|
references=[],
|
|
online_results={},
|
|
code_results={},
|
|
conversation_log={},
|
|
model: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
|
|
loaded_model: Union[Any, None] = None,
|
|
completion_func=None,
|
|
conversation_commands=[ConversationCommand.Default],
|
|
max_prompt_size=None,
|
|
tokenizer_name=None,
|
|
location_data: LocationData = None,
|
|
user_name: str = None,
|
|
agent: Agent = None,
|
|
query_files: str = None,
|
|
generated_files: List[FileAttachment] = None,
|
|
additional_context: List[str] = None,
|
|
tracer: dict = {},
|
|
) -> Union[ThreadedGenerator, Iterator[str]]:
|
|
"""
|
|
Converse with user using Llama
|
|
"""
|
|
# Initialize Variables
|
|
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
|
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
|
tracer["chat_model"] = model
|
|
current_date = datetime.now()
|
|
|
|
if agent and agent.personality:
|
|
system_prompt = prompts.custom_system_prompt_offline_chat.format(
|
|
name=agent.name,
|
|
bio=agent.personality,
|
|
current_date=current_date.strftime("%Y-%m-%d"),
|
|
day_of_week=current_date.strftime("%A"),
|
|
)
|
|
else:
|
|
system_prompt = prompts.system_prompt_offline_chat.format(
|
|
current_date=current_date.strftime("%Y-%m-%d"),
|
|
day_of_week=current_date.strftime("%A"),
|
|
)
|
|
|
|
if location_data:
|
|
location_prompt = prompts.user_location.format(location=f"{location_data}")
|
|
system_prompt = f"{system_prompt}\n{location_prompt}"
|
|
|
|
if user_name:
|
|
user_name_prompt = prompts.user_name.format(name=user_name)
|
|
system_prompt = f"{system_prompt}\n{user_name_prompt}"
|
|
|
|
# Get Conversation Primer appropriate to Conversation Type
|
|
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
|
return iter([prompts.no_notes_found.format()])
|
|
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
|
completion_func(chat_response=prompts.no_online_results_found.format())
|
|
return iter([prompts.no_online_results_found.format()])
|
|
|
|
context_message = ""
|
|
if not is_none_or_empty(references):
|
|
context_message = f"{prompts.notes_conversation_offline.format(references=yaml_dump(references))}\n\n"
|
|
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
|
|
simplified_online_results = online_results.copy()
|
|
for result in online_results:
|
|
if online_results[result].get("webpages"):
|
|
simplified_online_results[result] = online_results[result]["webpages"]
|
|
|
|
context_message += f"{prompts.online_search_conversation_offline.format(online_results=yaml_dump(simplified_online_results))}\n\n"
|
|
if ConversationCommand.Code in conversation_commands and not is_none_or_empty(code_results):
|
|
context_message += (
|
|
f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n"
|
|
)
|
|
context_message = context_message.strip()
|
|
|
|
# Setup Prompt with Primer or Conversation History
|
|
messages = generate_chatml_messages_with_context(
|
|
user_query,
|
|
system_prompt,
|
|
conversation_log,
|
|
context_message=context_message,
|
|
model_name=model,
|
|
loaded_model=offline_chat_model,
|
|
max_prompt_size=max_prompt_size,
|
|
tokenizer_name=tokenizer_name,
|
|
model_type=ChatModelOptions.ModelType.OFFLINE,
|
|
query_files=query_files,
|
|
generated_files=generated_files,
|
|
additional_program_context=additional_context,
|
|
)
|
|
|
|
logger.debug(f"Conversation Context for {model}: {messages_to_print(messages)}")
|
|
|
|
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
|
|
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer))
|
|
t.start()
|
|
return g
|
|
|
|
|
|
def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None, tracer: dict = {}):
|
|
stop_phrases = ["<s>", "INST]", "Notes:"]
|
|
aggregated_response = ""
|
|
|
|
state.chat_lock.acquire()
|
|
try:
|
|
response_iterator = send_message_to_model_offline(
|
|
messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True
|
|
)
|
|
for response in response_iterator:
|
|
response_delta = response["choices"][0]["delta"].get("content", "")
|
|
aggregated_response += response_delta
|
|
g.send(response_delta)
|
|
|
|
# Save conversation trace
|
|
if is_promptrace_enabled():
|
|
commit_conversation_trace(messages, aggregated_response, tracer)
|
|
|
|
finally:
|
|
state.chat_lock.release()
|
|
g.close()
|
|
|
|
|
|
def send_message_to_model_offline(
|
|
messages: List[ChatMessage],
|
|
loaded_model=None,
|
|
model="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
|
|
temperature: float = 0.2,
|
|
streaming=False,
|
|
stop=[],
|
|
max_prompt_size: int = None,
|
|
response_type: str = "text",
|
|
tracer: dict = {},
|
|
):
|
|
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
|
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
|
messages_dict = [{"role": message.role, "content": message.content} for message in messages]
|
|
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
|
response = offline_chat_model.create_chat_completion(
|
|
messages_dict,
|
|
stop=stop,
|
|
stream=streaming,
|
|
temperature=temperature,
|
|
response_format={"type": response_type},
|
|
seed=seed,
|
|
)
|
|
|
|
if streaming:
|
|
return response
|
|
|
|
response_text = response["choices"][0]["message"].get("content", "")
|
|
|
|
# Save conversation trace for non-streaming responses
|
|
# Streamed responses need to be saved by the calling function
|
|
tracer["chat_model"] = model
|
|
tracer["temperature"] = temperature
|
|
if is_promptrace_enabled():
|
|
commit_conversation_trace(messages, response_text, tracer)
|
|
|
|
return response_text
|