mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Drop native offline chat support with llama-cpp-python
It is recommended to chat with open-source models by running an open-source server like Ollama, Llama.cpp on your GPU powered machine or use a commercial provider of open-source models like DeepInfra or OpenRouter. These chat model serving options provide a mature Openai compatible API that already works with Khoj. Directly using offline chat models only worked reasonably with pip install on a machine with GPU. Docker setup of khoj had trouble with accessing GPU. And without GPU access offline chat is too slow. Deprecating support for an offline chat provider directly from within Khoj will reduce code complexity and increase developement velocity. Offline models are subsumed to use existing Openai ai model provider.
This commit is contained in:
@@ -72,7 +72,6 @@ from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.utils import state
|
||||
from khoj.utils.config import OfflineChatProcessorModel
|
||||
from khoj.utils.helpers import (
|
||||
clean_object_for_db,
|
||||
clean_text_for_db,
|
||||
@@ -1553,14 +1552,6 @@ class ConversationAdapters:
|
||||
if chat_model is None:
|
||||
chat_model = await ConversationAdapters.aget_default_chat_model()
|
||||
|
||||
if chat_model.model_type == ChatModel.ModelType.OFFLINE:
|
||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||
chat_model_name = chat_model.name
|
||||
max_tokens = chat_model.max_prompt_size
|
||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
|
||||
|
||||
return chat_model
|
||||
|
||||
if (
|
||||
chat_model.model_type
|
||||
in [
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
# Generated by Django 5.1.10 on 2025-07-19 21:33
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0091_chatmodel_friendly_name_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="chatmodel",
|
||||
name="model_type",
|
||||
field=models.CharField(
|
||||
choices=[("openai", "Openai"), ("anthropic", "Anthropic"), ("google", "Google")],
|
||||
default="google",
|
||||
max_length=200,
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="chatmodel",
|
||||
name="name",
|
||||
field=models.CharField(default="gemini-2.5-flash", max_length=200),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="speechtotextmodeloptions",
|
||||
name="model_name",
|
||||
field=models.CharField(default="whisper-1", max_length=200),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="speechtotextmodeloptions",
|
||||
name="model_type",
|
||||
field=models.CharField(choices=[("openai", "Openai")], default="openai", max_length=200),
|
||||
),
|
||||
]
|
||||
@@ -220,16 +220,15 @@ class PriceTier(models.TextChoices):
|
||||
class ChatModel(DbBaseModel):
|
||||
class ModelType(models.TextChoices):
|
||||
OPENAI = "openai"
|
||||
OFFLINE = "offline"
|
||||
ANTHROPIC = "anthropic"
|
||||
GOOGLE = "google"
|
||||
|
||||
max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
|
||||
subscribed_max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
|
||||
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
name = models.CharField(max_length=200, default="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF")
|
||||
name = models.CharField(max_length=200, default="gemini-2.5-flash")
|
||||
friendly_name = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
|
||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.GOOGLE)
|
||||
price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE)
|
||||
vision_enabled = models.BooleanField(default=False)
|
||||
ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
@@ -605,11 +604,10 @@ class TextToImageModelConfig(DbBaseModel):
|
||||
class SpeechToTextModelOptions(DbBaseModel):
|
||||
class ModelType(models.TextChoices):
|
||||
OPENAI = "openai"
|
||||
OFFLINE = "offline"
|
||||
|
||||
model_name = models.CharField(max_length=200, default="base")
|
||||
model_name = models.CharField(max_length=200, default="whisper-1")
|
||||
friendly_name = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
|
||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
|
||||
price_tier = models.CharField(max_length=20, choices=PriceTier.choices, default=PriceTier.FREE)
|
||||
ai_model_api = models.ForeignKey(AiModelApi, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
|
||||
|
||||
@@ -214,7 +214,6 @@ def set_state(args):
|
||||
)
|
||||
state.anonymous_mode = args.anonymous_mode
|
||||
state.khoj_version = version("khoj")
|
||||
state.chat_on_gpu = args.chat_on_gpu
|
||||
|
||||
|
||||
def start_server(app, host=None, port=None, socket=None):
|
||||
|
||||
@@ -1,224 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from threading import Thread
|
||||
from time import perf_counter
|
||||
from typing import Any, AsyncGenerator, Dict, List, Union
|
||||
|
||||
from langchain_core.messages.chat import ChatMessage
|
||||
from llama_cpp import Llama
|
||||
|
||||
from khoj.database.models import Agent, ChatMessageModel, ChatModel
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.offline.utils import download_model
|
||||
from khoj.processor.conversation.utils import (
|
||||
ResponseWithThought,
|
||||
commit_conversation_trace,
|
||||
generate_chatml_messages_with_context,
|
||||
messages_to_print,
|
||||
)
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import (
|
||||
is_none_or_empty,
|
||||
is_promptrace_enabled,
|
||||
truncate_code_context,
|
||||
)
|
||||
from khoj.utils.rawconfig import FileAttachment, LocationData
|
||||
from khoj.utils.yaml import yaml_dump
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def converse_offline(
|
||||
# Query
|
||||
user_query: str,
|
||||
# Context
|
||||
references: list[dict] = [],
|
||||
online_results={},
|
||||
code_results={},
|
||||
query_files: str = None,
|
||||
generated_files: List[FileAttachment] = None,
|
||||
additional_context: List[str] = None,
|
||||
generated_asset_results: Dict[str, Dict] = {},
|
||||
location_data: LocationData = None,
|
||||
user_name: str = None,
|
||||
chat_history: list[ChatMessageModel] = [],
|
||||
# Model
|
||||
model_name: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
|
||||
loaded_model: Union[Any, None] = None,
|
||||
max_prompt_size=None,
|
||||
tokenizer_name=None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
) -> AsyncGenerator[ResponseWithThought, None]:
|
||||
"""
|
||||
Converse with user using Llama (Async Version)
|
||||
"""
|
||||
# Initialize Variables
|
||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||
offline_chat_model = loaded_model or download_model(model_name, max_tokens=max_prompt_size)
|
||||
tracer["chat_model"] = model_name
|
||||
current_date = datetime.now()
|
||||
|
||||
if agent and agent.personality:
|
||||
system_prompt = prompts.custom_system_prompt_offline_chat.format(
|
||||
name=agent.name,
|
||||
bio=agent.personality,
|
||||
current_date=current_date.strftime("%Y-%m-%d"),
|
||||
day_of_week=current_date.strftime("%A"),
|
||||
)
|
||||
else:
|
||||
system_prompt = prompts.system_prompt_offline_chat.format(
|
||||
current_date=current_date.strftime("%Y-%m-%d"),
|
||||
day_of_week=current_date.strftime("%A"),
|
||||
)
|
||||
|
||||
if location_data:
|
||||
location_prompt = prompts.user_location.format(location=f"{location_data}")
|
||||
system_prompt = f"{system_prompt}\n{location_prompt}"
|
||||
|
||||
if user_name:
|
||||
user_name_prompt = prompts.user_name.format(name=user_name)
|
||||
system_prompt = f"{system_prompt}\n{user_name_prompt}"
|
||||
|
||||
# Get Conversation Primer appropriate to Conversation Type
|
||||
context_message = ""
|
||||
if not is_none_or_empty(references):
|
||||
context_message = f"{prompts.notes_conversation_offline.format(references=yaml_dump(references))}\n\n"
|
||||
if not is_none_or_empty(online_results):
|
||||
simplified_online_results = online_results.copy()
|
||||
for result in online_results:
|
||||
if online_results[result].get("webpages"):
|
||||
simplified_online_results[result] = online_results[result]["webpages"]
|
||||
|
||||
context_message += f"{prompts.online_search_conversation_offline.format(online_results=yaml_dump(simplified_online_results))}\n\n"
|
||||
if not is_none_or_empty(code_results):
|
||||
context_message += (
|
||||
f"{prompts.code_executed_context.format(code_results=truncate_code_context(code_results))}\n\n"
|
||||
)
|
||||
context_message = context_message.strip()
|
||||
|
||||
# Setup Prompt with Primer or Conversation History
|
||||
messages = generate_chatml_messages_with_context(
|
||||
user_query,
|
||||
system_prompt,
|
||||
chat_history,
|
||||
context_message=context_message,
|
||||
model_name=model_name,
|
||||
loaded_model=offline_chat_model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
tokenizer_name=tokenizer_name,
|
||||
model_type=ChatModel.ModelType.OFFLINE,
|
||||
query_files=query_files,
|
||||
generated_files=generated_files,
|
||||
generated_asset_results=generated_asset_results,
|
||||
program_execution_context=additional_context,
|
||||
)
|
||||
|
||||
logger.debug(f"Conversation Context for {model_name}: {messages_to_print(messages)}")
|
||||
|
||||
# Use asyncio.Queue and a thread to bridge sync iterator
|
||||
queue: asyncio.Queue[ResponseWithThought] = asyncio.Queue()
|
||||
stop_phrases = ["<s>", "INST]", "Notes:"]
|
||||
|
||||
def _sync_llm_thread():
|
||||
"""Synchronous function to run in a separate thread."""
|
||||
aggregated_response = ""
|
||||
start_time = perf_counter()
|
||||
state.chat_lock.acquire()
|
||||
try:
|
||||
response_iterator = send_message_to_model_offline(
|
||||
messages,
|
||||
loaded_model=offline_chat_model,
|
||||
stop=stop_phrases,
|
||||
max_prompt_size=max_prompt_size,
|
||||
streaming=True,
|
||||
tracer=tracer,
|
||||
)
|
||||
for response in response_iterator:
|
||||
response_delta: str = response["choices"][0]["delta"].get("content", "")
|
||||
# Log the time taken to start response
|
||||
if aggregated_response == "" and response_delta != "":
|
||||
logger.info(f"First response took: {perf_counter() - start_time:.3f} seconds")
|
||||
# Handle response chunk
|
||||
aggregated_response += response_delta
|
||||
# Put chunk into the asyncio queue (non-blocking)
|
||||
try:
|
||||
queue.put_nowait(ResponseWithThought(text=response_delta))
|
||||
except asyncio.QueueFull:
|
||||
# Should not happen with default queue size unless consumer is very slow
|
||||
logger.warning("Asyncio queue full during offline LLM streaming.")
|
||||
# Potentially block here or handle differently if needed
|
||||
asyncio.run(queue.put(ResponseWithThought(text=response_delta)))
|
||||
|
||||
# Log the time taken to stream the entire response
|
||||
logger.info(f"Chat streaming took: {perf_counter() - start_time:.3f} seconds")
|
||||
|
||||
# Save conversation trace
|
||||
tracer["chat_model"] = model_name
|
||||
if is_promptrace_enabled():
|
||||
commit_conversation_trace(messages, aggregated_response, tracer)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in offline LLM thread: {e}", exc_info=True)
|
||||
finally:
|
||||
state.chat_lock.release()
|
||||
# Signal end of stream
|
||||
queue.put_nowait(None)
|
||||
|
||||
# Start the synchronous thread
|
||||
thread = Thread(target=_sync_llm_thread)
|
||||
thread.start()
|
||||
|
||||
# Asynchronously consume from the queue
|
||||
while True:
|
||||
chunk = await queue.get()
|
||||
if chunk is None: # End of stream signal
|
||||
queue.task_done()
|
||||
break
|
||||
yield chunk
|
||||
queue.task_done()
|
||||
|
||||
# Wait for the thread to finish (optional, ensures cleanup)
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(None, thread.join)
|
||||
|
||||
|
||||
def send_message_to_model_offline(
|
||||
messages: List[ChatMessage],
|
||||
loaded_model=None,
|
||||
model_name="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
|
||||
temperature: float = 0.2,
|
||||
streaming=False,
|
||||
stop=[],
|
||||
max_prompt_size: int = None,
|
||||
response_type: str = "text",
|
||||
tracer: dict = {},
|
||||
):
|
||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||
offline_chat_model = loaded_model or download_model(model_name, max_tokens=max_prompt_size)
|
||||
messages_dict = [{"role": message.role, "content": message.content} for message in messages]
|
||||
seed = int(os.getenv("KHOJ_LLM_SEED")) if os.getenv("KHOJ_LLM_SEED") else None
|
||||
response = offline_chat_model.create_chat_completion(
|
||||
messages_dict,
|
||||
stop=stop,
|
||||
stream=streaming,
|
||||
temperature=temperature,
|
||||
response_format={"type": response_type},
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
if streaming:
|
||||
return response
|
||||
|
||||
response_text: str = response["choices"][0]["message"].get("content", "")
|
||||
|
||||
# Save conversation trace for non-streaming responses
|
||||
# Streamed responses need to be saved by the calling function
|
||||
tracer["chat_model"] = model_name
|
||||
tracer["temperature"] = temperature
|
||||
if is_promptrace_enabled():
|
||||
commit_conversation_trace(messages, response_text, tracer)
|
||||
|
||||
return ResponseWithThought(text=response_text)
|
||||
@@ -1,80 +0,0 @@
|
||||
import glob
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
from huggingface_hub.constants import HF_HUB_CACHE
|
||||
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import get_device_memory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf", max_tokens: int = None):
|
||||
# Initialize Model Parameters
|
||||
# Use n_ctx=0 to get context size from the model
|
||||
kwargs: Dict[str, Any] = {"n_threads": 4, "n_ctx": 0, "verbose": False}
|
||||
|
||||
# Decide whether to load model to GPU or CPU
|
||||
device = "gpu" if state.chat_on_gpu and state.device != "cpu" else "cpu"
|
||||
kwargs["n_gpu_layers"] = -1 if device == "gpu" else 0
|
||||
|
||||
# Add chat format if known
|
||||
if "llama-3" in repo_id.lower():
|
||||
kwargs["chat_format"] = "llama-3"
|
||||
elif "gemma-2" in repo_id.lower():
|
||||
kwargs["chat_format"] = "gemma"
|
||||
|
||||
# Check if the model is already downloaded
|
||||
model_path = load_model_from_cache(repo_id, filename)
|
||||
chat_model = None
|
||||
try:
|
||||
chat_model = load_model(model_path, repo_id, filename, kwargs)
|
||||
except:
|
||||
# Load model on CPU if GPU is not available
|
||||
kwargs["n_gpu_layers"], device = 0, "cpu"
|
||||
chat_model = load_model(model_path, repo_id, filename, kwargs)
|
||||
|
||||
# Now load the model with context size set based on:
|
||||
# 1. context size supported by model and
|
||||
# 2. configured size or machine (V)RAM
|
||||
kwargs["n_ctx"] = infer_max_tokens(chat_model.n_ctx(), max_tokens)
|
||||
chat_model = load_model(model_path, repo_id, filename, kwargs)
|
||||
|
||||
logger.debug(
|
||||
f"{'Loaded' if model_path else 'Downloaded'} chat model to {device.upper()} with {kwargs['n_ctx']} token context window."
|
||||
)
|
||||
return chat_model
|
||||
|
||||
|
||||
def load_model(model_path: str, repo_id: str, filename: str = "*Q4_K_M.gguf", kwargs: dict = {}):
|
||||
from llama_cpp.llama import Llama
|
||||
|
||||
if model_path:
|
||||
return Llama(model_path, **kwargs)
|
||||
else:
|
||||
return Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs)
|
||||
|
||||
|
||||
def load_model_from_cache(repo_id: str, filename: str, repo_type="models"):
|
||||
# Construct the path to the model file in the cache directory
|
||||
repo_org, repo_name = repo_id.split("/")
|
||||
object_id = "--".join([repo_type, repo_org, repo_name])
|
||||
model_path = os.path.sep.join([HF_HUB_CACHE, object_id, "snapshots", "**", filename])
|
||||
|
||||
# Check if the model file exists
|
||||
paths = glob.glob(model_path)
|
||||
if paths:
|
||||
return paths[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def infer_max_tokens(model_context_window: int, configured_max_tokens=None) -> int:
|
||||
"""Infer max prompt size based on device memory and max context window supported by the model"""
|
||||
configured_max_tokens = math.inf if configured_max_tokens is None else configured_max_tokens
|
||||
vram_based_n_ctx = int(get_device_memory() / 1e6) # based on heuristic
|
||||
configured_max_tokens = configured_max_tokens or math.inf # do not use if set to None
|
||||
return min(configured_max_tokens, vram_based_n_ctx, model_context_window)
|
||||
@@ -1,15 +0,0 @@
|
||||
import whisper
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
from khoj.utils import state
|
||||
|
||||
|
||||
async def transcribe_audio_offline(audio_filename: str, model: str) -> str:
|
||||
"""
|
||||
Transcribe audio file offline using Whisper
|
||||
"""
|
||||
# Send the audio data to the Whisper API
|
||||
if not state.whisper_model:
|
||||
state.whisper_model = whisper.load_model(model)
|
||||
response = await sync_to_async(state.whisper_model.transcribe)(audio_filename)
|
||||
return response["text"]
|
||||
@@ -78,38 +78,6 @@ no_entries_found = PromptTemplate.from_template(
|
||||
""".strip()
|
||||
)
|
||||
|
||||
## Conversation Prompts for Offline Chat Models
|
||||
## --
|
||||
system_prompt_offline_chat = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, a smart, inquisitive and helpful personal assistant.
|
||||
- Use your general knowledge and past conversation with the user as context to inform your responses.
|
||||
- If you do not know the answer, say 'I don't know.'
|
||||
- Think step-by-step and ask questions to get the necessary information to answer the user's question.
|
||||
- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided information or past conversations.
|
||||
- Do not print verbatim Notes unless necessary.
|
||||
|
||||
Note: More information about you, the company or Khoj apps can be found at https://khoj.dev.
|
||||
Today is {day_of_week}, {current_date} in UTC.
|
||||
""".strip()
|
||||
)
|
||||
|
||||
custom_system_prompt_offline_chat = PromptTemplate.from_template(
|
||||
"""
|
||||
You are {name}, a personal agent on Khoj.
|
||||
- Use your general knowledge and past conversation with the user as context to inform your responses.
|
||||
- If you do not know the answer, say 'I don't know.'
|
||||
- Think step-by-step and ask questions to get the necessary information to answer the user's question.
|
||||
- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided information or past conversations.
|
||||
- Do not print verbatim Notes unless necessary.
|
||||
|
||||
Note: More information about you, the company or Khoj apps can be found at https://khoj.dev.
|
||||
Today is {day_of_week}, {current_date} in UTC.
|
||||
|
||||
Instructions:\n{bio}
|
||||
""".strip()
|
||||
)
|
||||
|
||||
## Notes Conversation
|
||||
## --
|
||||
notes_conversation = PromptTemplate.from_template(
|
||||
|
||||
@@ -18,8 +18,6 @@ import requests
|
||||
import tiktoken
|
||||
import yaml
|
||||
from langchain_core.messages.chat import ChatMessage
|
||||
from llama_cpp import LlamaTokenizer
|
||||
from llama_cpp.llama import Llama
|
||||
from pydantic import BaseModel, ConfigDict, ValidationError, create_model
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
@@ -32,7 +30,6 @@ from khoj.database.models import (
|
||||
KhojUser,
|
||||
)
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
|
||||
from khoj.search_filter.base_filter import BaseFilter
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
@@ -85,12 +82,6 @@ model_to_prompt_size = {
|
||||
"claude-sonnet-4-20250514": 60000,
|
||||
"claude-opus-4-0": 60000,
|
||||
"claude-opus-4-20250514": 60000,
|
||||
# Offline Models
|
||||
"bartowski/Qwen2.5-14B-Instruct-GGUF": 20000,
|
||||
"bartowski/Meta-Llama-3.1-8B-Instruct-GGUF": 20000,
|
||||
"bartowski/Llama-3.2-3B-Instruct-GGUF": 20000,
|
||||
"bartowski/gemma-2-9b-it-GGUF": 6000,
|
||||
"bartowski/gemma-2-2b-it-GGUF": 6000,
|
||||
}
|
||||
model_to_tokenizer: Dict[str, str] = {}
|
||||
|
||||
@@ -573,7 +564,6 @@ def generate_chatml_messages_with_context(
|
||||
system_message: str = None,
|
||||
chat_history: list[ChatMessageModel] = [],
|
||||
model_name="gpt-4o-mini",
|
||||
loaded_model: Optional[Llama] = None,
|
||||
max_prompt_size=None,
|
||||
tokenizer_name=None,
|
||||
query_images=None,
|
||||
@@ -588,10 +578,7 @@ def generate_chatml_messages_with_context(
|
||||
"""Generate chat messages with appropriate context from previous conversation to send to the chat model"""
|
||||
# Set max prompt size from user config or based on pre-configured for model and machine specs
|
||||
if not max_prompt_size:
|
||||
if loaded_model:
|
||||
max_prompt_size = infer_max_tokens(loaded_model.n_ctx(), model_to_prompt_size.get(model_name, math.inf))
|
||||
else:
|
||||
max_prompt_size = model_to_prompt_size.get(model_name, 10000)
|
||||
max_prompt_size = model_to_prompt_size.get(model_name, 10000)
|
||||
|
||||
# Scale lookback turns proportional to max prompt size supported by model
|
||||
lookback_turns = max_prompt_size // 750
|
||||
@@ -735,7 +722,7 @@ def generate_chatml_messages_with_context(
|
||||
message.content = [{"type": "text", "text": message.content}]
|
||||
|
||||
# Truncate oldest messages from conversation history until under max supported prompt size by model
|
||||
messages = truncate_messages(messages, max_prompt_size, model_name, loaded_model, tokenizer_name)
|
||||
messages = truncate_messages(messages, max_prompt_size, model_name, tokenizer_name)
|
||||
|
||||
# Return message in chronological order
|
||||
return messages[::-1]
|
||||
@@ -743,25 +730,20 @@ def generate_chatml_messages_with_context(
|
||||
|
||||
def get_encoder(
|
||||
model_name: str,
|
||||
loaded_model: Optional[Llama] = None,
|
||||
tokenizer_name=None,
|
||||
) -> tiktoken.Encoding | PreTrainedTokenizer | PreTrainedTokenizerFast | LlamaTokenizer:
|
||||
) -> tiktoken.Encoding | PreTrainedTokenizer | PreTrainedTokenizerFast:
|
||||
default_tokenizer = "gpt-4o"
|
||||
|
||||
try:
|
||||
if loaded_model:
|
||||
encoder = loaded_model.tokenizer()
|
||||
elif model_name.startswith("gpt-") or model_name.startswith("o1"):
|
||||
# as tiktoken doesn't recognize o1 model series yet
|
||||
encoder = tiktoken.encoding_for_model("gpt-4o" if model_name.startswith("o1") else model_name)
|
||||
elif tokenizer_name:
|
||||
if tokenizer_name:
|
||||
if tokenizer_name in state.pretrained_tokenizers:
|
||||
encoder = state.pretrained_tokenizers[tokenizer_name]
|
||||
else:
|
||||
encoder = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
state.pretrained_tokenizers[tokenizer_name] = encoder
|
||||
else:
|
||||
encoder = download_model(model_name).tokenizer()
|
||||
# as tiktoken doesn't recognize o1 model series yet
|
||||
encoder = tiktoken.encoding_for_model("gpt-4o" if model_name.startswith("o1") else model_name)
|
||||
except:
|
||||
encoder = tiktoken.encoding_for_model(default_tokenizer)
|
||||
if state.verbose > 2:
|
||||
@@ -773,7 +755,7 @@ def get_encoder(
|
||||
|
||||
def count_tokens(
|
||||
message_content: str | list[str | dict],
|
||||
encoder: PreTrainedTokenizer | PreTrainedTokenizerFast | LlamaTokenizer | tiktoken.Encoding,
|
||||
encoder: PreTrainedTokenizer | PreTrainedTokenizerFast | tiktoken.Encoding,
|
||||
) -> int:
|
||||
"""
|
||||
Count the total number of tokens in a list of messages.
|
||||
@@ -825,11 +807,10 @@ def truncate_messages(
|
||||
messages: list[ChatMessage],
|
||||
max_prompt_size: int,
|
||||
model_name: str,
|
||||
loaded_model: Optional[Llama] = None,
|
||||
tokenizer_name=None,
|
||||
) -> list[ChatMessage]:
|
||||
"""Truncate messages to fit within max prompt size supported by model"""
|
||||
encoder = get_encoder(model_name, loaded_model, tokenizer_name)
|
||||
encoder = get_encoder(model_name, tokenizer_name)
|
||||
|
||||
# Extract system message from messages
|
||||
system_message = None
|
||||
|
||||
@@ -235,7 +235,6 @@ def is_operator_model(model: str) -> ChatModel.ModelType | None:
|
||||
"claude-3-7-sonnet": ChatModel.ModelType.ANTHROPIC,
|
||||
"claude-sonnet-4": ChatModel.ModelType.ANTHROPIC,
|
||||
"claude-opus-4": ChatModel.ModelType.ANTHROPIC,
|
||||
"ui-tars-1.5": ChatModel.ModelType.OFFLINE,
|
||||
}
|
||||
for operator_model in operator_models:
|
||||
if model.startswith(operator_model):
|
||||
|
||||
@@ -15,7 +15,6 @@ from khoj.configure import initialize_content
|
||||
from khoj.database import adapters
|
||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters, get_user_photo
|
||||
from khoj.database.models import KhojUser, SpeechToTextModelOptions
|
||||
from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
|
||||
from khoj.processor.conversation.openai.whisper import transcribe_audio
|
||||
from khoj.routers.helpers import (
|
||||
ApiUserRateLimiter,
|
||||
@@ -150,9 +149,6 @@ async def transcribe(
|
||||
if not speech_to_text_config:
|
||||
# If the user has not configured a speech to text model, return an unsupported on server error
|
||||
status_code = 501
|
||||
elif speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OFFLINE:
|
||||
speech2text_model = speech_to_text_config.model_name
|
||||
user_message = await transcribe_audio_offline(audio_filename, speech2text_model)
|
||||
elif speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OPENAI:
|
||||
speech2text_model = speech_to_text_config.model_name
|
||||
if speech_to_text_config.ai_model_api:
|
||||
|
||||
@@ -89,10 +89,6 @@ from khoj.processor.conversation.google.gemini_chat import (
|
||||
converse_gemini,
|
||||
gemini_send_message_to_model,
|
||||
)
|
||||
from khoj.processor.conversation.offline.chat_model import (
|
||||
converse_offline,
|
||||
send_message_to_model_offline,
|
||||
)
|
||||
from khoj.processor.conversation.openai.gpt import (
|
||||
converse_openai,
|
||||
send_message_to_model,
|
||||
@@ -117,7 +113,6 @@ from khoj.search_filter.file_filter import FileFilter
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.search_type import text_search
|
||||
from khoj.utils import state
|
||||
from khoj.utils.config import OfflineChatProcessorModel
|
||||
from khoj.utils.helpers import (
|
||||
LRU,
|
||||
ConversationCommand,
|
||||
@@ -168,14 +163,6 @@ async def is_ready_to_chat(user: KhojUser):
|
||||
if user_chat_model == None:
|
||||
user_chat_model = await ConversationAdapters.aget_default_chat_model(user)
|
||||
|
||||
if user_chat_model and user_chat_model.model_type == ChatModel.ModelType.OFFLINE:
|
||||
chat_model_name = user_chat_model.name
|
||||
max_tokens = user_chat_model.max_prompt_size
|
||||
if state.offline_chat_processor_config is None:
|
||||
logger.info("Loading Offline Chat Model...")
|
||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
|
||||
return True
|
||||
|
||||
if (
|
||||
user_chat_model
|
||||
and (
|
||||
@@ -1470,12 +1457,6 @@ async def send_message_to_model_wrapper(
|
||||
vision_available = chat_model.vision_enabled
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
api_base_url = chat_model.ai_model_api.api_base_url
|
||||
loaded_model = None
|
||||
|
||||
if model_type == ChatModel.ModelType.OFFLINE:
|
||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
|
||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=query,
|
||||
@@ -1483,7 +1464,6 @@ async def send_message_to_model_wrapper(
|
||||
system_message=system_message,
|
||||
chat_history=chat_history,
|
||||
model_name=chat_model_name,
|
||||
loaded_model=loaded_model,
|
||||
tokenizer_name=tokenizer,
|
||||
max_prompt_size=max_tokens,
|
||||
vision_enabled=vision_available,
|
||||
@@ -1492,18 +1472,7 @@ async def send_message_to_model_wrapper(
|
||||
query_files=query_files,
|
||||
)
|
||||
|
||||
if model_type == ChatModel.ModelType.OFFLINE:
|
||||
return send_message_to_model_offline(
|
||||
messages=truncated_messages,
|
||||
loaded_model=loaded_model,
|
||||
model_name=chat_model_name,
|
||||
max_prompt_size=max_tokens,
|
||||
streaming=False,
|
||||
response_type=response_type,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
elif model_type == ChatModel.ModelType.OPENAI:
|
||||
if model_type == ChatModel.ModelType.OPENAI:
|
||||
return send_message_to_model(
|
||||
messages=truncated_messages,
|
||||
api_key=api_key,
|
||||
@@ -1565,19 +1534,12 @@ def send_message_to_model_wrapper_sync(
|
||||
vision_available = chat_model.vision_enabled
|
||||
api_key = chat_model.ai_model_api.api_key
|
||||
api_base_url = chat_model.ai_model_api.api_base_url
|
||||
loaded_model = None
|
||||
|
||||
if model_type == ChatModel.ModelType.OFFLINE:
|
||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model_name, max_tokens)
|
||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message,
|
||||
system_message=system_message,
|
||||
chat_history=chat_history,
|
||||
model_name=chat_model_name,
|
||||
loaded_model=loaded_model,
|
||||
max_prompt_size=max_tokens,
|
||||
vision_enabled=vision_available,
|
||||
model_type=model_type,
|
||||
@@ -1585,18 +1547,7 @@ def send_message_to_model_wrapper_sync(
|
||||
query_files=query_files,
|
||||
)
|
||||
|
||||
if model_type == ChatModel.ModelType.OFFLINE:
|
||||
return send_message_to_model_offline(
|
||||
messages=truncated_messages,
|
||||
loaded_model=loaded_model,
|
||||
model_name=chat_model_name,
|
||||
max_prompt_size=max_tokens,
|
||||
streaming=False,
|
||||
response_type=response_type,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
elif model_type == ChatModel.ModelType.OPENAI:
|
||||
if model_type == ChatModel.ModelType.OPENAI:
|
||||
return send_message_to_model(
|
||||
messages=truncated_messages,
|
||||
api_key=api_key,
|
||||
@@ -1678,30 +1629,7 @@ async def agenerate_chat_response(
|
||||
chat_model = vision_enabled_config
|
||||
vision_available = True
|
||||
|
||||
if chat_model.model_type == "offline":
|
||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||
chat_response_generator = converse_offline(
|
||||
# Query
|
||||
user_query=query_to_run,
|
||||
# Context
|
||||
references=compiled_references,
|
||||
online_results=online_results,
|
||||
generated_files=raw_generated_files,
|
||||
generated_asset_results=generated_asset_results,
|
||||
location_data=location_data,
|
||||
user_name=user_name,
|
||||
query_files=query_files,
|
||||
chat_history=chat_history,
|
||||
# Model
|
||||
loaded_model=loaded_model,
|
||||
model_name=chat_model.name,
|
||||
max_prompt_size=chat_model.max_prompt_size,
|
||||
tokenizer_name=chat_model.tokenizer,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
elif chat_model.model_type == ChatModel.ModelType.OPENAI:
|
||||
if chat_model.model_type == ChatModel.ModelType.OPENAI:
|
||||
openai_chat_config = chat_model.ai_model_api
|
||||
api_key = openai_chat_config.api_key
|
||||
chat_model_name = chat_model.name
|
||||
|
||||
@@ -33,9 +33,6 @@ def cli(args=None):
|
||||
parser.add_argument("--sslcert", type=str, help="Path to SSL certificate file")
|
||||
parser.add_argument("--sslkey", type=str, help="Path to SSL key file")
|
||||
parser.add_argument("--version", "-V", action="store_true", help="Print the installed Khoj version and exit")
|
||||
parser.add_argument(
|
||||
"--disable-chat-on-gpu", action="store_true", default=False, help="Disable using GPU for the offline chat model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--anonymous-mode",
|
||||
action="store_true",
|
||||
@@ -54,9 +51,6 @@ def cli(args=None):
|
||||
if len(remaining_args) > 0:
|
||||
logger.info(f"⚠️ Ignoring unknown commandline args: {remaining_args}")
|
||||
|
||||
# Set default values for arguments
|
||||
args.chat_on_gpu = not args.disable_chat_on_gpu
|
||||
|
||||
args.version_no = version("khoj")
|
||||
if args.version:
|
||||
# Show version of khoj installed and exit
|
||||
|
||||
@@ -8,8 +8,6 @@ from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from khoj.processor.conversation.offline.utils import download_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -62,20 +60,3 @@ class ImageSearchModel:
|
||||
@dataclass
|
||||
class SearchModels:
|
||||
text_search: Optional[TextSearchModel] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class OfflineChatProcessorConfig:
|
||||
loaded_model: Union[Any, None] = None
|
||||
|
||||
|
||||
class OfflineChatProcessorModel:
|
||||
def __init__(self, chat_model: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", max_tokens: int = None):
|
||||
self.chat_model = chat_model
|
||||
self.loaded_model = None
|
||||
try:
|
||||
self.loaded_model = download_model(self.chat_model, max_tokens=max_tokens)
|
||||
except ValueError as e:
|
||||
self.loaded_model = None
|
||||
logger.error(f"Error while loading offline chat model: {e}", exc_info=True)
|
||||
raise e
|
||||
|
||||
@@ -10,13 +10,6 @@ empty_escape_sequences = "\n|\r|\t| "
|
||||
app_env_filepath = "~/.khoj/env"
|
||||
telemetry_server = "https://khoj.beta.haletic.com/v1/telemetry"
|
||||
content_directory = "~/.khoj/content/"
|
||||
default_offline_chat_models = [
|
||||
"bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
|
||||
"bartowski/Llama-3.2-3B-Instruct-GGUF",
|
||||
"bartowski/gemma-2-9b-it-GGUF",
|
||||
"bartowski/gemma-2-2b-it-GGUF",
|
||||
"bartowski/Qwen2.5-14B-Instruct-GGUF",
|
||||
]
|
||||
default_openai_chat_models = ["gpt-4o-mini", "gpt-4.1", "o3", "o4-mini"]
|
||||
default_gemini_chat_models = ["gemini-2.0-flash", "gemini-2.5-flash-preview-05-20", "gemini-2.5-pro-preview-06-05"]
|
||||
default_anthropic_chat_models = ["claude-sonnet-4-0", "claude-3-5-haiku-latest"]
|
||||
|
||||
@@ -16,7 +16,6 @@ from khoj.processor.conversation.utils import model_to_prompt_size, model_to_tok
|
||||
from khoj.utils.constants import (
|
||||
default_anthropic_chat_models,
|
||||
default_gemini_chat_models,
|
||||
default_offline_chat_models,
|
||||
default_openai_chat_models,
|
||||
)
|
||||
|
||||
@@ -72,7 +71,6 @@ def initialization(interactive: bool = True):
|
||||
default_api_key=openai_api_key,
|
||||
api_base_url=openai_base_url,
|
||||
vision_enabled=True,
|
||||
is_offline=False,
|
||||
interactive=interactive,
|
||||
provider_name=provider,
|
||||
)
|
||||
@@ -118,7 +116,6 @@ def initialization(interactive: bool = True):
|
||||
default_gemini_chat_models,
|
||||
default_api_key=os.getenv("GEMINI_API_KEY"),
|
||||
vision_enabled=True,
|
||||
is_offline=False,
|
||||
interactive=interactive,
|
||||
provider_name="Google Gemini",
|
||||
)
|
||||
@@ -145,17 +142,6 @@ def initialization(interactive: bool = True):
|
||||
default_anthropic_chat_models,
|
||||
default_api_key=os.getenv("ANTHROPIC_API_KEY"),
|
||||
vision_enabled=True,
|
||||
is_offline=False,
|
||||
interactive=interactive,
|
||||
)
|
||||
|
||||
# Set up offline chat models
|
||||
_setup_chat_model_provider(
|
||||
ChatModel.ModelType.OFFLINE,
|
||||
default_offline_chat_models,
|
||||
default_api_key=None,
|
||||
vision_enabled=False,
|
||||
is_offline=True,
|
||||
interactive=interactive,
|
||||
)
|
||||
|
||||
@@ -186,7 +172,6 @@ def initialization(interactive: bool = True):
|
||||
interactive: bool,
|
||||
api_base_url: str = None,
|
||||
vision_enabled: bool = False,
|
||||
is_offline: bool = False,
|
||||
provider_name: str = None,
|
||||
) -> Tuple[bool, AiModelApi]:
|
||||
supported_vision_models = (
|
||||
@@ -195,11 +180,6 @@ def initialization(interactive: bool = True):
|
||||
provider_name = provider_name or model_type.name.capitalize()
|
||||
|
||||
default_use_model = default_api_key is not None
|
||||
# If not in interactive mode & in the offline setting, it's most likely that we're running in a containerized environment.
|
||||
# This usually means there's not enough RAM to load offline models directly within the application.
|
||||
# In such cases, we default to not using the model -- it's recommended to use another service like Ollama to host the model locally in that case.
|
||||
if is_offline:
|
||||
default_use_model = False
|
||||
|
||||
use_model_provider = (
|
||||
default_use_model if not interactive else input(f"Add {provider_name} chat models? (y/n): ") == "y"
|
||||
@@ -211,13 +191,12 @@ def initialization(interactive: bool = True):
|
||||
logger.info(f"️💬 Setting up your {provider_name} chat configuration")
|
||||
|
||||
ai_model_api = None
|
||||
if not is_offline:
|
||||
if interactive:
|
||||
user_api_key = input(f"Enter your {provider_name} API key (default: {default_api_key}): ")
|
||||
api_key = user_api_key if user_api_key != "" else default_api_key
|
||||
else:
|
||||
api_key = default_api_key
|
||||
ai_model_api = AiModelApi.objects.create(api_key=api_key, name=provider_name, api_base_url=api_base_url)
|
||||
if interactive:
|
||||
user_api_key = input(f"Enter your {provider_name} API key (default: {default_api_key}): ")
|
||||
api_key = user_api_key if user_api_key != "" else default_api_key
|
||||
else:
|
||||
api_key = default_api_key
|
||||
ai_model_api = AiModelApi.objects.create(api_key=api_key, name=provider_name, api_base_url=api_base_url)
|
||||
|
||||
if interactive:
|
||||
user_chat_models = input(
|
||||
|
||||
@@ -103,13 +103,8 @@ class OpenAIProcessorConfig(ConfigBase):
|
||||
chat_model: Optional[str] = "gpt-4o-mini"
|
||||
|
||||
|
||||
class OfflineChatProcessorConfig(ConfigBase):
|
||||
chat_model: Optional[str] = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF"
|
||||
|
||||
|
||||
class ConversationProcessorConfig(ConfigBase):
|
||||
openai: Optional[OpenAIProcessorConfig] = None
|
||||
offline_chat: Optional[OfflineChatProcessorConfig] = None
|
||||
max_prompt_size: Optional[int] = None
|
||||
tokenizer: Optional[str] = None
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from whisper import Whisper
|
||||
from khoj.database.models import ProcessLock
|
||||
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||
from khoj.utils import config as utils_config
|
||||
from khoj.utils.config import OfflineChatProcessorModel, SearchModels
|
||||
from khoj.utils.config import SearchModels
|
||||
from khoj.utils.helpers import LRU, get_device, is_env_var_true
|
||||
from khoj.utils.rawconfig import FullConfig
|
||||
|
||||
@@ -22,7 +22,6 @@ search_models = SearchModels()
|
||||
embeddings_model: Dict[str, EmbeddingsModel] = None
|
||||
cross_encoder_model: Dict[str, CrossEncoderModel] = None
|
||||
openai_client: OpenAI = None
|
||||
offline_chat_processor_config: OfflineChatProcessorModel = None
|
||||
whisper_model: Whisper = None
|
||||
config_file: Path = None
|
||||
verbose: int = 0
|
||||
@@ -39,7 +38,6 @@ telemetry: List[Dict[str, str]] = []
|
||||
telemetry_disabled: bool = is_env_var_true("KHOJ_TELEMETRY_DISABLE")
|
||||
khoj_version: str = None
|
||||
device = get_device()
|
||||
chat_on_gpu: bool = True
|
||||
anonymous_mode: bool = False
|
||||
pretrained_tokenizers: Dict[str, PreTrainedTokenizer | PreTrainedTokenizerFast] = dict()
|
||||
billing_enabled: bool = (
|
||||
|
||||
Reference in New Issue
Block a user