[Multi-User Part 1]: Enable storage of settings for plaintext files based on user account (#498)

- Partition configuration for indexing local data based on user accounts
- Store indexed data in an underlying postgres db using the `pgvector` extension
- Add migrations for all relevant user data and embeddings generation. Very little performance optimization has been done for the lookup time
- Apply filters using SQL queries
- Start removing many server-level configuration settings
- Configure GitHub test actions to run during any PR. Update the test action to run in a containerized environment with a DB.
- Update the Docker image and docker-compose.yml to work with the new application design
This commit is contained in:
sabaimran
2023-10-26 09:42:29 -07:00
committed by GitHub
parent 963cd165eb
commit 216acf545f
60 changed files with 1827 additions and 1792 deletions

View File

@@ -1,15 +1,30 @@
from typing import Type, TypeVar
from typing import Type, TypeVar, List
import uuid
from datetime import date
from django.db import models
from django.contrib.sessions.backends.db import SessionStore
from pgvector.django import CosineDistance
from django.db.models.manager import BaseManager
from django.db.models import Q
from torch import Tensor
# Import sync_to_async from Django Channels
from asgiref.sync import sync_to_async
from fastapi import HTTPException
from database.models import KhojUser, GoogleUser, NotionConfig
from database.models import (
KhojUser,
GoogleUser,
NotionConfig,
GithubConfig,
Embeddings,
GithubRepoConfig,
)
from khoj.search_filter.word_filter import WordFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.date_filter import DateFilter
ModelType = TypeVar("ModelType", bound=models.Model)
@@ -40,9 +55,7 @@ async def get_or_create_user(token: dict) -> KhojUser:
async def create_google_user(token: dict) -> KhojUser:
user_info = token.get("userinfo")
user = await KhojUser.objects.acreate(
username=user_info.get("email"), email=user_info.get("email"), uuid=uuid.uuid4()
)
user = await KhojUser.objects.acreate(username=user_info.get("email"), email=user_info.get("email"))
await user.asave()
await GoogleUser.objects.acreate(
sub=user_info.get("sub"),
@@ -76,3 +89,149 @@ async def retrieve_user(session_id: str) -> KhojUser:
if not user:
raise HTTPException(status_code=401, detail="Invalid user")
return user
def get_all_users() -> BaseManager[KhojUser]:
return KhojUser.objects.all()
def get_user_github_config(user: KhojUser):
config = GithubConfig.objects.filter(user=user).prefetch_related("githubrepoconfig").first()
if not config:
return None
return config
def get_user_notion_config(user: KhojUser):
config = NotionConfig.objects.filter(user=user).first()
if not config:
return None
return config
async def set_text_content_config(user: KhojUser, object: Type[models.Model], updated_config):
deduped_files = list(set(updated_config.input_files)) if updated_config.input_files else None
deduped_filters = list(set(updated_config.input_filter)) if updated_config.input_filter else None
await object.objects.filter(user=user).adelete()
await object.objects.acreate(
input_files=deduped_files,
input_filter=deduped_filters,
index_heading_entries=updated_config.index_heading_entries,
user=user,
)
async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
config = await GithubConfig.objects.filter(user=user).afirst()
if not config:
config = await GithubConfig.objects.acreate(pat_token=pat_token, user=user)
else:
config.pat_token = pat_token
await config.asave()
await config.githubrepoconfig.all().adelete()
for repo in repos:
await GithubRepoConfig.objects.acreate(
name=repo["name"], owner=repo["owner"], branch=repo["branch"], github_config=config
)
return config
class EmbeddingsAdapters:
word_filer = WordFilter()
file_filter = FileFilter()
date_filter = DateFilter()
@staticmethod
def does_embedding_exist(user: KhojUser, hashed_value: str) -> bool:
return Embeddings.objects.filter(user=user, hashed_value=hashed_value).exists()
@staticmethod
def delete_embedding_by_file(user: KhojUser, file_path: str):
deleted_count, _ = Embeddings.objects.filter(user=user, file_path=file_path).delete()
return deleted_count
@staticmethod
def delete_all_embeddings(user: KhojUser, file_type: str):
deleted_count, _ = Embeddings.objects.filter(user=user, file_type=file_type).delete()
return deleted_count
@staticmethod
def get_existing_entry_hashes_by_file(user: KhojUser, file_path: str):
return Embeddings.objects.filter(user=user, file_path=file_path).values_list("hashed_value", flat=True)
@staticmethod
def delete_embedding_by_hash(user: KhojUser, hashed_values: List[str]):
Embeddings.objects.filter(user=user, hashed_value__in=hashed_values).delete()
@staticmethod
def get_embeddings_by_date_filter(embeddings: BaseManager[Embeddings], start_date: date, end_date: date):
return embeddings.filter(
embeddingsdates__date__gte=start_date,
embeddingsdates__date__lte=end_date,
)
@staticmethod
async def user_has_embeddings(user: KhojUser):
return await Embeddings.objects.filter(user=user).aexists()
@staticmethod
def apply_filters(user: KhojUser, query: str, file_type_filter: str = None):
q_filter_terms = Q()
explicit_word_terms = EmbeddingsAdapters.word_filer.get_filter_terms(query)
file_filters = EmbeddingsAdapters.file_filter.get_filter_terms(query)
date_filters = EmbeddingsAdapters.date_filter.get_query_date_range(query)
if len(explicit_word_terms) == 0 and len(file_filters) == 0 and len(date_filters) == 0:
return Embeddings.objects.filter(user=user)
for term in explicit_word_terms:
if term.startswith("+"):
q_filter_terms &= Q(raw__icontains=term[1:])
elif term.startswith("-"):
q_filter_terms &= ~Q(raw__icontains=term[1:])
q_file_filter_terms = Q()
if len(file_filters) > 0:
for term in file_filters:
q_file_filter_terms |= Q(file_path__regex=term)
q_filter_terms &= q_file_filter_terms
if len(date_filters) > 0:
min_date, max_date = date_filters
if min_date is not None:
# Convert the min_date timestamp to yyyy-mm-dd format
formatted_min_date = date.fromtimestamp(min_date).strftime("%Y-%m-%d")
q_filter_terms &= Q(embeddings_dates__date__gte=formatted_min_date)
if max_date is not None:
# Convert the max_date timestamp to yyyy-mm-dd format
formatted_max_date = date.fromtimestamp(max_date).strftime("%Y-%m-%d")
q_filter_terms &= Q(embeddings_dates__date__lte=formatted_max_date)
relevant_embeddings = Embeddings.objects.filter(user=user).filter(
q_filter_terms,
)
if file_type_filter:
relevant_embeddings = relevant_embeddings.filter(file_type=file_type_filter)
return relevant_embeddings
@staticmethod
def search_with_embeddings(
user: KhojUser, embeddings: Tensor, max_results: int = 10, file_type_filter: str = None, raw_query: str = None
):
relevant_embeddings = EmbeddingsAdapters.apply_filters(user, raw_query, file_type_filter)
relevant_embeddings = relevant_embeddings.filter(user=user).annotate(
distance=CosineDistance("embeddings", embeddings)
)
if file_type_filter:
relevant_embeddings = relevant_embeddings.filter(file_type=file_type_filter)
relevant_embeddings = relevant_embeddings.order_by("distance")
return relevant_embeddings[:max_results]
@staticmethod
def get_unique_file_types(user: KhojUser):
return Embeddings.objects.filter(user=user).values_list("file_type", flat=True).distinct()

View File

@@ -1,79 +0,0 @@
# Generated by Django 4.2.5 on 2023-09-27 17:52
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
import uuid
class Migration(migrations.Migration):
dependencies = [
("database", "0002_googleuser"),
]
operations = [
migrations.CreateModel(
name="Configuration",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
],
),
migrations.CreateModel(
name="ConversationProcessorConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("conversation", models.JSONField()),
("enable_offline_chat", models.BooleanField(default=False)),
],
),
migrations.CreateModel(
name="GithubConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("pat_token", models.CharField(max_length=200)),
("compressed_jsonl", models.CharField(max_length=300)),
("embeddings_file", models.CharField(max_length=300)),
(
"config",
models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to="database.configuration"),
),
],
),
migrations.AddField(
model_name="khojuser",
name="uuid",
field=models.UUIDField(verbose_name=models.UUIDField(default=uuid.uuid4, editable=False)),
preserve_default=False,
),
migrations.CreateModel(
name="NotionConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("token", models.CharField(max_length=200)),
("compressed_jsonl", models.CharField(max_length=300)),
("embeddings_file", models.CharField(max_length=300)),
(
"config",
models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to="database.configuration"),
),
],
),
migrations.CreateModel(
name="GithubRepoConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("name", models.CharField(max_length=200)),
("owner", models.CharField(max_length=200)),
("branch", models.CharField(max_length=200)),
(
"github_config",
models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="database.githubconfig"),
),
],
),
migrations.AddField(
model_name="configuration",
name="user",
field=models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL),
),
]

View File

@@ -0,0 +1,10 @@
from django.db import migrations
from pgvector.django import VectorExtension
class Migration(migrations.Migration):
dependencies = [
("database", "0002_googleuser"),
]
operations = [VectorExtension()]

View File

@@ -0,0 +1,193 @@
# Generated by Django 4.2.5 on 2023-10-11 22:24
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
import pgvector.django
import uuid
class Migration(migrations.Migration):
dependencies = [
("database", "0003_vector_extension"),
]
operations = [
migrations.CreateModel(
name="ConversationProcessorConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("conversation", models.JSONField()),
("enable_offline_chat", models.BooleanField(default=False)),
],
options={
"abstract": False,
},
),
migrations.CreateModel(
name="GithubConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("pat_token", models.CharField(max_length=200)),
],
options={
"abstract": False,
},
),
migrations.AddField(
model_name="khojuser",
name="uuid",
field=models.UUIDField(default=1234, verbose_name=models.UUIDField(default=uuid.uuid4, editable=False)),
preserve_default=False,
),
migrations.CreateModel(
name="NotionConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("token", models.CharField(max_length=200)),
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
],
options={
"abstract": False,
},
),
migrations.CreateModel(
name="LocalPlaintextConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("input_files", models.JSONField(default=list, null=True)),
("input_filter", models.JSONField(default=list, null=True)),
("index_heading_entries", models.BooleanField(default=False)),
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
],
options={
"abstract": False,
},
),
migrations.CreateModel(
name="LocalPdfConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("input_files", models.JSONField(default=list, null=True)),
("input_filter", models.JSONField(default=list, null=True)),
("index_heading_entries", models.BooleanField(default=False)),
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
],
options={
"abstract": False,
},
),
migrations.CreateModel(
name="LocalOrgConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("input_files", models.JSONField(default=list, null=True)),
("input_filter", models.JSONField(default=list, null=True)),
("index_heading_entries", models.BooleanField(default=False)),
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
],
options={
"abstract": False,
},
),
migrations.CreateModel(
name="LocalMarkdownConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("input_files", models.JSONField(default=list, null=True)),
("input_filter", models.JSONField(default=list, null=True)),
("index_heading_entries", models.BooleanField(default=False)),
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
],
options={
"abstract": False,
},
),
migrations.CreateModel(
name="GithubRepoConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("name", models.CharField(max_length=200)),
("owner", models.CharField(max_length=200)),
("branch", models.CharField(max_length=200)),
(
"github_config",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="githubrepoconfig",
to="database.githubconfig",
),
),
],
options={
"abstract": False,
},
),
migrations.AddField(
model_name="githubconfig",
name="user",
field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL),
),
migrations.CreateModel(
name="Embeddings",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("embeddings", pgvector.django.VectorField(dimensions=384)),
("raw", models.TextField()),
("compiled", models.TextField()),
("heading", models.CharField(blank=True, default=None, max_length=1000, null=True)),
(
"file_type",
models.CharField(
choices=[
("image", "Image"),
("pdf", "Pdf"),
("plaintext", "Plaintext"),
("markdown", "Markdown"),
("org", "Org"),
("notion", "Notion"),
("github", "Github"),
("conversation", "Conversation"),
],
default="plaintext",
max_length=30,
),
),
("file_path", models.CharField(blank=True, default=None, max_length=400, null=True)),
("file_name", models.CharField(blank=True, default=None, max_length=400, null=True)),
("url", models.URLField(blank=True, default=None, max_length=400, null=True)),
("hashed_value", models.CharField(max_length=100)),
(
"user",
models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
to=settings.AUTH_USER_MODEL,
),
),
],
options={
"abstract": False,
},
),
]

View File

@@ -0,0 +1,18 @@
# Generated by Django 4.2.5 on 2023-10-13 02:39
from django.db import migrations, models
import uuid
class Migration(migrations.Migration):
dependencies = [
("database", "0004_conversationprocessorconfig_githubconfig_and_more"),
]
operations = [
migrations.AddField(
model_name="embeddings",
name="corpus_id",
field=models.UUIDField(default=uuid.uuid4, editable=False),
),
]

View File

@@ -0,0 +1,33 @@
# Generated by Django 4.2.5 on 2023-10-13 19:28
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
("database", "0005_embeddings_corpus_id"),
]
operations = [
migrations.CreateModel(
name="EmbeddingsDates",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("date", models.DateField()),
(
"embeddings",
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE,
related_name="embeddings_dates",
to="database.embeddings",
),
),
],
options={
"indexes": [models.Index(fields=["date"], name="database_em_date_a1ba47_idx")],
},
),
]

View File

@@ -2,11 +2,25 @@ import uuid
from django.db import models
from django.contrib.auth.models import AbstractUser
from pgvector.django import VectorField
class BaseModel(models.Model):
created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True)
class Meta:
abstract = True
class KhojUser(AbstractUser):
uuid = models.UUIDField(models.UUIDField(default=uuid.uuid4, editable=False))
def save(self, *args, **kwargs):
if not self.uuid:
self.uuid = uuid.uuid4()
super().save(*args, **kwargs)
class GoogleUser(models.Model):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
@@ -23,31 +37,85 @@ class GoogleUser(models.Model):
return self.name
class Configuration(models.Model):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
class NotionConfig(models.Model):
class NotionConfig(BaseModel):
token = models.CharField(max_length=200)
compressed_jsonl = models.CharField(max_length=300)
embeddings_file = models.CharField(max_length=300)
config = models.OneToOneField(Configuration, on_delete=models.CASCADE)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class GithubConfig(models.Model):
class GithubConfig(BaseModel):
pat_token = models.CharField(max_length=200)
compressed_jsonl = models.CharField(max_length=300)
embeddings_file = models.CharField(max_length=300)
config = models.OneToOneField(Configuration, on_delete=models.CASCADE)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class GithubRepoConfig(models.Model):
class GithubRepoConfig(BaseModel):
name = models.CharField(max_length=200)
owner = models.CharField(max_length=200)
branch = models.CharField(max_length=200)
github_config = models.ForeignKey(GithubConfig, on_delete=models.CASCADE)
github_config = models.ForeignKey(GithubConfig, on_delete=models.CASCADE, related_name="githubrepoconfig")
class ConversationProcessorConfig(models.Model):
class LocalOrgConfig(BaseModel):
input_files = models.JSONField(default=list, null=True)
input_filter = models.JSONField(default=list, null=True)
index_heading_entries = models.BooleanField(default=False)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class LocalMarkdownConfig(BaseModel):
input_files = models.JSONField(default=list, null=True)
input_filter = models.JSONField(default=list, null=True)
index_heading_entries = models.BooleanField(default=False)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class LocalPdfConfig(BaseModel):
input_files = models.JSONField(default=list, null=True)
input_filter = models.JSONField(default=list, null=True)
index_heading_entries = models.BooleanField(default=False)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class LocalPlaintextConfig(BaseModel):
input_files = models.JSONField(default=list, null=True)
input_filter = models.JSONField(default=list, null=True)
index_heading_entries = models.BooleanField(default=False)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class ConversationProcessorConfig(BaseModel):
conversation = models.JSONField()
enable_offline_chat = models.BooleanField(default=False)
class Embeddings(BaseModel):
class EmbeddingsType(models.TextChoices):
IMAGE = "image"
PDF = "pdf"
PLAINTEXT = "plaintext"
MARKDOWN = "markdown"
ORG = "org"
NOTION = "notion"
GITHUB = "github"
CONVERSATION = "conversation"
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
embeddings = VectorField(dimensions=384)
raw = models.TextField()
compiled = models.TextField()
heading = models.CharField(max_length=1000, default=None, null=True, blank=True)
file_type = models.CharField(max_length=30, choices=EmbeddingsType.choices, default=EmbeddingsType.PLAINTEXT)
file_path = models.CharField(max_length=400, default=None, null=True, blank=True)
file_name = models.CharField(max_length=400, default=None, null=True, blank=True)
url = models.URLField(max_length=400, default=None, null=True, blank=True)
hashed_value = models.CharField(max_length=100)
corpus_id = models.UUIDField(default=uuid.uuid4, editable=False)
class EmbeddingsDates(BaseModel):
date = models.DateField()
embeddings = models.ForeignKey(Embeddings, on_delete=models.CASCADE, related_name="embeddings_dates")
class Meta:
indexes = [
models.Index(fields=["date"]),
]

3
src/database/tests.py Normal file
View File

@@ -0,0 +1,3 @@
from django.test import TestCase
# Create your tests here.