diff --git a/docker-compose.yml b/docker-compose.yml
index c75aa4fc..365d2572 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -10,7 +10,15 @@ services:
POSTGRES_DB: postgres
volumes:
- khoj_db:/var/lib/postgresql/data/
+ healthcheck:
+ test: ["CMD-SHELL", "pg_isready -U postgres"]
+ interval: 30s
+ timeout: 10s
+ retries: 5
server:
+ depends_on:
+ database:
+ condition: service_healthy
# Use the following line to use the latest version of khoj. Otherwise, it will build from source.
image: ghcr.io/khoj-ai/khoj:latest
# Uncomment the following line to build from source. This will take a few minutes. Comment the next two lines out if you want to use the offiicial image.
@@ -24,20 +32,6 @@ services:
- "42110:42110"
working_dir: /app
volumes:
- - .:/app
- # These mounted volumes hold the raw data that should be indexed for search.
- # The path in your local directory (left hand side)
- # points to the files you want to index.
- # The path of the mounted directory (right hand side),
- # must match the path prefix in your config file.
- - ./tests/data/org/:/data/org/
- - ./tests/data/images/:/data/images/
- - ./tests/data/markdown/:/data/markdown/
- - ./tests/data/pdf/:/data/pdf/
- # Embeddings and models are populated after the first run
- # You can set these volumes to point to empty directories on host
- - ./tests/data/embeddings/:/root/.khoj/content/
- - ./tests/data/models/:/root/.khoj/search/
- khoj_config:/root/.khoj/
- khoj_models:/root/.cache/torch/sentence_transformers
# Use 0.0.0.0 to explicitly set the host ip for the service on the container. https://pythonspeed.com/articles/docker-connection-refused/
@@ -47,9 +41,11 @@ services:
- POSTGRES_PASSWORD=postgres
- POSTGRES_HOST=database
- POSTGRES_PORT=5432
- - GOOGLE_CLIENT_SECRET=bar
- - GOOGLE_CLIENT_ID=foo
- command: --host="0.0.0.0" --port=42110 -vv
+ - KHOJ_DJANGO_SECRET_KEY=secret
+ - KHOJ_DEBUG=True
+ - KHOJ_ADMIN_EMAIL=username@example.com
+ - KHOJ_ADMIN_PASSWORD=password
+ command: --host="0.0.0.0" --port=42110 -vv --anonymous-mode
volumes:
diff --git a/src/database/adapters/__init__.py b/src/database/adapters/__init__.py
index 70d94df3..4b9b54ef 100644
--- a/src/database/adapters/__init__.py
+++ b/src/database/adapters/__init__.py
@@ -1,3 +1,4 @@
+import math
from typing import Optional, Type, TypeVar, List
from datetime import date, datetime, timedelta
import secrets
@@ -101,6 +102,8 @@ async def create_google_user(token: dict) -> KhojUser:
user=user,
)
+ await Subscription.objects.acreate(user=user, type="trial")
+
return user
@@ -435,12 +438,19 @@ class EntryAdapters:
@staticmethod
def search_with_embeddings(
- user: KhojUser, embeddings: Tensor, max_results: int = 10, file_type_filter: str = None, raw_query: str = None
+ user: KhojUser,
+ embeddings: Tensor,
+ max_results: int = 10,
+ file_type_filter: str = None,
+ raw_query: str = None,
+ max_distance: float = math.inf,
):
relevant_entries = EntryAdapters.apply_filters(user, raw_query, file_type_filter)
relevant_entries = relevant_entries.filter(user=user).annotate(
distance=CosineDistance("embeddings", embeddings)
)
+ relevant_entries = relevant_entries.filter(distance__lte=max_distance)
+
if file_type_filter:
relevant_entries = relevant_entries.filter(file_type=file_type_filter)
relevant_entries = relevant_entries.order_by("distance")
diff --git a/src/database/admin.py b/src/database/admin.py
index 5f41f54a..03c2ca42 100644
--- a/src/database/admin.py
+++ b/src/database/admin.py
@@ -8,6 +8,7 @@ from database.models import (
ChatModelOptions,
OpenAIProcessorConversationConfig,
OfflineChatProcessorConversationConfig,
+ Subscription,
)
admin.site.register(KhojUser, UserAdmin)
@@ -15,3 +16,4 @@ admin.site.register(KhojUser, UserAdmin)
admin.site.register(ChatModelOptions)
admin.site.register(OpenAIProcessorConversationConfig)
admin.site.register(OfflineChatProcessorConversationConfig)
+admin.site.register(Subscription)
diff --git a/src/database/migrations/0015_alter_subscription_user.py b/src/database/migrations/0015_alter_subscription_user.py
new file mode 100644
index 00000000..e4ba6ab0
--- /dev/null
+++ b/src/database/migrations/0015_alter_subscription_user.py
@@ -0,0 +1,21 @@
+# Generated by Django 4.2.5 on 2023-11-11 05:39
+
+from django.conf import settings
+from django.db import migrations, models
+import django.db.models.deletion
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("database", "0014_alter_googleuser_picture"),
+ ]
+
+ operations = [
+ migrations.AlterField(
+ model_name="subscription",
+ name="user",
+ field=models.OneToOneField(
+ on_delete=django.db.models.deletion.CASCADE, related_name="subscription", to=settings.AUTH_USER_MODEL
+ ),
+ ),
+ ]
diff --git a/src/database/migrations/0016_alter_subscription_renewal_date.py b/src/database/migrations/0016_alter_subscription_renewal_date.py
new file mode 100644
index 00000000..bc7c5ada
--- /dev/null
+++ b/src/database/migrations/0016_alter_subscription_renewal_date.py
@@ -0,0 +1,17 @@
+# Generated by Django 4.2.5 on 2023-11-11 06:15
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("database", "0015_alter_subscription_user"),
+ ]
+
+ operations = [
+ migrations.AlterField(
+ model_name="subscription",
+ name="renewal_date",
+ field=models.DateTimeField(blank=True, default=None, null=True),
+ ),
+ ]
diff --git a/src/database/models/__init__.py b/src/database/models/__init__.py
index 73f19c36..437d86ed 100644
--- a/src/database/models/__init__.py
+++ b/src/database/models/__init__.py
@@ -51,10 +51,10 @@ class Subscription(BaseModel):
TRIAL = "trial"
STANDARD = "standard"
- user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
+ user = models.OneToOneField(KhojUser, on_delete=models.CASCADE, related_name="subscription")
type = models.CharField(max_length=20, choices=Type.choices, default=Type.TRIAL)
is_recurring = models.BooleanField(default=False)
- renewal_date = models.DateTimeField(null=True, default=None)
+ renewal_date = models.DateTimeField(null=True, default=None, blank=True)
class NotionConfig(BaseModel):
diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html
index 302d4a54..a089939d 100644
--- a/src/interface/desktop/chat.html
+++ b/src/interface/desktop/chat.html
@@ -577,12 +577,12 @@
cursor: pointer;
transition: background 0.2s ease-in-out;
text-align: left;
- max-height: 50px;
+ max-height: 75px;
transition: max-height 0.3s ease-in-out;
overflow: hidden;
}
button.reference-button.expanded {
- max-height: 200px;
+ max-height: none;
}
button.reference-button::before {
diff --git a/src/interface/desktop/renderer.js b/src/interface/desktop/renderer.js
index 7e3dba4c..7d0d906e 100644
--- a/src/interface/desktop/renderer.js
+++ b/src/interface/desktop/renderer.js
@@ -198,12 +198,6 @@ khojKeyInput.addEventListener('blur', async () => {
khojKeyInput.value = token;
});
-const syncButton = document.getElementById('sync-data');
-syncButton.addEventListener('click', async () => {
- loadingBar.style.display = 'block';
- await window.syncDataAPI.syncData(false);
-});
-
const syncForceButton = document.getElementById('sync-force');
syncForceButton.addEventListener('click', async () => {
loadingBar.style.display = 'block';
diff --git a/src/interface/desktop/search.html b/src/interface/desktop/search.html
index 315e6972..aa8aa662 100644
--- a/src/interface/desktop/search.html
+++ b/src/interface/desktop/search.html
@@ -188,7 +188,6 @@
fetch(url, { headers })
.then(response => response.json())
.then(data => {
- console.log(data);
document.getElementById("results").innerHTML = render_results(data, query, type);
});
}
diff --git a/src/interface/obsidian/src/main.ts b/src/interface/obsidian/src/main.ts
index 0285fb6c..9f0560e8 100644
--- a/src/interface/obsidian/src/main.ts
+++ b/src/interface/obsidian/src/main.ts
@@ -1,4 +1,4 @@
-import { Notice, Plugin } from 'obsidian';
+import { Notice, Plugin, request } from 'obsidian';
import { KhojSetting, KhojSettingTab, DEFAULT_SETTINGS } from 'src/settings'
import { KhojSearchModal } from 'src/search_modal'
import { KhojChatModal } from 'src/chat_modal'
@@ -69,6 +69,25 @@ export default class Khoj extends Plugin {
async loadSettings() {
// Load khoj obsidian plugin settings
this.settings = Object.assign({}, DEFAULT_SETTINGS, await this.loadData());
+
+ // Check if khoj backend is configured, note if cannot connect to backend
+ let headers = { "Authorization": `Bearer ${this.settings.khojApiKey}` };
+
+ if (this.settings.khojUrl === "https://app.khoj.dev") {
+ if (this.settings.khojApiKey === "") {
+ new Notice(`โ๏ธKhoj API key is not configured. Please visit https://app.khoj.dev to get an API key.`);
+ return;
+ }
+
+ await request({ url: this.settings.khojUrl ,method: "GET", headers: headers })
+ .then(response => {
+ this.settings.connectedToBackend = true;
+ })
+ .catch(error => {
+ this.settings.connectedToBackend = false;
+ new Notice(`โ๏ธEnsure Khoj backend is running and Khoj URL is pointing to it in the plugin settings.\n\n${error}`);
+ });
+ }
}
async saveSettings() {
diff --git a/src/khoj/configure.py b/src/khoj/configure.py
index 6f0589a8..9fb1f019 100644
--- a/src/khoj/configure.py
+++ b/src/khoj/configure.py
@@ -28,7 +28,7 @@ from khoj.utils.config import (
from khoj.utils.fs_syncer import collect_files
from khoj.utils.rawconfig import FullConfig
from khoj.routers.indexer import configure_content, load_content, configure_search
-from database.models import KhojUser
+from database.models import KhojUser, Subscription
from database.adapters import get_all_users
@@ -54,27 +54,37 @@ class UserAuthenticationBackend(AuthenticationBackend):
def _initialize_default_user(self):
if not self.khojuser_manager.filter(username="default").exists():
- self.khojuser_manager.create_user(
+ default_user = self.khojuser_manager.create_user(
username="default",
email="default@example.com",
password="default",
)
+ Subscription.objects.create(user=default_user, type="standard", renewal_date="2100-04-01")
async def authenticate(self, request: HTTPConnection):
current_user = request.session.get("user")
if current_user and current_user.get("email"):
- user = await self.khojuser_manager.filter(email=current_user.get("email")).afirst()
+ user = (
+ await self.khojuser_manager.filter(email=current_user.get("email"))
+ .prefetch_related("subscription")
+ .afirst()
+ )
if user:
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
if len(request.headers.get("Authorization", "").split("Bearer ")) == 2:
# Get bearer token from header
bearer_token = request.headers["Authorization"].split("Bearer ")[1]
# Get user owning token
- user_with_token = await self.khojapiuser_manager.filter(token=bearer_token).select_related("user").afirst()
+ user_with_token = (
+ await self.khojapiuser_manager.filter(token=bearer_token)
+ .select_related("user")
+ .prefetch_related("user__subscription")
+ .afirst()
+ )
if user_with_token:
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
if state.anonymous_mode:
- user = await self.khojuser_manager.filter(username="default").afirst()
+ user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst()
if user:
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
diff --git a/src/khoj/interface/web/base_config.html b/src/khoj/interface/web/base_config.html
index 309fdba6..d9546249 100644
--- a/src/khoj/interface/web/base_config.html
+++ b/src/khoj/interface/web/base_config.html
@@ -109,7 +109,7 @@
display: grid;
grid-template-rows: repeat(3, 1fr);
gap: 8px;
- padding: 24px 16px;
+ padding: 24px 16px 8px;
width: 320px;
height: 180px;
background: var(--background-color);
@@ -121,7 +121,7 @@
div.finalize-buttons {
display: grid;
gap: 8px;
- padding: 24px 16px;
+ padding: 32px 0px 0px;
width: 320px;
border-radius: 4px;
overflow: hidden;
@@ -162,10 +162,13 @@
color: grey;
font-size: 16px;
}
- .card-button-row {
+ .card-description-row {
+ padding-top: 4px;
+ }
+ .card-action-row {
display: grid;
- grid-template-columns: auto;
- text-align: right;
+ grid-auto-flow: row;
+ justify-content: left;
}
.card-button {
border: none;
@@ -271,7 +274,9 @@
100% { transform: rotate(360deg); }
}
-
+ #status {
+ padding-top: 32px;
+ }
div.finalize-actions {
grid-auto-flow: column;
grid-gap: 24px;
@@ -287,6 +292,7 @@
select#chat-models {
margin-bottom: 0;
+ padding: 8px;
}
@@ -343,6 +349,12 @@
width: auto;
}
+ #status {
+ padding-top: 12px;
+ }
+ div.finalize-actions {
+ padding: 12px 0 0;
+ }
div.finalize-buttons {
padding: 0;
}
diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html
index 919350aa..bd4870e4 100644
--- a/src/khoj/interface/web/chat.html
+++ b/src/khoj/interface/web/chat.html
@@ -43,7 +43,7 @@ To get started, just start typing below. You can also type / to see a list of co
let escaped_ref = reference.replaceAll('"', '"');
// Generate HTML for Chat Reference
- let short_ref = escaped_ref.slice(0, 100);
+ let short_ref = escaped_ref.slice(0, 140);
short_ref = short_ref.length < escaped_ref.length ? short_ref + "..." : short_ref;
let referenceButton = document.createElement('button');
referenceButton.innerHTML = short_ref;
@@ -205,8 +205,11 @@ To get started, just start typing below. You can also type / to see a list of co
// Evaluate the contents of new_response_text.innerHTML after all the data has been streamed
const currentHTML = newResponseText.innerHTML;
newResponseText.innerHTML = formatHTMLMessage(currentHTML);
- newResponseText.appendChild(references);
+ if (references != null) {
+ newResponseText.appendChild(references);
+ }
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
+ document.getElementById("chat-input").removeAttribute("disabled");
return;
}
@@ -265,7 +268,6 @@ To get started, just start typing below. You can also type / to see a list of co
});
}
readStream();
- document.getElementById("chat-input").removeAttribute("disabled");
});
}
@@ -417,6 +419,9 @@ To get started, just start typing below. You can also type / to see a list of co
display: block;
}
+ div.references {
+ padding-top: 8px;
+ }
div.reference {
display: grid;
grid-template-rows: auto;
@@ -447,12 +452,12 @@ To get started, just start typing below. You can also type / to see a list of co
cursor: pointer;
transition: background 0.2s ease-in-out;
text-align: left;
- max-height: 50px;
+ max-height: 75px;
transition: max-height 0.3s ease-in-out;
overflow: hidden;
}
button.reference-button.expanded {
- max-height: 200px;
+ max-height: none;
}
button.reference-button::before {
diff --git a/src/khoj/interface/web/config.html b/src/khoj/interface/web/config.html
index b2b7fbb3..497dd31a 100644
--- a/src/khoj/interface/web/config.html
+++ b/src/khoj/interface/web/config.html
@@ -29,12 +29,12 @@
{% endif %}
-
-
-
+
+
+
@@ -61,13 +61,13 @@
{% endif %}
-
-
-
+
+
+
@@ -94,13 +94,26 @@
{% endif %}
+
+
+
-
-
+
+
+
+
+
+
+
+
+
+
@@ -221,23 +234,7 @@
{% endif %}
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
{% endblock %}
diff --git a/src/khoj/interface/web/content_source_notion_input.html b/src/khoj/interface/web/content_source_notion_input.html
index 18eb5a7f..d5427ab3 100644
--- a/src/khoj/interface/web/content_source_notion_input.html
+++ b/src/khoj/interface/web/content_source_notion_input.html
@@ -41,6 +41,11 @@
return;
}
+ const submitButton = document.getElementById("submit");
+ submitButton.disabled = true;
+ submitButton.innerHTML = "Saving...";
+
+ // Save Notion config on server
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
fetch('/api/config/data/content-source/notion', {
method: 'POST',
@@ -53,15 +58,39 @@
})
})
.then(response => response.json())
+ .then(data => { data["status"] === "ok" ? data : Promise.reject(data) })
+ .catch(error => {
+ document.getElementById("success").innerHTML = "โ ๏ธ Failed to save Notion settings.";
+ document.getElementById("success").style.display = "block";
+ submitButton.innerHTML = "โ ๏ธ Failed to save settings";
+ setTimeout(function() {
+ submitButton.innerHTML = "Save";
+ submitButton.disabled = false;
+ }, 2000);
+ return;
+ });
+
+ // Index Notion content on server
+ fetch('/api/update?t=notion')
+ .then(response => response.json())
+ .then(data => { data["status"] == "ok" ? data : Promise.reject(data) })
.then(data => {
- if (data["status"] == "ok") {
- document.getElementById("success").innerHTML = "โ
Successfully updated. Go to your settings page to complete setup.";
- document.getElementById("success").style.display = "block";
- } else {
- document.getElementById("success").innerHTML = "โ ๏ธ Failed to update settings.";
- document.getElementById("success").style.display = "block";
- }
+ document.getElementById("success").style.display = "none";
+ submitButton.innerHTML = "โ
Successfully updated";
+ setTimeout(function() {
+ submitButton.innerHTML = "Save";
+ submitButton.disabled = false;
+ }, 2000);
})
+ .catch(error => {
+ document.getElementById("success").innerHTML = "โ ๏ธ Failed to save Notion content.";
+ document.getElementById("success").style.display = "block";
+ submitButton.innerHTML = "โ ๏ธ Failed to save content";
+ setTimeout(function() {
+ submitButton.innerHTML = "Save";
+ submitButton.disabled = false;
+ }, 2000);
+ });
});
{% endblock %}
diff --git a/src/khoj/interface/web/search.html b/src/khoj/interface/web/search.html
index dcd98ede..5331ea92 100644
--- a/src/khoj/interface/web/search.html
+++ b/src/khoj/interface/web/search.html
@@ -189,7 +189,6 @@
})
.then(response => response.json())
.then(data => {
- console.log(data);
document.getElementById("results").innerHTML = render_results(data, query, type);
});
}
diff --git a/src/khoj/main.py b/src/khoj/main.py
index f92e8cbe..9fa65fc3 100644
--- a/src/khoj/main.py
+++ b/src/khoj/main.py
@@ -56,6 +56,7 @@ locale.setlocale(locale.LC_ALL, "")
from khoj.configure import configure_routes, initialize_server, configure_middleware
from khoj.utils import state
from khoj.utils.cli import cli
+from khoj.utils.initialization import initialization
# Setup Logger
rich_handler = RichHandler(rich_tracebacks=True)
@@ -74,8 +75,7 @@ def run(should_start_server=True):
args = cli(state.cli_args)
set_state(args)
- # Create app directory, if it doesn't exist
- state.config_file.parent.mkdir(parents=True, exist_ok=True)
+ logger.info(f"๐ Initializing Khoj v{state.khoj_version}")
# Set Logging Level
if args.verbose == 0:
@@ -83,6 +83,11 @@ def run(should_start_server=True):
elif args.verbose >= 1:
logger.setLevel(logging.DEBUG)
+ initialization()
+
+ # Create app directory, if it doesn't exist
+ state.config_file.parent.mkdir(parents=True, exist_ok=True)
+
# Set Log File
fh = logging.FileHandler(state.config_file.parent / "khoj.log", encoding="utf-8")
fh.setLevel(logging.DEBUG)
@@ -97,7 +102,7 @@ def run(should_start_server=True):
configure_routes(app)
# Mount Django and Static Files
- app.mount("/django", django_app, name="django")
+ app.mount("/server", django_app, name="server")
static_dir = "static"
if not os.path.exists(static_dir):
os.mkdir(static_dir)
diff --git a/src/khoj/processor/conversation/gpt4all/chat_model.py b/src/khoj/processor/conversation/gpt4all/chat_model.py
index 04a004f0..d3eaa01a 100644
--- a/src/khoj/processor/conversation/gpt4all/chat_model.py
+++ b/src/khoj/processor/conversation/gpt4all/chat_model.py
@@ -55,10 +55,10 @@ def extract_questions_offline(
last_year = datetime.now().year - 1
last_christmas_date = f"{last_year}-12-25"
next_christmas_date = f"{datetime.now().year}-12-25"
- system_prompt = prompts.extract_questions_system_prompt_llamav2.format(
- message=(prompts.system_prompt_message_extract_questions_llamav2)
+ system_prompt = prompts.system_prompt_extract_questions_gpt4all.format(
+ message=(prompts.system_prompt_message_extract_questions_gpt4all)
)
- example_questions = prompts.extract_questions_llamav2_sample.format(
+ example_questions = prompts.extract_questions_gpt4all_sample.format(
query=text,
chat_history=chat_history,
current_date=current_date,
@@ -150,14 +150,14 @@ def converse_offline(
elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references_message):
conversation_primer = user_query
else:
- conversation_primer = prompts.notes_conversation_llamav2.format(
+ conversation_primer = prompts.notes_conversation_gpt4all.format(
query=user_query, references=compiled_references_message
)
# Setup Prompt with Primer or Conversation History
messages = generate_chatml_messages_with_context(
conversation_primer,
- prompts.system_prompt_message_llamav2,
+ prompts.system_prompt_message_gpt4all,
conversation_log,
model_name=model,
max_prompt_size=max_prompt_size,
@@ -183,16 +183,16 @@ def llm_thread(g, messages: List[ChatMessage], model: Any):
conversation_history = messages[1:-1]
formatted_messages = [
- prompts.chat_history_llamav2_from_assistant.format(message=message.content)
+ prompts.khoj_message_gpt4all.format(message=message.content)
if message.role == "assistant"
- else prompts.chat_history_llamav2_from_user.format(message=message.content)
+ else prompts.user_message_gpt4all.format(message=message.content)
for message in conversation_history
]
stop_words = [""]
chat_history = "".join(formatted_messages)
- templated_system_message = prompts.system_prompt_llamav2.format(message=system_message.content)
- templated_user_message = prompts.general_conversation_llamav2.format(query=user_message.content)
+ templated_system_message = prompts.system_prompt_gpt4all.format(message=system_message.content)
+ templated_user_message = prompts.user_message_gpt4all.format(message=user_message.content)
prompted_message = templated_system_message + chat_history + templated_user_message
state.chat_lock.acquire()
diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py
index 73b4f176..b86ebc6b 100644
--- a/src/khoj/processor/conversation/openai/gpt.py
+++ b/src/khoj/processor/conversation/openai/gpt.py
@@ -20,27 +20,6 @@ from khoj.utils.helpers import ConversationCommand, is_none_or_empty
logger = logging.getLogger(__name__)
-def summarize(session, model, api_key=None, temperature=0.5, max_tokens=200):
- """
- Summarize conversation session using the specified OpenAI chat model
- """
- messages = [ChatMessage(content=prompts.summarize_chat.format(), role="system")] + session
-
- # Get Response from GPT
- logger.debug(f"Prompt for GPT: {messages}")
- response = completion_with_backoff(
- messages=messages,
- model_name=model,
- temperature=temperature,
- max_tokens=max_tokens,
- model_kwargs={"stop": ['"""'], "frequency_penalty": 0.2},
- openai_api_key=api_key,
- )
-
- # Extract, Clean Message from GPT's Response
- return str(response.content).replace("\n\n", "")
-
-
def extract_questions(
text,
model: Optional[str] = "gpt-4",
@@ -131,16 +110,14 @@ def converse(
completion_func(chat_response=prompts.no_notes_found.format())
return iter([prompts.no_notes_found.format()])
elif conversation_command == ConversationCommand.General or is_none_or_empty(compiled_references):
- conversation_primer = prompts.general_conversation.format(current_date=current_date, query=user_query)
+ conversation_primer = prompts.general_conversation.format(query=user_query)
else:
- conversation_primer = prompts.notes_conversation.format(
- current_date=current_date, query=user_query, references=compiled_references
- )
+ conversation_primer = prompts.notes_conversation.format(query=user_query, references=compiled_references)
# Setup Prompt with Primer or Conversation History
messages = generate_chatml_messages_with_context(
conversation_primer,
- prompts.personality.format(),
+ prompts.personality.format(current_date=current_date),
conversation_log,
model,
max_prompt_size,
@@ -157,4 +134,5 @@ def converse(
temperature=temperature,
openai_api_key=api_key,
completion_func=completion_func,
+ model_kwargs={"stop": ["Notes:\n["]},
)
diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py
index 130532e0..dce72e1f 100644
--- a/src/khoj/processor/conversation/openai/utils.py
+++ b/src/khoj/processor/conversation/openai/utils.py
@@ -69,15 +69,15 @@ def completion_with_backoff(**kwargs):
reraise=True,
)
def chat_completion_with_backoff(
- messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None
+ messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None, model_kwargs=None
):
g = ThreadedGenerator(compiled_references, completion_func=completion_func)
- t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key))
+ t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, model_kwargs))
t.start()
return g
-def llm_thread(g, messages, model_name, temperature, openai_api_key=None):
+def llm_thread(g, messages, model_name, temperature, openai_api_key=None, model_kwargs=None):
callback_handler = StreamingChatCallbackHandler(g)
chat = ChatOpenAI(
streaming=True,
@@ -86,6 +86,7 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None):
model_name=model_name, # type: ignore
temperature=temperature,
openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
+ model_kwargs=model_kwargs,
request_timeout=20,
max_retries=1,
client=None,
diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py
index ef9100e0..fa9f9d91 100644
--- a/src/khoj/processor/conversation/prompts.py
+++ b/src/khoj/processor/conversation/prompts.py
@@ -4,30 +4,44 @@ from langchain.prompts import PromptTemplate
## Personality
## --
-personality = PromptTemplate.from_template("You are Khoj, a smart, inquisitive and helpful personal assistant.")
+personality = PromptTemplate.from_template(
+ """
+You are Khoj, a smart, inquisitive and helpful personal assistant.
+Use your general knowledge and the past conversation with the user as context to inform your responses.
+You were created by Khoj Inc. with the following capabilities:
+- You *CAN REMEMBER ALL NOTES and PERSONAL INFORMATION FOREVER* that the user ever shares with you.
+- You cannot set reminders.
+- Say "I don't know" or "I don't understand" if you don't know what to say or if you don't know the answer to a question.
+- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations.
+- Sometimes the user will share personal information that needs to be remembered, like an account ID or a residential address. These can be acknowledged with a simple "Got it" or "Okay".
+
+Note: More information about you, the company or other Khoj apps can be found at https://khoj.dev.
+Today is {current_date} in UTC.
+""".strip()
+)
## General Conversation
## --
general_conversation = PromptTemplate.from_template(
"""
-Using your general knowledge and our past conversations as context, answer the following question.
-Current Date: {current_date}
-
-Question: {query}
+{query}
""".strip()
)
+
no_notes_found = PromptTemplate.from_template(
"""
I'm sorry, I couldn't find any relevant notes to respond to your message.
""".strip()
)
-system_prompt_message_llamav2 = f"""You are Khoj, a smart, inquisitive and helpful personal assistant.
+## Conversation Prompts for GPT4All Models
+## --
+system_prompt_message_gpt4all = f"""You are Khoj, a smart, inquisitive and helpful personal assistant.
Using your general knowledge and our past conversations as context, answer the following question.
If you do not know the answer, say 'I don't know.'"""
-system_prompt_message_extract_questions_llamav2 = f"""You are Khoj, a kind and intelligent personal assistant. When the user asks you a question, you ask follow-up questions to clarify the necessary information you need in order to answer from the user's perspective.
+system_prompt_message_extract_questions_gpt4all = f"""You are Khoj, a kind and intelligent personal assistant. When the user asks you a question, you ask follow-up questions to clarify the necessary information you need in order to answer from the user's perspective.
- Write the question as if you can search for the answer on the user's personal notes.
- Try to be as specific as possible. Instead of saying "they" or "it" or "he", use the name of the person or thing you are referring to. For example, instead of saying "Which store did they go to?", say "Which store did Alice and Bob go to?".
- Add as much context from the previous questions and notes as required into your search queries.
@@ -35,61 +49,47 @@ system_prompt_message_extract_questions_llamav2 = f"""You are Khoj, a kind and i
What follow-up questions, if any, will you need to ask to answer the user's question?
"""
-system_prompt_llamav2 = PromptTemplate.from_template(
+system_prompt_gpt4all = PromptTemplate.from_template(
"""
[INST] <>
{message}
<>Hi there! [/INST] Hello! How can I help you today? """
)
-extract_questions_system_prompt_llamav2 = PromptTemplate.from_template(
+system_prompt_extract_questions_gpt4all = PromptTemplate.from_template(
"""
[INST] <>
{message}
<>[/INST]"""
)
-general_conversation_llamav2 = PromptTemplate.from_template(
- """
-[INST] {query} [/INST]
-""".strip()
-)
-
-chat_history_llamav2_from_user = PromptTemplate.from_template(
+user_message_gpt4all = PromptTemplate.from_template(
"""
[INST] {message} [/INST]
""".strip()
)
-chat_history_llamav2_from_assistant = PromptTemplate.from_template(
+khoj_message_gpt4all = PromptTemplate.from_template(
"""
{message}
""".strip()
)
-conversation_llamav2 = PromptTemplate.from_template(
- """
-[INST] {query} [/INST]
-""".strip()
-)
-
## Notes Conversation
## --
notes_conversation = PromptTemplate.from_template(
"""
-Using my personal notes and our past conversations as context, answer the following question.
-Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations.
-These questions should end with a question mark.
-Current Date: {current_date}
+Use my personal notes and our past conversations to inform your response.
+Ask crisp follow-up questions to get additional context, when a helpful response cannot be provided from the provided notes or past conversations.
Notes:
{references}
-Question: {query}
+Query: {query}
""".strip()
)
-notes_conversation_llamav2 = PromptTemplate.from_template(
+notes_conversation_gpt4all = PromptTemplate.from_template(
"""
User's Notes:
{references}
@@ -98,13 +98,6 @@ Question: {query}
)
-## Summarize Chat
-## --
-summarize_chat = PromptTemplate.from_template(
- f"{personality.format()} Summarize the conversation from your first person perspective"
-)
-
-
## Summarize Notes
## --
summarize_notes = PromptTemplate.from_template(
@@ -132,7 +125,10 @@ Question: {user_query}
Answer (in second person):"""
)
-extract_questions_llamav2_sample = PromptTemplate.from_template(
+
+## Extract Questions
+## --
+extract_questions_gpt4all_sample = PromptTemplate.from_template(
"""
[INST] <>Current Date: {current_date}<> [/INST]
[INST] How was my trip to Cambodia? [/INST]
@@ -157,8 +153,6 @@ Use these notes from the user's previous conversations to provide a response:
)
-## Extract Questions
-## --
extract_questions = PromptTemplate.from_template(
"""
You are Khoj, an extremely smart and helpful search assistant with the ability to retrieve information from the user's notes.
diff --git a/src/khoj/processor/embeddings.py b/src/khoj/processor/embeddings.py
index 1e92f27d..a4daa24f 100644
--- a/src/khoj/processor/embeddings.py
+++ b/src/khoj/processor/embeddings.py
@@ -27,5 +27,5 @@ class CrossEncoderModel:
def predict(self, query, hits: List[SearchResponse]):
cross__inp = [[query, hit.additional["compiled"]] for hit in hits]
- cross_scores = self.cross_encoder_model.predict(cross__inp)
+ cross_scores = self.cross_encoder_model.predict(cross__inp, apply_softmax=True)
return cross_scores
diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py
index 5228a6fb..190fc260 100644
--- a/src/khoj/routers/api.py
+++ b/src/khoj/routers/api.py
@@ -7,7 +7,7 @@ import json
from typing import List, Optional, Union, Any
# External Packages
-from fastapi import APIRouter, HTTPException, Header, Request
+from fastapi import APIRouter, Depends, HTTPException, Header, Request
from starlette.authentication import requires
from asgiref.sync import sync_to_async
@@ -36,6 +36,7 @@ from khoj.routers.helpers import (
agenerate_chat_response,
update_telemetry_state,
is_ready_to_chat,
+ ApiUserRateLimiter,
)
from khoj.processor.conversation.prompts import help_message
from khoj.processor.conversation.openai.gpt import extract_questions
@@ -177,11 +178,15 @@ async def set_content_config_github_data(
user = request.user.object
- await adapters.set_user_github_config(
- user=user,
- pat_token=updated_config.pat_token,
- repos=updated_config.repos,
- )
+ try:
+ await adapters.set_user_github_config(
+ user=user,
+ pat_token=updated_config.pat_token,
+ repos=updated_config.repos,
+ )
+ except Exception as e:
+ logger.error(e, exc_info=True)
+ raise HTTPException(status_code=500, detail="Failed to set Github config")
update_telemetry_state(
request=request,
@@ -205,10 +210,14 @@ async def set_content_config_notion_data(
user = request.user.object
- await adapters.set_notion_config(
- user=user,
- token=updated_config.token,
- )
+ try:
+ await adapters.set_notion_config(
+ user=user,
+ token=updated_config.token,
+ )
+ except Exception as e:
+ logger.error(e, exc_info=True)
+ raise HTTPException(status_code=500, detail="Failed to set Github config")
update_telemetry_state(
request=request,
@@ -348,7 +357,7 @@ async def search(
n: Optional[int] = 5,
t: Optional[SearchType] = SearchType.All,
r: Optional[bool] = False,
- score_threshold: Optional[Union[float, None]] = None,
+ max_distance: Optional[Union[float, None]] = None,
dedupe: Optional[bool] = True,
client: Optional[str] = None,
user_agent: Optional[str] = Header(None),
@@ -367,12 +376,12 @@ async def search(
# initialize variables
user_query = q.strip()
results_count = n or 5
- score_threshold = score_threshold if score_threshold is not None else -math.inf
+ max_distance = max_distance if max_distance is not None else math.inf
search_futures: List[concurrent.futures.Future] = []
# return cached results, if available
if user:
- query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}"
+ query_cache_key = f"{user_query}-{n}-{t}-{r}-{max_distance}-{dedupe}"
if query_cache_key in state.query_cache[user.uuid]:
logger.debug(f"Return response from query cache")
return state.query_cache[user.uuid][query_cache_key]
@@ -409,8 +418,7 @@ async def search(
user_query,
t,
question_embedding=encoded_asymmetric_query,
- rank_results=r or False,
- score_threshold=score_threshold,
+ max_distance=max_distance,
)
]
@@ -423,7 +431,6 @@ async def search(
results_count,
state.search_models.image_search,
state.content_index.image,
- score_threshold=score_threshold,
)
]
@@ -446,11 +453,10 @@ async def search(
# Collate results
results += text_search.collate_results(hits, dedupe=dedupe)
- if r:
- results = text_search.rerank_and_sort_results(results, query=defiltered_query)[:results_count]
- else:
# Sort results across all content types and take top results
- results = sorted(results, key=lambda x: float(x.score))[:results_count]
+ results = text_search.rerank_and_sort_results(results, query=defiltered_query, rank_results=r)[
+ :results_count
+ ]
# Cache results
if user:
@@ -575,11 +581,14 @@ async def chat(
request: Request,
q: str,
n: Optional[int] = 5,
+ d: Optional[float] = 0.15,
client: Optional[str] = None,
stream: Optional[bool] = False,
user_agent: Optional[str] = Header(None),
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
+ rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=30, window=60)),
+ rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=500, window=60 * 60 * 24)),
) -> Response:
user = request.user.object
@@ -591,7 +600,7 @@ async def chat(
meta_log = (await ConversationAdapters.aget_conversation_by_user(user)).conversation_log
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
- request, meta_log, q, (n or 5), conversation_command
+ request, meta_log, q, (n or 5), (d or math.inf), conversation_command
)
if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references):
@@ -606,7 +615,7 @@ async def chat(
return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)
# Get the (streamed) chat response from the LLM of choice.
- llm_response = await agenerate_chat_response(
+ llm_response, chat_metadata = await agenerate_chat_response(
defiltered_query,
meta_log,
compiled_references,
@@ -615,6 +624,19 @@ async def chat(
user,
)
+ chat_metadata.update({"conversation_command": conversation_command.value})
+
+ update_telemetry_state(
+ request=request,
+ telemetry_type="api",
+ api="chat",
+ client=client,
+ user_agent=user_agent,
+ referer=referer,
+ host=host,
+ metadata=chat_metadata,
+ )
+
if llm_response is None:
return Response(content=llm_response, media_type="text/plain", status_code=500)
@@ -634,16 +656,6 @@ async def chat(
response_obj = {"response": actual_response, "context": compiled_references}
- update_telemetry_state(
- request=request,
- telemetry_type="api",
- api="chat",
- client=client,
- user_agent=user_agent,
- referer=referer,
- host=host,
- )
-
return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200)
@@ -652,6 +664,7 @@ async def extract_references_and_questions(
meta_log: dict,
q: str,
n: int,
+ d: float,
conversation_type: ConversationCommand = ConversationCommand.Default,
):
user = request.user.object if request.user.is_authenticated else None
@@ -663,7 +676,7 @@ async def extract_references_and_questions(
if conversation_type == ConversationCommand.General:
return compiled_references, inferred_queries, q
- if not sync_to_async(EntryAdapters.user_has_entries)(user=user):
+ if not await sync_to_async(EntryAdapters.user_has_entries)(user=user):
logger.warning(
"No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes."
)
@@ -712,7 +725,7 @@ async def extract_references_and_questions(
request=request,
n=n_items,
r=True,
- score_threshold=-5.0,
+ max_distance=d,
dedupe=False,
)
)
diff --git a/src/khoj/routers/auth.py b/src/khoj/routers/auth.py
index 4a3cbcef..2c013bc8 100644
--- a/src/khoj/routers/auth.py
+++ b/src/khoj/routers/auth.py
@@ -16,6 +16,7 @@ from google.auth.transport import requests as google_requests
# Internal Packages
from database.adapters import get_khoj_tokens, get_or_create_user, create_khoj_token, delete_khoj_token
+from khoj.routers.helpers import update_telemetry_state
from khoj.utils import state
@@ -95,6 +96,16 @@ async def auth(request: Request):
if khoj_user:
request.session["user"] = dict(idinfo)
+ if not khoj_user.last_login:
+ update_telemetry_state(
+ request=request,
+ telemetry_type="api",
+ api="create_user",
+ metadata={"user_id": str(khoj_user.uuid)},
+ )
+ logger.log(logging.INFO, f"New User Created: {khoj_user.uuid}")
+ RedirectResponse(url="/?status=welcome")
+
return RedirectResponse(url="/")
diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py
index ebb661ef..a1af10f2 100644
--- a/src/khoj/routers/helpers.py
+++ b/src/khoj/routers/helpers.py
@@ -1,21 +1,27 @@
-import logging
+# Standard Packages
import asyncio
+from collections import defaultdict
+from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from functools import partial
-from typing import Iterator, List, Optional, Union
-from concurrent.futures import ThreadPoolExecutor
+import logging
+from time import time
+from typing import Iterator, List, Optional, Union, Tuple, Dict
+# External Packages
from fastapi import HTTPException, Request
+# Internal Packages
from khoj.utils import state
from khoj.utils.config import GPT4AllProcessorModel
from khoj.utils.helpers import ConversationCommand, log_telemetry
from khoj.processor.conversation.openai.gpt import converse
from khoj.processor.conversation.gpt4all.chat_model import converse_offline
from khoj.processor.conversation.utils import message_to_log, ThreadedGenerator
-from database.models import KhojUser
+from database.models import KhojUser, Subscription
from database.adapters import ConversationAdapters
+
logger = logging.getLogger(__name__)
executor = ThreadPoolExecutor(max_workers=1)
@@ -61,12 +67,15 @@ def update_telemetry_state(
metadata: Optional[dict] = None,
):
user: KhojUser = request.user.object if request.user.is_authenticated else None
+ subscription: Subscription = user.subscription if user and user.subscription else None
user_state = {
"client_host": request.client.host if request.client else None,
"user_agent": user_agent or "unknown",
"referer": referer or "unknown",
"host": host or "unknown",
"server_id": str(user.uuid) if user else None,
+ "subscription_type": subscription.type if subscription else None,
+ "is_recurring": subscription.is_recurring if subscription else None,
}
if metadata:
@@ -109,7 +118,7 @@ def generate_chat_response(
inferred_queries: List[str] = [],
conversation_command: ConversationCommand = ConversationCommand.Default,
user: KhojUser = None,
-) -> Union[ThreadedGenerator, Iterator[str]]:
+) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
def _save_to_conversation_log(
q: str,
chat_response: str,
@@ -132,6 +141,8 @@ def generate_chat_response(
chat_response = None
logger.debug(f"Conversation Type: {conversation_command.name}")
+ metadata = {}
+
try:
partial_completion = partial(
_save_to_conversation_log,
@@ -148,8 +159,8 @@ def generate_chat_response(
conversation_config = ConversationAdapters.get_default_conversation_config()
openai_chat_config = ConversationAdapters.get_openai_conversation_config()
if offline_chat_config and offline_chat_config.enabled and conversation_config.model_type == "offline":
- if state.gpt4all_processor_config.loaded_model is None:
- state.gpt4all_processor_config = GPT4AllProcessorModel(offline_chat_config.chat_model)
+ if state.gpt4all_processor_config is None or state.gpt4all_processor_config.loaded_model is None:
+ state.gpt4all_processor_config = GPT4AllProcessorModel(conversation_config.chat_model)
loaded_model = state.gpt4all_processor_config.loaded_model
chat_response = converse_offline(
@@ -179,8 +190,33 @@ def generate_chat_response(
tokenizer_name=conversation_config.tokenizer,
)
+ metadata.update({"chat_model": conversation_config.chat_model})
+
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
- return chat_response
+ return chat_response, metadata
+
+
+class ApiUserRateLimiter:
+ def __init__(self, requests: int, window: int):
+ self.requests = requests
+ self.window = window
+ self.cache: dict[str, list[float]] = defaultdict(list)
+
+ def __call__(self, request: Request):
+ user: KhojUser = request.user.object
+ user_requests = self.cache[user.uuid]
+
+ # Remove requests outside of the time window
+ cutoff = time() - self.window
+ while user_requests and user_requests[0] < cutoff:
+ user_requests.pop(0)
+
+ # Check if the user has exceeded the rate limit
+ if len(user_requests) >= self.requests:
+ raise HTTPException(status_code=429, detail="Too Many Requests")
+
+ # Add the current request to the cache
+ user_requests.append(time())
diff --git a/src/khoj/search_type/image_search.py b/src/khoj/search_type/image_search.py
index d7f486af..214118fc 100644
--- a/src/khoj/search_type/image_search.py
+++ b/src/khoj/search_type/image_search.py
@@ -146,7 +146,7 @@ def extract_metadata(image_name):
async def query(
- raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = -math.inf
+ raw_query, count, search_model: ImageSearchModel, content: ImageContent, score_threshold: float = math.inf
):
# Set query to image content if query is of form file:/path/to/file.png
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
@@ -167,7 +167,8 @@ async def query(
# Compute top_k ranked images based on cosine-similarity b/w query and all image embeddings.
with timer("Search Time", logger):
image_hits = {
- result["corpus_id"]: {"image_score": result["score"], "score": result["score"]}
+ # Map scores to distance metric by multiplying by -1
+ result["corpus_id"]: {"image_score": -1 * result["score"], "score": -1 * result["score"]}
for result in util.semantic_search(query_embedding, content.image_embeddings, top_k=count)[0]
}
@@ -204,7 +205,7 @@ async def query(
]
# Filter results by score threshold
- hits = [hit for hit in hits if hit["image_score"] >= score_threshold]
+ hits = [hit for hit in hits if hit["image_score"] <= score_threshold]
# Sort the images based on their combined metadata, image scores
return sorted(hits, key=lambda hit: hit["score"], reverse=True)
diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py
index ba2fc9ec..2b99ed66 100644
--- a/src/khoj/search_type/text_search.py
+++ b/src/khoj/search_type/text_search.py
@@ -104,8 +104,7 @@ async def query(
raw_query: str,
type: SearchType = SearchType.All,
question_embedding: Union[torch.Tensor, None] = None,
- rank_results: bool = False,
- score_threshold: float = -math.inf,
+ max_distance: float = math.inf,
) -> Tuple[List[dict], List[Entry]]:
"Search for entries that answer the query"
@@ -127,6 +126,7 @@ async def query(
max_results=top_k,
file_type_filter=file_type,
raw_query=raw_query,
+ max_distance=max_distance,
).all()
hits = await sync_to_async(list)(hits) # type: ignore[call-arg]
@@ -177,12 +177,16 @@ def deduplicated_search_responses(hits: List[SearchResponse]):
)
-def rerank_and_sort_results(hits, query):
+def rerank_and_sort_results(hits, query, rank_results):
+ # If we have more than one result and reranking is enabled
+ rank_results = rank_results and len(list(hits)) > 1
+
# Score all retrieved entries using the cross-encoder
- hits = cross_encoder_score(query, hits)
+ if rank_results:
+ hits = cross_encoder_score(query, hits)
# Sort results by cross-encoder score followed by bi-encoder score
- hits = sort_results(rank_results=True, hits=hits)
+ hits = sort_results(rank_results=rank_results, hits=hits)
return hits
@@ -217,9 +221,9 @@ def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchRe
with timer("Cross-Encoder Predict Time", logger, state.device):
cross_scores = state.cross_encoder_model.predict(query, hits)
- # Store cross-encoder scores in results dictionary for ranking
+ # Convert cross-encoder scores to distances and pass in hits for reranking
for idx in range(len(cross_scores)):
- hits[idx]["cross_score"] = cross_scores[idx]
+ hits[idx]["cross_score"] = 1 - cross_scores[idx]
return hits
@@ -227,7 +231,7 @@ def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchRe
def sort_results(rank_results: bool, hits: List[dict]) -> List[dict]:
"""Order results by cross-encoder score followed by bi-encoder score"""
with timer("Rank Time", logger, state.device):
- hits.sort(key=lambda x: x["score"], reverse=True) # sort by bi-encoder score
+ hits.sort(key=lambda x: x["score"]) # sort by bi-encoder score
if rank_results:
- hits.sort(key=lambda x: x["cross_score"], reverse=True) # sort by cross-encoder score
+ hits.sort(key=lambda x: x["cross_score"]) # sort by cross-encoder score
return hits
diff --git a/src/khoj/utils/constants.py b/src/khoj/utils/constants.py
index e9d431c6..8a106153 100644
--- a/src/khoj/utils/constants.py
+++ b/src/khoj/utils/constants.py
@@ -6,6 +6,7 @@ empty_escape_sequences = "\n|\r|\t| "
app_env_filepath = "~/.khoj/env"
telemetry_server = "https://khoj.beta.haletic.com/v1/telemetry"
content_directory = "~/.khoj/content/"
+default_offline_chat_model = "mistral-7b-instruct-v0.1.Q4_0.gguf"
empty_config = {
"search-type": {
diff --git a/src/khoj/utils/initialization.py b/src/khoj/utils/initialization.py
new file mode 100644
index 00000000..c797f848
--- /dev/null
+++ b/src/khoj/utils/initialization.py
@@ -0,0 +1,98 @@
+import logging
+import os
+
+from database.models import (
+ KhojUser,
+ OfflineChatProcessorConversationConfig,
+ OpenAIProcessorConversationConfig,
+ ChatModelOptions,
+)
+
+from khoj.utils.constants import default_offline_chat_model
+
+from database.adapters import ConversationAdapters
+
+
+logger = logging.getLogger(__name__)
+
+
+def initialization():
+ def _create_admin_user():
+ logger.info(
+ "๐ฉโโ๏ธ Setting up admin user. These credentials will allow you to configure your server at /server/admin."
+ )
+ email_addr = os.getenv("KHOJ_ADMIN_EMAIL") or input("Email: ")
+ password = os.getenv("KHOJ_ADMIN_PASSWORD") or input("Password: ")
+ admin_user = KhojUser.objects.create_superuser(email=email_addr, username=email_addr, password=password)
+ logger.info(f"๐ฉโโ๏ธ Created admin user: {admin_user.email}")
+
+ def _create_chat_configuration():
+ logger.info(
+ "๐ฃ๏ธ Configure chat models available to your server. You can always update these at /server/admin using the credentials of your admin account"
+ )
+ try:
+ # Some environments don't support interactive input. We catch the exception and return if that's the case. The admin can still configure their settings from the admin page.
+ input()
+ except EOFError:
+ return
+
+ try:
+ # Note: gpt4all package is not available on all devices.
+ # So ensure gpt4all package is installed before continuing this step.
+ import gpt4all
+
+ use_offline_model = input("Use offline chat model? (y/n): ")
+ if use_offline_model == "y":
+ logger.info("๐ฃ๏ธ Setting up offline chat model")
+ OfflineChatProcessorConversationConfig.objects.create(enabled=True)
+
+ offline_chat_model = input(
+ f"Enter the name of the offline chat model you want to use, based on the models in HuggingFace (press enter to use the default: {default_offline_chat_model}): "
+ )
+ if offline_chat_model == "":
+ ChatModelOptions.objects.create(
+ chat_model=default_offline_chat_model, model_type=ChatModelOptions.ModelType.OFFLINE
+ )
+ else:
+ max_tokens = input("Enter the maximum number of tokens to use for the offline chat model:")
+ tokenizer = input("Enter the tokenizer to use for the offline chat model:")
+ ChatModelOptions.objects.create(
+ chat_model=offline_chat_model,
+ model_type=ChatModelOptions.ModelType.OFFLINE,
+ max_prompt_size=max_tokens,
+ tokenizer=tokenizer,
+ )
+ except ModuleNotFoundError as e:
+ logger.warning("Offline models are not supported on this device.")
+
+ use_openai_model = input("Use OpenAI chat model? (y/n): ")
+
+ if use_openai_model == "y":
+ logger.info("๐ฃ๏ธ Setting up OpenAI chat model")
+ api_key = input("Enter your OpenAI API key: ")
+ OpenAIProcessorConversationConfig.objects.create(api_key=api_key)
+ openai_chat_model = input("Enter the name of the OpenAI chat model you want to use: ")
+ max_tokens = input("Enter the maximum number of tokens to use for the OpenAI chat model:")
+ ChatModelOptions.objects.create(
+ chat_model=openai_chat_model, model_type=ChatModelOptions.ModelType.OPENAI, max_tokens=max_tokens
+ )
+
+ logger.info("๐ฃ๏ธ Chat model configuration complete")
+
+ admin_user = KhojUser.objects.filter(is_staff=True).first()
+ if admin_user is None:
+ while True:
+ try:
+ _create_admin_user()
+ break
+ except Exception as e:
+ logger.error(f"๐จ Failed to create admin user: {e}", exc_info=True)
+
+ chat_config = ConversationAdapters.get_default_conversation_config()
+ if admin_user is None and chat_config is None:
+ while True:
+ try:
+ _create_chat_configuration()
+ break
+ except Exception as e:
+ logger.error(f"๐จ Failed to create chat configuration: {e}", exc_info=True)
diff --git a/tests/conftest.py b/tests/conftest.py
index 59104123..95fa9a99 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -43,6 +43,7 @@ from tests.helpers import (
OpenAIProcessorConversationConfigFactory,
OfflineChatProcessorConversationConfigFactory,
UserConversationProcessorConfigFactory,
+ SubscriptionFactory,
)
@@ -69,7 +70,9 @@ def search_config() -> SearchConfig:
@pytest.mark.django_db
@pytest.fixture
def default_user():
- return UserFactory()
+ user = UserFactory()
+ SubscriptionFactory(user=user)
+ return user
@pytest.mark.django_db
@@ -78,11 +81,31 @@ def default_user2():
if KhojUser.objects.filter(username="default").exists():
return KhojUser.objects.get(username="default")
- return KhojUser.objects.create(
+ user = KhojUser.objects.create(
username="default",
email="default@example.com",
password="default",
)
+ SubscriptionFactory(user=user)
+ return user
+
+
+@pytest.mark.django_db
+@pytest.fixture
+def default_user3():
+ """
+ This user should not have any data associated with it
+ """
+ if KhojUser.objects.filter(username="default3").exists():
+ return KhojUser.objects.get(username="default3")
+
+ user = KhojUser.objects.create(
+ username="default3",
+ email="default3@example.com",
+ password="default3",
+ )
+ SubscriptionFactory(user=user)
+ return user
@pytest.mark.django_db
@@ -111,6 +134,19 @@ def api_user2(default_user2):
)
+@pytest.mark.django_db
+@pytest.fixture
+def api_user3(default_user3):
+ if KhojApiUser.objects.filter(user=default_user3).exists():
+ return KhojApiUser.objects.get(user=default_user3)
+
+ return KhojApiUser.objects.create(
+ user=default_user3,
+ name="api-key",
+ token="kk-diff-secret-3",
+ )
+
+
@pytest.fixture(scope="session")
def search_models(search_config: SearchConfig):
search_models = SearchModels()
@@ -206,7 +242,7 @@ def chat_client(search_config: SearchConfig, default_user2: KhojUser):
OpenAIProcessorConversationConfigFactory()
UserConversationProcessorConfigFactory(user=default_user2, setting=chat_model)
- state.anonymous_mode = False
+ state.anonymous_mode = True
app = FastAPI()
@@ -224,7 +260,9 @@ def chat_client_no_background(search_config: SearchConfig, default_user2: KhojUs
# Initialize Processor from Config
if os.getenv("OPENAI_API_KEY"):
+ chat_model = ChatModelOptionsFactory(chat_model="gpt-3.5-turbo", model_type="openai")
OpenAIProcessorConversationConfigFactory()
+ UserConversationProcessorConfigFactory(user=default_user2, setting=chat_model)
state.anonymous_mode = True
diff --git a/tests/helpers.py b/tests/helpers.py
index 3aa7c435..03f3f9c7 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -9,6 +9,7 @@ from database.models import (
OpenAIProcessorConversationConfig,
UserConversationConfig,
Conversation,
+ Subscription,
)
@@ -68,3 +69,13 @@ class ConversationFactory(factory.django.DjangoModelFactory):
model = Conversation
user = factory.SubFactory(UserFactory)
+
+
+class SubscriptionFactory(factory.django.DjangoModelFactory):
+ class Meta:
+ model = Subscription
+
+ user = factory.SubFactory(UserFactory)
+ type = "standard"
+ is_recurring = False
+ renewal_date = "2100-04-01"
diff --git a/tests/test_client.py b/tests/test_client.py
index c105c605..1894577c 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -16,7 +16,7 @@ from khoj.utils.state import search_models, content_index, config
from khoj.search_type import text_search, image_search
from khoj.utils.rawconfig import ContentConfig, SearchConfig
from khoj.processor.org_mode.org_to_entries import OrgToEntries
-from database.models import KhojUser
+from database.models import KhojUser, KhojApiUser
from database.adapters import EntryAdapters
@@ -351,6 +351,24 @@ def test_different_user_data_not_accessed(client, sample_org_data, default_user:
assert len(response.json()) == 1 and response.json()["detail"] == "Forbidden"
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.django_db(transaction=True)
+def test_user_no_data_returns_empty(client, sample_org_data, api_user3: KhojApiUser):
+ # Arrange
+ token = api_user3.token
+ headers = {"Authorization": "Bearer " + token}
+ user_query = quote("How to git install application?")
+
+ # Act
+ response = client.get(f"/api/search?q={user_query}&n=1&t=org", headers=headers)
+
+ # Assert
+ assert response.status_code == 200
+ # assert actual response has no data as the default_user3, though other users have data
+ assert len(response.json()) == 0
+ assert response.json() == []
+
+
def get_sample_files_data():
return [
("files", ("path/to/filename.org", "* practicing piano", "text/org")),
diff --git a/tests/test_openai_chat_director.py b/tests/test_openai_chat_director.py
index 14a73f15..a8c85787 100644
--- a/tests/test_openai_chat_director.py
+++ b/tests/test_openai_chat_director.py
@@ -307,6 +307,8 @@ def test_ask_for_clarification_if_not_enough_context_in_question(chat_client_no_
"which one is",
"which of namita's sons",
"the birth order",
+ "provide more context",
+ "provide me with more context",
]
assert response.status_code == 200
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (