[Multi-User Part 3]: Separate chat sesssions based on authenticated users (#511)

- Add a data model which allows us to store Conversations with users. This does a minimal lift over the current setup, where the underlying data is stored in a JSON file. This maintains parity with that configuration.
- There does _seem_ to be some regression in chat quality, which is most likely attributable to search results.

This will help us with #275. It should become much easier to maintain multiple Conversations in a given table in the backend now. We will have to do some thinking on the UI.
This commit is contained in:
sabaimran
2023-10-26 11:37:41 -07:00
committed by GitHub
parent a8a82d274a
commit 4b6ec248a6
24 changed files with 719 additions and 626 deletions

View File

@@ -1,5 +1,4 @@
from typing import Type, TypeVar, List
import uuid
from datetime import date
from django.db import models
@@ -21,6 +20,13 @@ from database.models import (
GithubConfig,
Embeddings,
GithubRepoConfig,
Conversation,
ConversationProcessorConfig,
OpenAIProcessorConversationConfig,
OfflineChatProcessorConversationConfig,
)
from khoj.utils.rawconfig import (
ConversationProcessorConfig as UserConversationProcessorConfig,
)
from khoj.search_filter.word_filter import WordFilter
from khoj.search_filter.file_filter import FileFilter
@@ -54,18 +60,17 @@ async def get_or_create_user(token: dict) -> KhojUser:
async def create_google_user(token: dict) -> KhojUser:
user_info = token.get("userinfo")
user = await KhojUser.objects.acreate(username=user_info.get("email"), email=user_info.get("email"))
user = await KhojUser.objects.acreate(username=token.get("email"), email=token.get("email"))
await user.asave()
await GoogleUser.objects.acreate(
sub=user_info.get("sub"),
azp=user_info.get("azp"),
email=user_info.get("email"),
name=user_info.get("name"),
given_name=user_info.get("given_name"),
family_name=user_info.get("family_name"),
picture=user_info.get("picture"),
locale=user_info.get("locale"),
sub=token.get("sub"),
azp=token.get("azp"),
email=token.get("email"),
name=token.get("name"),
given_name=token.get("given_name"),
family_name=token.get("family_name"),
picture=token.get("picture"),
locale=token.get("locale"),
user=user,
)
@@ -137,6 +142,124 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
return config
class ConversationAdapters:
@staticmethod
def get_conversation_by_user(user: KhojUser):
conversation = Conversation.objects.filter(user=user)
if conversation.exists():
return conversation.first()
return Conversation.objects.create(user=user)
@staticmethod
async def aget_conversation_by_user(user: KhojUser):
conversation = Conversation.objects.filter(user=user)
if await conversation.aexists():
return await conversation.afirst()
return await Conversation.objects.acreate(user=user)
@staticmethod
def has_any_conversation_config(user: KhojUser):
return ConversationProcessorConfig.objects.filter(user=user).exists()
@staticmethod
def get_openai_conversation_config(user: KhojUser):
return OpenAIProcessorConversationConfig.objects.filter(user=user).first()
@staticmethod
def get_offline_chat_conversation_config(user: KhojUser):
return OfflineChatProcessorConversationConfig.objects.filter(user=user).first()
@staticmethod
def has_valid_offline_conversation_config(user: KhojUser):
return OfflineChatProcessorConversationConfig.objects.filter(user=user, enable_offline_chat=True).exists()
@staticmethod
def has_valid_openai_conversation_config(user: KhojUser):
return OpenAIProcessorConversationConfig.objects.filter(user=user).exists()
@staticmethod
def get_conversation_config(user: KhojUser):
return ConversationProcessorConfig.objects.filter(user=user).first()
@staticmethod
def save_conversation(user: KhojUser, conversation_log: dict):
conversation = Conversation.objects.filter(user=user)
if conversation.exists():
conversation.update(conversation_log=conversation_log)
else:
Conversation.objects.create(user=user, conversation_log=conversation_log)
@staticmethod
def set_conversation_processor_config(user: KhojUser, new_config: UserConversationProcessorConfig):
conversation_config, _ = ConversationProcessorConfig.objects.get_or_create(user=user)
conversation_config.max_prompt_size = new_config.max_prompt_size
conversation_config.tokenizer = new_config.tokenizer
conversation_config.save()
if new_config.openai:
default_values = {
"api_key": new_config.openai.api_key,
}
if new_config.openai.chat_model:
default_values["chat_model"] = new_config.openai.chat_model
OpenAIProcessorConversationConfig.objects.update_or_create(user=user, defaults=default_values)
if new_config.offline_chat:
default_values = {
"enable_offline_chat": str(new_config.offline_chat.enable_offline_chat),
}
if new_config.offline_chat.chat_model:
default_values["chat_model"] = new_config.offline_chat.chat_model
OfflineChatProcessorConversationConfig.objects.update_or_create(user=user, defaults=default_values)
@staticmethod
def get_enabled_conversation_settings(user: KhojUser):
openai_config = ConversationAdapters.get_openai_conversation_config(user)
offline_chat_config = ConversationAdapters.get_offline_chat_conversation_config(user)
return {
"openai": True if openai_config is not None else False,
"offline_chat": True
if (offline_chat_config is not None and offline_chat_config.enable_offline_chat)
else False,
}
@staticmethod
def clear_conversation_config(user: KhojUser):
ConversationProcessorConfig.objects.filter(user=user).delete()
ConversationAdapters.clear_openai_conversation_config(user)
ConversationAdapters.clear_offline_chat_conversation_config(user)
@staticmethod
def clear_openai_conversation_config(user: KhojUser):
OpenAIProcessorConversationConfig.objects.filter(user=user).delete()
@staticmethod
def clear_offline_chat_conversation_config(user: KhojUser):
OfflineChatProcessorConversationConfig.objects.filter(user=user).delete()
@staticmethod
async def has_offline_chat(user: KhojUser):
return await OfflineChatProcessorConversationConfig.objects.filter(
user=user, enable_offline_chat=True
).aexists()
@staticmethod
async def get_offline_chat(user: KhojUser):
return await OfflineChatProcessorConversationConfig.objects.filter(user=user).afirst()
@staticmethod
async def has_openai_chat(user: KhojUser):
return await OpenAIProcessorConversationConfig.objects.filter(user=user).aexists()
@staticmethod
async def get_openai_chat(user: KhojUser):
return await OpenAIProcessorConversationConfig.objects.filter(user=user).afirst()
class EmbeddingsAdapters:
word_filer = WordFilter()
file_filter = FileFilter()

View File

@@ -0,0 +1,81 @@
# Generated by Django 4.2.5 on 2023-10-18 05:31
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
("database", "0006_embeddingsdates"),
]
operations = [
migrations.RemoveField(
model_name="conversationprocessorconfig",
name="conversation",
),
migrations.RemoveField(
model_name="conversationprocessorconfig",
name="enable_offline_chat",
),
migrations.AddField(
model_name="conversationprocessorconfig",
name="max_prompt_size",
field=models.IntegerField(blank=True, default=None, null=True),
),
migrations.AddField(
model_name="conversationprocessorconfig",
name="tokenizer",
field=models.CharField(blank=True, default=None, max_length=200, null=True),
),
migrations.AddField(
model_name="conversationprocessorconfig",
name="user",
field=models.ForeignKey(
default=1, on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL
),
preserve_default=False,
),
migrations.CreateModel(
name="OpenAIProcessorConversationConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("api_key", models.CharField(max_length=200)),
("chat_model", models.CharField(max_length=200)),
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
],
options={
"abstract": False,
},
),
migrations.CreateModel(
name="OfflineChatProcessorConversationConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("enable_offline_chat", models.BooleanField(default=False)),
("chat_model", models.CharField(default="llama-2-7b-chat.ggmlv3.q4_0.bin", max_length=200)),
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
],
options={
"abstract": False,
},
),
migrations.CreateModel(
name="Conversation",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("conversation_log", models.JSONField()),
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
],
options={
"abstract": False,
},
),
]

View File

@@ -0,0 +1,17 @@
# Generated by Django 4.2.5 on 2023-10-18 16:46
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0007_remove_conversationprocessorconfig_conversation_and_more"),
]
operations = [
migrations.AlterField(
model_name="conversation",
name="conversation_log",
field=models.JSONField(default=dict),
),
]

View File

@@ -82,9 +82,27 @@ class LocalPlaintextConfig(BaseModel):
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class ConversationProcessorConfig(BaseModel):
conversation = models.JSONField()
class OpenAIProcessorConversationConfig(BaseModel):
api_key = models.CharField(max_length=200)
chat_model = models.CharField(max_length=200)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class OfflineChatProcessorConversationConfig(BaseModel):
enable_offline_chat = models.BooleanField(default=False)
chat_model = models.CharField(max_length=200, default="llama-2-7b-chat.ggmlv3.q4_0.bin")
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class ConversationProcessorConfig(BaseModel):
max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class Conversation(BaseModel):
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
conversation_log = models.JSONField(default=dict)
class Embeddings(BaseModel):

View File

@@ -23,12 +23,10 @@ from starlette.authentication import (
from khoj.utils import constants, state
from khoj.utils.config import (
SearchType,
ProcessorConfigModel,
ConversationProcessorConfigModel,
)
from khoj.utils.helpers import resolve_absolute_path, merge_dicts
from khoj.utils.helpers import merge_dicts
from khoj.utils.fs_syncer import collect_files
from khoj.utils.rawconfig import FullConfig, OfflineChatProcessorConfig, ProcessorConfig, ConversationProcessorConfig
from khoj.utils.rawconfig import FullConfig
from khoj.routers.indexer import configure_content, load_content, configure_search
from database.models import KhojUser
from database.adapters import get_all_users
@@ -98,13 +96,6 @@ def configure_server(
# Update Config
state.config = config
# Initialize Processor from Config
try:
state.processor_config = configure_processor(state.config.processor)
except Exception as e:
logger.error(f"🚨 Failed to configure processor", exc_info=True)
raise e
# Initialize Search Models from Config and initialize content
try:
state.config_lock.acquire()
@@ -190,103 +181,6 @@ def configure_search_types(config: FullConfig):
return Enum("SearchType", merge_dicts(core_search_types, {}))
def configure_processor(
processor_config: Optional[ProcessorConfig], state_processor_config: Optional[ProcessorConfigModel] = None
):
if not processor_config:
logger.warning("🚨 No Processor configuration available.")
return None
processor = ProcessorConfigModel()
# Initialize Conversation Processor
logger.info("💬 Setting up conversation processor")
processor.conversation = configure_conversation_processor(processor_config, state_processor_config)
return processor
def configure_conversation_processor(
processor_config: Optional[ProcessorConfig], state_processor_config: Optional[ProcessorConfigModel] = None
):
if (
not processor_config
or not processor_config.conversation
or not processor_config.conversation.conversation_logfile
):
default_config = constants.default_config
default_conversation_logfile = resolve_absolute_path(
default_config["processor"]["conversation"]["conversation-logfile"] # type: ignore
)
conversation_logfile = resolve_absolute_path(default_conversation_logfile)
conversation_config = processor_config.conversation if processor_config else None
conversation_processor = ConversationProcessorConfigModel(
conversation_config=ConversationProcessorConfig(
conversation_logfile=conversation_logfile,
openai=(conversation_config.openai if (conversation_config is not None) else None),
offline_chat=conversation_config.offline_chat if conversation_config else OfflineChatProcessorConfig(),
max_prompt_size=conversation_config.max_prompt_size if conversation_config else None,
tokenizer=conversation_config.tokenizer if conversation_config else None,
)
)
else:
conversation_processor = ConversationProcessorConfigModel(
conversation_config=processor_config.conversation,
)
conversation_logfile = resolve_absolute_path(conversation_processor.conversation_logfile)
# Load Conversation Logs from Disk
if state_processor_config and state_processor_config.conversation and state_processor_config.conversation.meta_log:
conversation_processor.meta_log = state_processor_config.conversation.meta_log
conversation_processor.chat_session = state_processor_config.conversation.chat_session
logger.debug(f"Loaded conversation logs from state")
return conversation_processor
if conversation_logfile.is_file():
# Load Metadata Logs from Conversation Logfile
with conversation_logfile.open("r") as f:
conversation_processor.meta_log = json.load(f)
logger.debug(f"Loaded conversation logs from {conversation_logfile}")
else:
# Initialize Conversation Logs
conversation_processor.meta_log = {}
conversation_processor.chat_session = []
return conversation_processor
@schedule.repeat(schedule.every(17).minutes)
def save_chat_session():
# No need to create empty log file
if not (
state.processor_config
and state.processor_config.conversation
and state.processor_config.conversation.meta_log
and state.processor_config.conversation.chat_session
):
return
# Summarize Conversation Logs for this Session
conversation_log = state.processor_config.conversation.meta_log
session = {
"session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"],
"session-end": len(conversation_log["chat"]),
}
if "session" in conversation_log:
conversation_log["session"].append(session)
else:
conversation_log["session"] = [session]
# Save Conversation Metadata Logs to Disk
conversation_logfile = resolve_absolute_path(state.processor_config.conversation.conversation_logfile)
conversation_logfile.parent.mkdir(parents=True, exist_ok=True) # create conversation directory if doesn't exist
with open(conversation_logfile, "w+", encoding="utf-8") as logfile:
json.dump(conversation_log, logfile, indent=2)
state.processor_config.conversation.chat_session = []
logger.info("📩 Saved current chat session to conversation logs")
@schedule.repeat(schedule.every(59).minutes)
def upload_telemetry():
if not state.config or not state.config.app or not state.config.app.should_log_telemetry or not state.telemetry:

View File

@@ -3,6 +3,11 @@
<div class="page">
<div class="section">
{% if anonymous_mode == False %}
<div>
Logged in as {{ username }}
</div>
{% endif %}
<h2 class="section-title">Plugins</h2>
<div class="section-cards">
<div class="card">
@@ -257,8 +262,8 @@
<img class="card-icon" src="/static/assets/icons/chat.svg" alt="Chat">
<h3 class="card-title">
Offline Chat
<img id="configured-icon-conversation-enable-offline-chat" class="configured-icon {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.offline_chat.enable_offline_chat and current_model_state.conversation_gpt4all %}enabled{% else %}disabled{% endif %}" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
{% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.offline_chat.enable_offline_chat and not current_model_state.conversation_gpt4all %}
<img id="configured-icon-conversation-enable-offline-chat" class="configured-icon {% if current_model_state.enable_offline_model and current_model_state.conversation_gpt4all %}enabled{% else %}disabled{% endif %}" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
{% if current_model_state.enable_offline_model and not current_model_state.conversation_gpt4all %}
<img id="misconfigured-icon-conversation-enable-offline-chat" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="The model was not downloaded as expected.">
{% endif %}
</h3>
@@ -266,12 +271,12 @@
<div class="card-description-row">
<p class="card-description">Setup offline chat</p>
</div>
<div id="clear-enable-offline-chat" class="card-action-row {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.offline_chat.enable_offline_chat %}enabled{% else %}disabled{% endif %}">
<div id="clear-enable-offline-chat" class="card-action-row {% if current_model_state.enable_offline_model %}enabled{% else %}disabled{% endif %}">
<button class="card-button" onclick="toggleEnableLocalLLLM(false)">
Disable
</button>
</div>
<div id="set-enable-offline-chat" class="card-action-row {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.offline_chat.enable_offline_chat %}disabled{% else %}enabled{% endif %}">
<div id="set-enable-offline-chat" class="card-action-row {% if current_model_state.enable_offline_model %}disabled{% else %}enabled{% endif %}">
<button class="card-button happy" onclick="toggleEnableLocalLLLM(true)">
Enable
</button>

View File

@@ -8,21 +8,20 @@ from typing import List, Optional, Union, Any
import asyncio
# External Packages
from fastapi import APIRouter, HTTPException, Header, Request, Depends
from fastapi import APIRouter, HTTPException, Header, Request
from starlette.authentication import requires
from asgiref.sync import sync_to_async
# Internal Packages
from khoj.configure import configure_processor, configure_server
from khoj.configure import configure_server
from khoj.search_type import image_search, text_search
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter
from khoj.utils.config import TextSearchModel
from khoj.utils.config import TextSearchModel, GPT4AllProcessorModel
from khoj.utils.helpers import ConversationCommand, is_none_or_empty, timer, command_descriptions
from khoj.utils.rawconfig import (
FullConfig,
ProcessorConfig,
SearchConfig,
SearchResponse,
TextContentConfig,
@@ -32,16 +31,16 @@ from khoj.utils.rawconfig import (
ConversationProcessorConfig,
OfflineChatProcessorConfig,
)
from khoj.utils.helpers import resolve_absolute_path
from khoj.utils.state import SearchType
from khoj.utils import state, constants
from khoj.utils.yaml import save_config_to_file_updated_state
from khoj.utils.helpers import AsyncIteratorWrapper
from fastapi.responses import StreamingResponse, Response
from khoj.routers.helpers import (
get_conversation_command,
perform_chat_checks,
generate_chat_response,
agenerate_chat_response,
update_telemetry_state,
is_ready_to_chat,
)
from khoj.processor.conversation.prompts import help_message
from khoj.processor.conversation.openai.gpt import extract_questions
@@ -49,7 +48,7 @@ from khoj.processor.conversation.gpt4all.chat_model import extract_questions_off
from fastapi.requests import Request
from database import adapters
from database.adapters import EmbeddingsAdapters
from database.adapters import EmbeddingsAdapters, ConversationAdapters
from database.models import LocalMarkdownConfig, LocalOrgConfig, LocalPdfConfig, LocalPlaintextConfig, KhojUser
@@ -114,6 +113,8 @@ async def map_config_to_db(config: FullConfig, user: KhojUser):
user=user,
token=config.content_type.notion.token,
)
if config.processor and config.processor.conversation:
ConversationAdapters.set_conversation_processor_config(user, config.processor.conversation)
# If it's a demo instance, prevent updating any of the configuration.
@@ -123,8 +124,6 @@ if not state.demo:
if state.config is None:
state.config = FullConfig()
state.config.search_type = SearchConfig.parse_obj(constants.default_config["search-type"])
if state.processor_config is None:
state.processor_config = configure_processor(state.config.processor)
@api.get("/config/data", response_model=FullConfig)
@requires(["authenticated"], redirect="login_page")
@@ -238,28 +237,24 @@ if not state.demo:
)
content_object = map_config_to_object(content_type)
if content_object is None:
raise ValueError(f"Invalid content type: {content_type}")
await content_object.objects.filter(user=user).adelete()
await sync_to_async(EmbeddingsAdapters.delete_all_embeddings)(user, content_type)
enabled_content = await sync_to_async(EmbeddingsAdapters.get_unique_file_types)(user)
return {"status": "ok"}
@api.post("/delete/config/data/processor/conversation/openai", status_code=200)
@requires(["authenticated"], redirect="login_page")
async def remove_processor_conversation_config_data(
request: Request,
client: Optional[str] = None,
):
if (
not state.config
or not state.config.processor
or not state.config.processor.conversation
or not state.config.processor.conversation.openai
):
return {"status": "ok"}
user = request.user.object
state.config.processor.conversation.openai = None
state.processor_config = configure_processor(state.config.processor, state.processor_config)
await sync_to_async(ConversationAdapters.clear_openai_conversation_config)(user)
update_telemetry_state(
request=request,
@@ -269,11 +264,7 @@ if not state.demo:
metadata={"processor_conversation_type": "openai"},
)
try:
save_config_to_file_updated_state()
return {"status": "ok"}
except Exception as e:
return {"status": "error", "message": str(e)}
return {"status": "ok"}
@api.post("/config/data/content_type/{content_type}", status_code=200)
@requires(["authenticated"], redirect="login_page")
@@ -301,24 +292,17 @@ if not state.demo:
return {"status": "ok"}
@api.post("/config/data/processor/conversation/openai", status_code=200)
@requires(["authenticated"], redirect="login_page")
async def set_processor_openai_config_data(
request: Request,
updated_config: Union[OpenAIProcessorConfig, None],
client: Optional[str] = None,
):
_initialize_config()
user = request.user.object
if not state.config.processor or not state.config.processor.conversation:
default_config = constants.default_config
default_conversation_logfile = resolve_absolute_path(
default_config["processor"]["conversation"]["conversation-logfile"] # type: ignore
)
conversation_logfile = resolve_absolute_path(default_conversation_logfile)
state.config.processor = ProcessorConfig(conversation=ConversationProcessorConfig(conversation_logfile=conversation_logfile)) # type: ignore
conversation_config = ConversationProcessorConfig(openai=updated_config)
assert state.config.processor.conversation is not None
state.config.processor.conversation.openai = updated_config
state.processor_config = configure_processor(state.config.processor, state.processor_config)
await sync_to_async(ConversationAdapters.set_conversation_processor_config)(user, conversation_config)
update_telemetry_state(
request=request,
@@ -328,11 +312,7 @@ if not state.demo:
metadata={"processor_conversation_type": "conversation"},
)
try:
save_config_to_file_updated_state()
return {"status": "ok"}
except Exception as e:
return {"status": "error", "message": str(e)}
return {"status": "ok"}
@api.post("/config/data/processor/conversation/offline_chat", status_code=200)
async def set_processor_enable_offline_chat_config_data(
@@ -341,24 +321,26 @@ if not state.demo:
offline_chat_model: Optional[str] = None,
client: Optional[str] = None,
):
_initialize_config()
user = request.user.object
if not state.config.processor or not state.config.processor.conversation:
default_config = constants.default_config
default_conversation_logfile = resolve_absolute_path(
default_config["processor"]["conversation"]["conversation-logfile"] # type: ignore
if enable_offline_chat:
conversation_config = ConversationProcessorConfig(
offline_chat=OfflineChatProcessorConfig(
enable_offline_chat=enable_offline_chat,
chat_model=offline_chat_model,
)
)
conversation_logfile = resolve_absolute_path(default_conversation_logfile)
state.config.processor = ProcessorConfig(conversation=ConversationProcessorConfig(conversation_logfile=conversation_logfile)) # type: ignore
assert state.config.processor.conversation is not None
if state.config.processor.conversation.offline_chat is None:
state.config.processor.conversation.offline_chat = OfflineChatProcessorConfig()
await sync_to_async(ConversationAdapters.set_conversation_processor_config)(user, conversation_config)
state.config.processor.conversation.offline_chat.enable_offline_chat = enable_offline_chat
if offline_chat_model is not None:
state.config.processor.conversation.offline_chat.chat_model = offline_chat_model
state.processor_config = configure_processor(state.config.processor, state.processor_config)
offline_chat = await ConversationAdapters.get_offline_chat(user)
chat_model = offline_chat.chat_model
if state.gpt4all_processor_config is None:
state.gpt4all_processor_config = GPT4AllProcessorModel(chat_model=chat_model)
else:
await sync_to_async(ConversationAdapters.clear_offline_chat_conversation_config)(user)
state.gpt4all_processor_config = None
update_telemetry_state(
request=request,
@@ -368,11 +350,7 @@ if not state.demo:
metadata={"processor_conversation_type": f"{'enable' if enable_offline_chat else 'disable'}_local_llm"},
)
try:
save_config_to_file_updated_state()
return {"status": "ok"}
except Exception as e:
return {"status": "error", "message": str(e)}
return {"status": "ok"}
# Create Routes
@@ -426,9 +404,6 @@ async def search(
if q is None or q == "":
logger.warning(f"No query param (q) passed in API call to initiate search")
return results
if not state.search_models or not any(state.search_models.__dict__.values()):
logger.warning(f"No search models loaded. Configure a search model before initiating search")
return results
# initialize variables
user_query = q.strip()
@@ -565,8 +540,6 @@ def update(
components.append("Search models")
if state.content_index:
components.append("Content index")
if state.processor_config:
components.append("Conversation processor")
components_msg = ", ".join(components)
logger.info(f"📪 {components_msg} updated via API")
@@ -592,12 +565,11 @@ def chat_history(
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
):
perform_chat_checks()
user = request.user.object
perform_chat_checks(user)
# Load Conversation History
meta_log = {}
if state.processor_config.conversation:
meta_log = state.processor_config.conversation.meta_log
meta_log = ConversationAdapters.get_conversation_by_user(user=user).conversation_log
update_telemetry_state(
request=request,
@@ -649,30 +621,35 @@ async def chat(
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
) -> Response:
perform_chat_checks()
user = request.user.object
await is_ready_to_chat(user)
conversation_command = get_conversation_command(query=q, any_references=True)
q = q.replace(f"/{conversation_command.value}", "").strip()
meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
request, q, (n or 5), conversation_command
request, meta_log, q, (n or 5), conversation_command
)
if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references):
conversation_command = ConversationCommand.General
if conversation_command == ConversationCommand.Help:
model_type = "offline" if state.processor_config.conversation.offline_chat.enable_offline_chat else "openai"
model_type = "offline" if await ConversationAdapters.has_offline_chat(user) else "openai"
formatted_help = help_message.format(model=model_type, version=state.khoj_version)
return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)
# Get the (streamed) chat response from the LLM of choice.
llm_response = generate_chat_response(
llm_response = await agenerate_chat_response(
defiltered_query,
meta_log=state.processor_config.conversation.meta_log,
compiled_references=compiled_references,
inferred_queries=inferred_queries,
conversation_command=conversation_command,
meta_log,
compiled_references,
inferred_queries,
conversation_command,
user,
)
if llm_response is None:
@@ -681,13 +658,14 @@ async def chat(
if stream:
return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200)
iterator = AsyncIteratorWrapper(llm_response)
# Get the full response from the generator if the stream is not requested.
aggregated_gpt_response = ""
while True:
try:
aggregated_gpt_response += next(llm_response)
except StopIteration:
async for item in iterator:
if item is None:
break
aggregated_gpt_response += item
actual_response = aggregated_gpt_response.split("### compiled references:")[0]
@@ -708,44 +686,53 @@ async def chat(
async def extract_references_and_questions(
request: Request,
meta_log: dict,
q: str,
n: int,
conversation_type: ConversationCommand = ConversationCommand.Default,
):
user = request.user.object if request.user.is_authenticated else None
# Load Conversation History
meta_log = state.processor_config.conversation.meta_log
# Initialize Variables
compiled_references: List[Any] = []
inferred_queries: List[str] = []
if not EmbeddingsAdapters.user_has_embeddings(user=user):
if conversation_type == ConversationCommand.General:
return compiled_references, inferred_queries, q
if not await EmbeddingsAdapters.user_has_embeddings(user=user):
logger.warning(
"No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes."
)
return compiled_references, inferred_queries, q
if conversation_type == ConversationCommand.General:
return compiled_references, inferred_queries, q
# Extract filter terms from user message
defiltered_query = q
for filter in [DateFilter(), WordFilter(), FileFilter()]:
defiltered_query = filter.defilter(defiltered_query)
filters_in_query = q.replace(defiltered_query, "").strip()
using_offline_chat = False
# Infer search queries from user message
with timer("Extracting search queries took", logger):
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
if state.processor_config.conversation.offline_chat.enable_offline_chat:
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
if await ConversationAdapters.has_offline_chat(user):
using_offline_chat = True
offline_chat = await ConversationAdapters.get_offline_chat(user)
chat_model = offline_chat.chat_model
if state.gpt4all_processor_config is None:
state.gpt4all_processor_config = GPT4AllProcessorModel(chat_model=chat_model)
loaded_model = state.gpt4all_processor_config.loaded_model
inferred_queries = extract_questions_offline(
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
)
elif state.processor_config.conversation.openai_model:
api_key = state.processor_config.conversation.openai_model.api_key
chat_model = state.processor_config.conversation.openai_model.chat_model
elif await ConversationAdapters.has_openai_chat(user):
openai_chat = await ConversationAdapters.get_openai_chat(user)
api_key = openai_chat.api_key
chat_model = openai_chat.chat_model
inferred_queries = extract_questions(
defiltered_query, model=chat_model, api_key=api_key, conversation_log=meta_log
)
@@ -754,7 +741,7 @@ async def extract_references_and_questions(
with timer("Searching knowledge base took", logger):
result_list = []
for query in inferred_queries:
n_items = min(n, 3) if state.processor_config.conversation.offline_chat.enable_offline_chat else n
n_items = min(n, 3) if using_offline_chat else n
result_list.extend(
await search(
f"{query} {filters_in_query}",
@@ -765,6 +752,8 @@ async def extract_references_and_questions(
dedupe=False,
)
)
# Dedupe the results again, as duplicates may be returned across queries.
result_list = text_search.deduplicated_search_responses(result_list)
compiled_references = [item.additional["compiled"] for item in result_list]
return compiled_references, inferred_queries, defiltered_query

View File

@@ -1,34 +1,50 @@
import logging
import asyncio
from datetime import datetime
from functools import partial
from typing import Iterator, List, Optional, Union
from concurrent.futures import ThreadPoolExecutor
from fastapi import HTTPException, Request
from khoj.utils import state
from khoj.utils.helpers import ConversationCommand, timer, log_telemetry
from khoj.utils.config import GPT4AllProcessorModel
from khoj.utils.helpers import ConversationCommand, log_telemetry
from khoj.processor.conversation.openai.gpt import converse
from khoj.processor.conversation.gpt4all.chat_model import converse_offline
from khoj.processor.conversation.utils import reciprocal_conversation_to_chatml, message_to_log, ThreadedGenerator
from khoj.processor.conversation.utils import message_to_log, ThreadedGenerator
from database.models import KhojUser
from database.adapters import ConversationAdapters
logger = logging.getLogger(__name__)
executor = ThreadPoolExecutor(max_workers=1)
def perform_chat_checks():
if (
state.processor_config
and state.processor_config.conversation
and (
state.processor_config.conversation.openai_model
or state.processor_config.conversation.gpt4all_model.loaded_model
)
):
def perform_chat_checks(user: KhojUser):
if ConversationAdapters.has_valid_offline_conversation_config(
user
) or ConversationAdapters.has_valid_openai_conversation_config(user):
return
raise HTTPException(
status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings and restart it."
)
raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.")
async def is_ready_to_chat(user: KhojUser):
has_offline_config = await ConversationAdapters.has_offline_chat(user=user)
has_openai_config = await ConversationAdapters.has_openai_chat(user=user)
if has_offline_config:
offline_chat = await ConversationAdapters.get_offline_chat(user)
chat_model = offline_chat.chat_model
if state.gpt4all_processor_config is None:
state.gpt4all_processor_config = GPT4AllProcessorModel(chat_model=chat_model)
return True
ready = has_openai_config or has_offline_config
if not ready:
raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.")
def update_telemetry_state(
@@ -74,12 +90,22 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver
return ConversationCommand.Default
async def construct_conversation_logs(user: KhojUser):
return (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
async def agenerate_chat_response(*args):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(executor, generate_chat_response, *args)
def generate_chat_response(
q: str,
meta_log: dict,
compiled_references: List[str] = [],
inferred_queries: List[str] = [],
conversation_command: ConversationCommand = ConversationCommand.Default,
user: KhojUser = None,
) -> Union[ThreadedGenerator, Iterator[str]]:
def _save_to_conversation_log(
q: str,
@@ -89,17 +115,14 @@ def generate_chat_response(
inferred_queries: List[str],
meta_log,
):
state.processor_config.conversation.chat_session += reciprocal_conversation_to_chatml([q, chat_response])
state.processor_config.conversation.meta_log["chat"] = message_to_log(
updated_conversation = message_to_log(
user_message=q,
chat_response=chat_response,
user_message_metadata={"created": user_message_time},
khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}},
conversation_log=meta_log.get("chat", []),
)
# Load Conversation History
meta_log = state.processor_config.conversation.meta_log
ConversationAdapters.save_conversation(user, {"chat": updated_conversation})
# Initialize Variables
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
@@ -116,8 +139,14 @@ def generate_chat_response(
meta_log=meta_log,
)
if state.processor_config.conversation.offline_chat.enable_offline_chat:
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
offline_chat_config = ConversationAdapters.get_offline_chat_conversation_config(user=user)
conversation_config = ConversationAdapters.get_conversation_config(user)
openai_chat_config = ConversationAdapters.get_openai_conversation_config(user)
if offline_chat_config:
if state.gpt4all_processor_config.loaded_model is None:
state.gpt4all_processor_config = GPT4AllProcessorModel(offline_chat_config.chat_model)
loaded_model = state.gpt4all_processor_config.loaded_model
chat_response = converse_offline(
references=compiled_references,
user_query=q,
@@ -125,14 +154,14 @@ def generate_chat_response(
conversation_log=meta_log,
completion_func=partial_completion,
conversation_command=conversation_command,
model=state.processor_config.conversation.offline_chat.chat_model,
max_prompt_size=state.processor_config.conversation.max_prompt_size,
tokenizer_name=state.processor_config.conversation.tokenizer,
model=offline_chat_config.chat_model,
max_prompt_size=conversation_config.max_prompt_size,
tokenizer_name=conversation_config.tokenizer,
)
elif state.processor_config.conversation.openai_model:
api_key = state.processor_config.conversation.openai_model.api_key
chat_model = state.processor_config.conversation.openai_model.chat_model
elif openai_chat_config:
api_key = openai_chat_config.api_key
chat_model = openai_chat_config.chat_model
chat_response = converse(
compiled_references,
q,
@@ -141,8 +170,8 @@ def generate_chat_response(
api_key=api_key,
completion_func=partial_completion,
conversation_command=conversation_command,
max_prompt_size=state.processor_config.conversation.max_prompt_size,
tokenizer_name=state.processor_config.conversation.tokenizer,
max_prompt_size=conversation_config.max_prompt_size if conversation_config else None,
tokenizer_name=conversation_config.tokenizer if conversation_config else None,
)
except Exception as e:

View File

@@ -92,7 +92,7 @@ async def update(
if dict_to_update is not None:
dict_to_update[file.filename] = (
file.file.read().decode("utf-8") if encoding == "utf-8" else file.file.read()
file.file.read().decode("utf-8") if encoding == "utf-8" else file.file.read() # type: ignore
)
else:
logger.warning(f"Skipped indexing unsupported file type sent by {client} client: {file.filename}")
@@ -181,24 +181,25 @@ def configure_content(
files: Optional[dict[str, dict[str, str]]],
search_models: SearchModels,
regenerate: bool = False,
t: Optional[Union[state.SearchType, str]] = None,
t: Optional[state.SearchType] = None,
full_corpus: bool = True,
user: KhojUser = None,
) -> Optional[ContentIndex]:
content_index = ContentIndex()
if t in [type.value for type in state.SearchType]:
t = state.SearchType(t).value
if t is not None and not t.value in [type.value for type in state.SearchType]:
logger.warning(f"🚨 Invalid search type: {t}")
return None
assert type(t) == str or t == None, f"Invalid search type: {t}"
search_type = t.value if t else None
if files is None:
logger.warning(f"🚨 No files to process for {t} search.")
logger.warning(f"🚨 No files to process for {search_type} search.")
return None
try:
# Initialize Org Notes Search
if (t == None or t == state.SearchType.Org.value) and files["org"]:
if (search_type == None or search_type == state.SearchType.Org.value) and files["org"]:
logger.info("🦄 Setting up search for orgmode notes")
# Extract Entries, Generate Notes Embeddings
text_search.setup(
@@ -213,7 +214,7 @@ def configure_content(
try:
# Initialize Markdown Search
if (t == None or t == state.SearchType.Markdown.value) and files["markdown"]:
if (search_type == None or search_type == state.SearchType.Markdown.value) and files["markdown"]:
logger.info("💎 Setting up search for markdown notes")
# Extract Entries, Generate Markdown Embeddings
text_search.setup(
@@ -229,7 +230,7 @@ def configure_content(
try:
# Initialize PDF Search
if (t == None or t == state.SearchType.Pdf.value) and files["pdf"]:
if (search_type == None or search_type == state.SearchType.Pdf.value) and files["pdf"]:
logger.info("🖨️ Setting up search for pdf")
# Extract Entries, Generate PDF Embeddings
text_search.setup(
@@ -245,7 +246,7 @@ def configure_content(
try:
# Initialize Plaintext Search
if (t == None or t == state.SearchType.Plaintext.value) and files["plaintext"]:
if (search_type == None or search_type == state.SearchType.Plaintext.value) and files["plaintext"]:
logger.info("📄 Setting up search for plaintext")
# Extract Entries, Generate Plaintext Embeddings
text_search.setup(
@@ -262,7 +263,7 @@ def configure_content(
try:
# Initialize Image Search
if (
(t == None or t == state.SearchType.Image.value)
(search_type == None or search_type == state.SearchType.Image.value)
and content_config
and content_config.image
and search_models.image_search
@@ -278,7 +279,7 @@ def configure_content(
try:
github_config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
if (t == None or t == state.SearchType.Github.value) and github_config is not None:
if (search_type == None or search_type == state.SearchType.Github.value) and github_config is not None:
logger.info("🐙 Setting up search for github")
# Extract Entries, Generate Github Embeddings
text_search.setup(
@@ -296,7 +297,7 @@ def configure_content(
try:
# Initialize Notion Search
notion_config = NotionConfig.objects.filter(user=user).first()
if (t == None or t in state.SearchType.Notion.value) and notion_config:
if (search_type == None or search_type in state.SearchType.Notion.value) and notion_config:
logger.info("🔌 Setting up search for notion")
text_search.setup(
NotionToJsonl,

View File

@@ -19,7 +19,7 @@ from khoj.utils.rawconfig import (
# Internal Packages
from khoj.utils import constants, state
from database.adapters import EmbeddingsAdapters, get_user_github_config, get_user_notion_config
from database.adapters import EmbeddingsAdapters, get_user_github_config, get_user_notion_config, ConversationAdapters
from database.models import LocalOrgConfig, LocalMarkdownConfig, LocalPdfConfig, LocalPlaintextConfig
@@ -83,7 +83,7 @@ if not state.demo:
@web_client.get("/config", response_class=HTMLResponse)
@requires(["authenticated"], redirect="login_page")
def config_page(request: Request):
user = request.user.object if request.user.is_authenticated else None
user = request.user.object
enabled_content = set(EmbeddingsAdapters.get_unique_file_types(user).all())
default_full_config = FullConfig(
content_type=None,
@@ -100,9 +100,6 @@ if not state.demo:
"github": ("github" in enabled_content),
"notion": ("notion" in enabled_content),
"plaintext": ("plaintext" in enabled_content),
"enable_offline_model": False,
"conversation_openai": False,
"conversation_gpt4all": False,
}
if state.content_index:
@@ -112,13 +109,17 @@ if not state.demo:
}
)
if state.processor_config and state.processor_config.conversation:
successfully_configured.update(
{
"conversation_openai": state.processor_config.conversation.openai_model is not None,
"conversation_gpt4all": state.processor_config.conversation.gpt4all_model.loaded_model is not None,
}
)
enabled_chat_config = ConversationAdapters.get_enabled_conversation_settings(user)
successfully_configured.update(
{
"conversation_openai": enabled_chat_config["openai"],
"enable_offline_model": enabled_chat_config["offline_chat"],
"conversation_gpt4all": state.gpt4all_processor_config.loaded_model is not None
if state.gpt4all_processor_config
else False,
}
)
return templates.TemplateResponse(
"config.html",
@@ -127,6 +128,7 @@ if not state.demo:
"current_config": current_config,
"current_model_state": successfully_configured,
"anonymous_mode": state.anonymous_mode,
"username": user.username if user else None,
},
)
@@ -204,22 +206,22 @@ if not state.demo:
)
@web_client.get("/config/processor/conversation/openai", response_class=HTMLResponse)
@requires(["authenticated"], redirect="login_page")
def conversation_processor_config_page(request: Request):
default_copy = constants.default_config.copy()
default_processor_config = default_copy["processor"]["conversation"]["openai"] # type: ignore
default_openai_config = OpenAIProcessorConfig(
api_key="",
chat_model=default_processor_config["chat-model"],
)
user = request.user.object
openai_config = ConversationAdapters.get_openai_conversation_config(user)
if openai_config:
current_processor_openai_config = OpenAIProcessorConfig(
api_key=openai_config.api_key,
chat_model=openai_config.chat_model,
)
else:
current_processor_openai_config = OpenAIProcessorConfig(
api_key="",
chat_model="gpt-3.5-turbo",
)
current_processor_openai_config = (
state.config.processor.conversation.openai
if state.config
and state.config.processor
and state.config.processor.conversation
and state.config.processor.conversation.openai
else default_openai_config
)
current_processor_openai_config = json.loads(current_processor_openai_config.json())
return templates.TemplateResponse(

View File

@@ -236,6 +236,7 @@ def collate_results(hits, image_names, output_directory, image_files_url, count=
"image_score": f"{hit['image_score']:.9f}",
"metadata_score": f"{hit['metadata_score']:.9f}",
},
"corpus_id": hit["corpus_id"],
}
)
]

View File

@@ -14,10 +14,9 @@ from asgiref.sync import sync_to_async
# Internal Packages
from khoj.utils import state
from khoj.utils.helpers import get_absolute_path, resolve_absolute_path, load_model, timer
from khoj.utils.config import TextSearchModel
from khoj.utils.models import BaseEncoder
from khoj.utils.state import SearchType
from khoj.utils.rawconfig import SearchResponse, TextSearchConfig, Entry
from khoj.utils.rawconfig import SearchResponse, Entry
from khoj.utils.jsonl import load_jsonl
from khoj.processor.text_to_jsonl import TextEmbeddings
from database.adapters import EmbeddingsAdapters
@@ -36,36 +35,6 @@ search_type_to_embeddings_type = {
}
def initialize_model(search_config: TextSearchConfig):
"Initialize model for semantic search on text"
torch.set_num_threads(4)
# If model directory is configured
if search_config.model_directory:
# Convert model directory to absolute path
search_config.model_directory = resolve_absolute_path(search_config.model_directory)
# Create model directory if it doesn't exist
search_config.model_directory.parent.mkdir(parents=True, exist_ok=True)
# The bi-encoder encodes all entries to use for semantic search
bi_encoder = load_model(
model_dir=search_config.model_directory,
model_name=search_config.encoder,
model_type=search_config.encoder_type or SentenceTransformer,
device=f"{state.device}",
)
# The cross-encoder re-ranks the results to improve quality
cross_encoder = load_model(
model_dir=search_config.model_directory,
model_name=search_config.cross_encoder,
model_type=CrossEncoder,
device=f"{state.device}",
)
return TextSearchModel(bi_encoder, cross_encoder)
def extract_entries(jsonl_file) -> List[Entry]:
"Load entries from compressed jsonl"
return list(map(Entry.from_dict, load_jsonl(jsonl_file)))
@@ -176,6 +145,7 @@ def collate_results(hits, dedupe=True):
{
"entry": hit.raw,
"score": hit.distance,
"corpus_id": str(hit.corpus_id),
"additional": {
"file": hit.file_path,
"compiled": hit.compiled,
@@ -185,6 +155,28 @@ def collate_results(hits, dedupe=True):
)
def deduplicated_search_responses(hits: List[SearchResponse]):
hit_ids = set()
for hit in hits:
if hit.corpus_id in hit_ids:
continue
else:
hit_ids.add(hit.corpus_id)
yield SearchResponse.parse_obj(
{
"entry": hit.entry,
"score": hit.score,
"corpus_id": hit.corpus_id,
"additional": {
"file": hit.additional["file"],
"compiled": hit.additional["compiled"],
"heading": hit.additional["heading"],
},
}
)
def rerank_and_sort_results(hits, query):
# Score all retrieved entries using the cross-encoder
hits = cross_encoder_score(query, hits)

View File

@@ -5,8 +5,7 @@ from enum import Enum
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Union, Any
from typing import TYPE_CHECKING, List, Optional, Union, Any
from khoj.processor.conversation.gpt4all.utils import download_model
# External Packages
@@ -19,9 +18,7 @@ logger = logging.getLogger(__name__)
# Internal Packages
if TYPE_CHECKING:
from sentence_transformers import CrossEncoder
from khoj.search_filter.base_filter import BaseFilter
from khoj.utils.models import BaseEncoder
from khoj.utils.rawconfig import ConversationProcessorConfig, Entry, OpenAIProcessorConfig
class SearchType(str, Enum):
@@ -79,31 +76,15 @@ class GPT4AllProcessorConfig:
loaded_model: Union[Any, None] = None
class ConversationProcessorConfigModel:
class GPT4AllProcessorModel:
def __init__(
self,
conversation_config: ConversationProcessorConfig,
chat_model: str = "llama-2-7b-chat.ggmlv3.q4_0.bin",
):
self.openai_model = conversation_config.openai
self.gpt4all_model = GPT4AllProcessorConfig()
self.offline_chat = conversation_config.offline_chat or OfflineChatProcessorConfig()
self.max_prompt_size = conversation_config.max_prompt_size
self.tokenizer = conversation_config.tokenizer
self.conversation_logfile = Path(conversation_config.conversation_logfile)
self.chat_session: List[str] = []
self.meta_log: dict = {}
if self.offline_chat.enable_offline_chat:
try:
self.gpt4all_model.loaded_model = download_model(self.offline_chat.chat_model)
except Exception as e:
self.offline_chat.enable_offline_chat = False
self.gpt4all_model.loaded_model = None
logger.error(f"Error while loading offline chat model: {e}", exc_info=True)
else:
self.gpt4all_model.loaded_model = None
@dataclass
class ProcessorConfigModel:
conversation: Union[ConversationProcessorConfigModel, None] = None
self.chat_model = chat_model
self.loaded_model = None
try:
self.loaded_model = download_model(self.chat_model)
except ValueError as e:
self.loaded_model = None
logger.error(f"Error while loading offline chat model: {e}", exc_info=True)

View File

@@ -8,136 +8,14 @@ telemetry_server = "https://khoj.beta.haletic.com/v1/telemetry"
content_directory = "~/.khoj/content/"
empty_config = {
"content-type": {
"org": {
"input-files": None,
"input-filter": None,
"compressed-jsonl": "~/.khoj/content/org/org.jsonl.gz",
"embeddings-file": "~/.khoj/content/org/org_embeddings.pt",
"index-heading-entries": False,
},
"markdown": {
"input-files": None,
"input-filter": None,
"compressed-jsonl": "~/.khoj/content/markdown/markdown.jsonl.gz",
"embeddings-file": "~/.khoj/content/markdown/markdown_embeddings.pt",
},
"pdf": {
"input-files": None,
"input-filter": None,
"compressed-jsonl": "~/.khoj/content/pdf/pdf.jsonl.gz",
"embeddings-file": "~/.khoj/content/pdf/pdf_embeddings.pt",
},
"plaintext": {
"input-files": None,
"input-filter": None,
"compressed-jsonl": "~/.khoj/content/plaintext/plaintext.jsonl.gz",
"embeddings-file": "~/.khoj/content/plaintext/plaintext_embeddings.pt",
},
},
"search-type": {
"symmetric": {
"encoder": "sentence-transformers/all-MiniLM-L6-v2",
"cross-encoder": "cross-encoder/ms-marco-MiniLM-L-6-v2",
"model_directory": "~/.khoj/search/symmetric/",
},
"asymmetric": {
"encoder": "sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
"cross-encoder": "cross-encoder/ms-marco-MiniLM-L-6-v2",
"model_directory": "~/.khoj/search/asymmetric/",
},
"image": {"encoder": "sentence-transformers/clip-ViT-B-32", "model_directory": "~/.khoj/search/image/"},
},
"processor": {
"conversation": {
"openai": {
"api-key": None,
"chat-model": "gpt-3.5-turbo",
},
"offline-chat": {
"enable-offline-chat": False,
"chat-model": "llama-2-7b-chat.ggmlv3.q4_0.bin",
},
"tokenizer": None,
"max-prompt-size": None,
"conversation-logfile": "~/.khoj/processor/conversation/conversation_logs.json",
}
},
}
# default app config to use
default_config = {
"content-type": {
"org": {
"input-files": None,
"input-filter": None,
"compressed-jsonl": "~/.khoj/content/org/org.jsonl.gz",
"embeddings-file": "~/.khoj/content/org/org_embeddings.pt",
"index-heading-entries": False,
},
"markdown": {
"input-files": None,
"input-filter": None,
"compressed-jsonl": "~/.khoj/content/markdown/markdown.jsonl.gz",
"embeddings-file": "~/.khoj/content/markdown/markdown_embeddings.pt",
},
"pdf": {
"input-files": None,
"input-filter": None,
"compressed-jsonl": "~/.khoj/content/pdf/pdf.jsonl.gz",
"embeddings-file": "~/.khoj/content/pdf/pdf_embeddings.pt",
},
"image": {
"input-directories": None,
"input-filter": None,
"embeddings-file": "~/.khoj/content/image/image_embeddings.pt",
"batch-size": 50,
"use-xmp-metadata": False,
},
"github": {
"pat-token": None,
"repos": [],
"compressed-jsonl": "~/.khoj/content/github/github.jsonl.gz",
"embeddings-file": "~/.khoj/content/github/github_embeddings.pt",
},
"notion": {
"token": None,
"compressed-jsonl": "~/.khoj/content/notion/notion.jsonl.gz",
"embeddings-file": "~/.khoj/content/notion/notion_embeddings.pt",
},
"plaintext": {
"input-files": None,
"input-filter": None,
"compressed-jsonl": "~/.khoj/content/plaintext/plaintext.jsonl.gz",
"embeddings-file": "~/.khoj/content/plaintext/plaintext_embeddings.pt",
},
},
"search-type": {
"symmetric": {
"encoder": "sentence-transformers/all-MiniLM-L6-v2",
"cross-encoder": "cross-encoder/ms-marco-MiniLM-L-6-v2",
"model_directory": "~/.khoj/search/symmetric/",
},
"asymmetric": {
"encoder": "sentence-transformers/multi-qa-MiniLM-L6-cos-v1",
"cross-encoder": "cross-encoder/ms-marco-MiniLM-L-6-v2",
"model_directory": "~/.khoj/search/asymmetric/",
},
"image": {"encoder": "sentence-transformers/clip-ViT-B-32", "model_directory": "~/.khoj/search/image/"},
},
"processor": {
"conversation": {
"openai": {
"api-key": None,
"chat-model": "gpt-3.5-turbo",
},
"offline-chat": {
"enable-offline-chat": False,
"chat-model": "llama-2-7b-chat.ggmlv3.q4_0.bin",
},
"tokenizer": None,
"max-prompt-size": None,
"conversation-logfile": "~/.khoj/processor/conversation/conversation_logs.json",
}
},
}

View File

@@ -15,6 +15,7 @@ from time import perf_counter
import torch
from typing import Optional, Union, TYPE_CHECKING
import uuid
from asgiref.sync import sync_to_async
# Internal Packages
from khoj.utils import constants
@@ -29,6 +30,28 @@ if TYPE_CHECKING:
from khoj.utils.rawconfig import AppConfig
class AsyncIteratorWrapper:
def __init__(self, obj):
self._it = iter(obj)
def __aiter__(self):
return self
async def __anext__(self):
try:
value = await self.next_async()
except StopAsyncIteration:
return
return value
@sync_to_async
def next_async(self):
try:
return next(self._it)
except StopIteration:
raise StopAsyncIteration
def is_none_or_empty(item):
return item == None or (hasattr(item, "__iter__") and len(item) == 0) or item == ""

View File

@@ -67,13 +67,6 @@ class ContentConfig(ConfigBase):
notion: Optional[NotionContentConfig]
class TextSearchConfig(ConfigBase):
encoder: str
cross_encoder: str
encoder_type: Optional[str]
model_directory: Optional[Path]
class ImageSearchConfig(ConfigBase):
encoder: str
encoder_type: Optional[str]
@@ -81,8 +74,6 @@ class ImageSearchConfig(ConfigBase):
class SearchConfig(ConfigBase):
asymmetric: Optional[TextSearchConfig]
symmetric: Optional[TextSearchConfig]
image: Optional[ImageSearchConfig]
@@ -97,11 +88,10 @@ class OfflineChatProcessorConfig(ConfigBase):
class ConversationProcessorConfig(ConfigBase):
conversation_logfile: Path
openai: Optional[OpenAIProcessorConfig]
offline_chat: Optional[OfflineChatProcessorConfig]
max_prompt_size: Optional[int]
tokenizer: Optional[str]
openai: Optional[OpenAIProcessorConfig] = None
offline_chat: Optional[OfflineChatProcessorConfig] = None
max_prompt_size: Optional[int] = None
tokenizer: Optional[str] = None
class ProcessorConfig(ConfigBase):
@@ -125,6 +115,7 @@ class SearchResponse(ConfigBase):
score: float
cross_score: Optional[float]
additional: Optional[dict]
corpus_id: str
class Entry:

View File

@@ -10,7 +10,7 @@ from pathlib import Path
# Internal Packages
from khoj.utils import config as utils_config
from khoj.utils.config import ContentIndex, SearchModels, ProcessorConfigModel
from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel
from khoj.utils.helpers import LRU
from khoj.utils.rawconfig import FullConfig
from khoj.processor.embeddings import EmbeddingsModel, CrossEncoderModel
@@ -21,7 +21,7 @@ search_models = SearchModels()
embeddings_model = EmbeddingsModel()
cross_encoder_model = CrossEncoderModel()
content_index = ContentIndex()
processor_config = ProcessorConfigModel()
gpt4all_processor_config: GPT4AllProcessorModel = None
config_file: Path = None
verbose: int = 0
host: str = None