Initial commit of a functional but not yet elegant prototype for this concept

This commit is contained in:
sabaimran
2024-11-28 17:28:23 -08:00
parent 9368699b2c
commit d91935c880
15 changed files with 455 additions and 150 deletions

View File

@@ -1,7 +1,9 @@
import logging
import os
import re
import uuid
from random import choice
from typing import Dict, List, Optional, Union
from django.contrib.auth.models import AbstractUser
from django.contrib.postgres.fields import ArrayField
@@ -11,9 +13,109 @@ from django.db.models.signals import pre_save
from django.dispatch import receiver
from pgvector.django import VectorField
from phonenumber_field.modelfields import PhoneNumberField
from pydantic import BaseModel as PydanticBaseModel
from pydantic import Field
logger = logging.getLogger(__name__)
class BaseModel(models.Model):
# Pydantic models for type Chat Message validation
class Context(PydanticBaseModel):
compiled: str
file: str
class CodeContextFile(PydanticBaseModel):
filename: str
b64_data: str
class CodeContextResult(PydanticBaseModel):
success: bool
output_files: List[CodeContextFile]
std_out: str
std_err: str
code_runtime: int
class CodeContextData(PydanticBaseModel):
code: str
result: CodeContextResult
class WebPage(PydanticBaseModel):
link: str
query: Optional[str] = None
snippet: str
class AnswerBox(PydanticBaseModel):
link: str
snippet: str
title: str
snippetHighlighted: List[str]
class PeopleAlsoAsk(PydanticBaseModel):
link: str
question: str
snippet: str
title: str
class KnowledgeGraph(PydanticBaseModel):
attributes: Dict[str, str]
description: str
descriptionLink: str
descriptionSource: str
imageUrl: str
title: str
type: str
class OrganicContext(PydanticBaseModel):
snippet: str
title: str
link: str
class OnlineContext(PydanticBaseModel):
webpages: Optional[Union[WebPage, List[WebPage]]] = None
answerBox: Optional[AnswerBox] = None
peopleAlsoAsk: Optional[List[PeopleAlsoAsk]] = None
knowledgeGraph: Optional[KnowledgeGraph] = None
organicContext: Optional[List[OrganicContext]] = None
class Intent(PydanticBaseModel):
type: str
query: str
memory_type: str = Field(alias="memory-type")
inferred_queries: Optional[List[str]] = Field(default=None, alias="inferred-queries")
class TrainOfThought(PydanticBaseModel):
type: str
data: str
class ChatMessage(PydanticBaseModel):
message: str
trainOfThought: List[TrainOfThought] = []
context: List[Context] = []
onlineContext: Dict[str, OnlineContext] = {}
codeContext: Dict[str, CodeContextData] = {}
created: str
images: Optional[List[str]] = None
queryFiles: Optional[List[Dict]] = None
excalidrawDiagram: Optional[str] = None
by: str
turnId: Optional[str]
intent: Optional[Intent] = None
automationId: Optional[str] = None
class DbBaseModel(models.Model):
created_at = models.DateTimeField(auto_now_add=True)
updated_at = models.DateTimeField(auto_now=True)
@@ -21,7 +123,7 @@ class BaseModel(models.Model):
abstract = True
class ClientApplication(BaseModel):
class ClientApplication(DbBaseModel):
name = models.CharField(max_length=200)
client_id = models.CharField(max_length=200)
client_secret = models.CharField(max_length=200)
@@ -67,7 +169,7 @@ class KhojApiUser(models.Model):
accessed_at = models.DateTimeField(null=True, default=None)
class Subscription(BaseModel):
class Subscription(DbBaseModel):
class Type(models.TextChoices):
TRIAL = "trial"
STANDARD = "standard"
@@ -79,13 +181,13 @@ class Subscription(BaseModel):
enabled_trial_at = models.DateTimeField(null=True, default=None, blank=True)
class OpenAIProcessorConversationConfig(BaseModel):
class OpenAIProcessorConversationConfig(DbBaseModel):
name = models.CharField(max_length=200)
api_key = models.CharField(max_length=200)
api_base_url = models.URLField(max_length=200, default=None, blank=True, null=True)
class ChatModelOptions(BaseModel):
class ChatModelOptions(DbBaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
OFFLINE = "offline"
@@ -103,12 +205,12 @@ class ChatModelOptions(BaseModel):
)
class VoiceModelOption(BaseModel):
class VoiceModelOption(DbBaseModel):
model_id = models.CharField(max_length=200)
name = models.CharField(max_length=200)
class Agent(BaseModel):
class Agent(DbBaseModel):
class StyleColorTypes(models.TextChoices):
BLUE = "blue"
GREEN = "green"
@@ -208,7 +310,7 @@ class Agent(BaseModel):
super().save(*args, **kwargs)
class ProcessLock(BaseModel):
class ProcessLock(DbBaseModel):
class Operation(models.TextChoices):
INDEX_CONTENT = "index_content"
SCHEDULED_JOB = "scheduled_job"
@@ -231,24 +333,24 @@ def verify_agent(sender, instance, **kwargs):
raise ValidationError(f"A private Agent with the name {instance.name} already exists.")
class NotionConfig(BaseModel):
class NotionConfig(DbBaseModel):
token = models.CharField(max_length=200)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class GithubConfig(BaseModel):
class GithubConfig(DbBaseModel):
pat_token = models.CharField(max_length=200)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class GithubRepoConfig(BaseModel):
class GithubRepoConfig(DbBaseModel):
name = models.CharField(max_length=200)
owner = models.CharField(max_length=200)
branch = models.CharField(max_length=200)
github_config = models.ForeignKey(GithubConfig, on_delete=models.CASCADE, related_name="githubrepoconfig")
class WebScraper(BaseModel):
class WebScraper(DbBaseModel):
class WebScraperType(models.TextChoices):
FIRECRAWL = "Firecrawl"
OLOSTEP = "Olostep"
@@ -321,7 +423,7 @@ class WebScraper(BaseModel):
super().save(*args, **kwargs)
class ServerChatSettings(BaseModel):
class ServerChatSettings(DbBaseModel):
chat_default = models.ForeignKey(
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default"
)
@@ -333,35 +435,35 @@ class ServerChatSettings(BaseModel):
)
class LocalOrgConfig(BaseModel):
class LocalOrgConfig(DbBaseModel):
input_files = models.JSONField(default=list, null=True)
input_filter = models.JSONField(default=list, null=True)
index_heading_entries = models.BooleanField(default=False)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class LocalMarkdownConfig(BaseModel):
class LocalMarkdownConfig(DbBaseModel):
input_files = models.JSONField(default=list, null=True)
input_filter = models.JSONField(default=list, null=True)
index_heading_entries = models.BooleanField(default=False)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class LocalPdfConfig(BaseModel):
class LocalPdfConfig(DbBaseModel):
input_files = models.JSONField(default=list, null=True)
input_filter = models.JSONField(default=list, null=True)
index_heading_entries = models.BooleanField(default=False)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class LocalPlaintextConfig(BaseModel):
class LocalPlaintextConfig(DbBaseModel):
input_files = models.JSONField(default=list, null=True)
input_filter = models.JSONField(default=list, null=True)
index_heading_entries = models.BooleanField(default=False)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class SearchModelConfig(BaseModel):
class SearchModelConfig(DbBaseModel):
class ModelType(models.TextChoices):
TEXT = "text"
@@ -393,7 +495,7 @@ class SearchModelConfig(BaseModel):
bi_encoder_confidence_threshold = models.FloatField(default=0.18)
class TextToImageModelConfig(BaseModel):
class TextToImageModelConfig(DbBaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
STABILITYAI = "stability-ai"
@@ -430,7 +532,7 @@ class TextToImageModelConfig(BaseModel):
super().save(*args, **kwargs)
class SpeechToTextModelOptions(BaseModel):
class SpeechToTextModelOptions(DbBaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
OFFLINE = "offline"
@@ -439,22 +541,22 @@ class SpeechToTextModelOptions(BaseModel):
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
class UserConversationConfig(BaseModel):
class UserConversationConfig(DbBaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True)
class UserVoiceModelConfig(BaseModel):
class UserVoiceModelConfig(DbBaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(VoiceModelOption, on_delete=models.CASCADE, default=None, null=True, blank=True)
class UserTextToImageModelConfig(BaseModel):
class UserTextToImageModelConfig(DbBaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(TextToImageModelConfig, on_delete=models.CASCADE)
class Conversation(BaseModel):
class Conversation(DbBaseModel):
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
conversation_log = models.JSONField(default=dict)
client = models.ForeignKey(ClientApplication, on_delete=models.CASCADE, default=None, null=True, blank=True)
@@ -468,8 +570,39 @@ class Conversation(BaseModel):
file_filters = models.JSONField(default=list)
id = models.UUIDField(default=uuid.uuid4, editable=False, unique=True, primary_key=True, db_index=True)
def clean(self):
# Validate conversation_log structure
try:
messages = self.conversation_log.get("chat", [])
for msg in messages:
ChatMessage.model_validate(msg)
except Exception as e:
raise ValidationError(f"Invalid conversation_log format: {str(e)}")
class PublicConversation(BaseModel):
def save(self, *args, **kwargs):
self.clean()
super().save(*args, **kwargs)
@property
def messages(self) -> List[ChatMessage]:
"""Type-hinted accessor for conversation messages"""
validated_messages = []
for msg in self.conversation_log.get("chat", []):
try:
# Clean up inferred queries if they contain None
if msg.get("intent") and msg["intent"].get("inferred-queries"):
msg["intent"]["inferred-queries"] = [
q for q in msg["intent"]["inferred-queries"] if q is not None and isinstance(q, str)
]
msg["message"] = str(msg.get("message", ""))
validated_messages.append(ChatMessage.model_validate(msg))
except ValidationError as e:
logger.warning(f"Skipping invalid message in conversation: {e}")
continue
return validated_messages
class PublicConversation(DbBaseModel):
source_owner = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
conversation_log = models.JSONField(default=dict)
slug = models.CharField(max_length=200, default=None, null=True, blank=True)
@@ -499,12 +632,12 @@ def verify_public_conversation(sender, instance, **kwargs):
instance.slug = slug
class ReflectiveQuestion(BaseModel):
class ReflectiveQuestion(DbBaseModel):
question = models.CharField(max_length=500)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
class Entry(BaseModel):
class Entry(DbBaseModel):
class EntryType(models.TextChoices):
IMAGE = "image"
PDF = "pdf"
@@ -541,7 +674,7 @@ class Entry(BaseModel):
raise ValidationError("An Entry cannot be associated with both a user and an agent.")
class FileObject(BaseModel):
class FileObject(DbBaseModel):
# Same as Entry but raw will be a much larger string
file_name = models.CharField(max_length=400, default=None, null=True, blank=True)
raw_text = models.TextField()
@@ -549,7 +682,7 @@ class FileObject(BaseModel):
agent = models.ForeignKey(Agent, on_delete=models.CASCADE, default=None, null=True, blank=True)
class EntryDates(BaseModel):
class EntryDates(DbBaseModel):
date = models.DateField()
entry = models.ForeignKey(Entry, on_delete=models.CASCADE, related_name="embeddings_dates")
@@ -559,12 +692,12 @@ class EntryDates(BaseModel):
]
class UserRequests(BaseModel):
class UserRequests(DbBaseModel):
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
slug = models.CharField(max_length=200)
class DataStore(BaseModel):
class DataStore(DbBaseModel):
key = models.CharField(max_length=200, unique=True)
value = models.JSONField(default=dict)
private = models.BooleanField(default=False)