mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-05 05:39:11 +00:00
Include agent personality through subtasks and support custom agents (#916)
Currently, the personality of the agent is only included in the final response that it returns to the user. Historically, this was because models were quite bad at navigating the additional context of personality, and there was a bias towards having more control over certain operations (e.g., tool selection, question extraction). Going forward, it should be more approachable to have prompts included in the sub tasks that Khoj runs in order to response to a given query. Make this possible in this PR. This also sets us up for agent creation becoming available soon. Create custom agents in #928 Agents are useful insofar as you can personalize them to fulfill specific subtasks you need to accomplish. In this PR, we add support for using custom agents that can be configured with a custom system prompt (aka persona) and knowledge base (from your own indexed documents). Once created, private agents can be accessible only to the creator, and protected agents can be accessible via a direct link. Custom tool selection for agents in #930 Expose the functionality to select which tools a given agent has access to. By default, they have all. Can limit both information sources and output modes. Add new tools to the agent modification form
This commit is contained in:
@@ -3,6 +3,7 @@ import uuid
|
||||
from random import choice
|
||||
|
||||
from django.contrib.auth.models import AbstractUser
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db import models
|
||||
from django.db.models.signals import pre_save
|
||||
@@ -10,6 +11,8 @@ from django.dispatch import receiver
|
||||
from pgvector.django import VectorField
|
||||
from phonenumber_field.modelfields import PhoneNumberField
|
||||
|
||||
from khoj.utils.helpers import ConversationCommand
|
||||
|
||||
|
||||
class BaseModel(models.Model):
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
@@ -125,7 +128,7 @@ class Agent(BaseModel):
|
||||
EMERALD = "emerald"
|
||||
|
||||
class StyleIconTypes(models.TextChoices):
|
||||
LIGHBULB = "Lightbulb"
|
||||
LIGHTBULB = "Lightbulb"
|
||||
HEALTH = "Health"
|
||||
ROBOT = "Robot"
|
||||
APERTURE = "Aperture"
|
||||
@@ -140,20 +143,64 @@ class Agent(BaseModel):
|
||||
CLOCK_COUNTER_CLOCKWISE = "ClockCounterClockwise"
|
||||
PENCIL_LINE = "PencilLine"
|
||||
CHALKBOARD = "Chalkboard"
|
||||
CIGARETTE = "Cigarette"
|
||||
CRANE_TOWER = "CraneTower"
|
||||
HEART = "Heart"
|
||||
LEAF = "Leaf"
|
||||
NEWSPAPER_CLIPPING = "NewspaperClipping"
|
||||
ORANGE_SLICE = "OrangeSlice"
|
||||
SMILEY_MELTING = "SmileyMelting"
|
||||
YIN_YANG = "YinYang"
|
||||
SNEAKER_MOVE = "SneakerMove"
|
||||
STUDENT = "Student"
|
||||
OVEN = "Oven"
|
||||
GAVEL = "Gavel"
|
||||
BROADCAST = "Broadcast"
|
||||
|
||||
class PrivacyLevel(models.TextChoices):
|
||||
PUBLIC = "public"
|
||||
PRIVATE = "private"
|
||||
PROTECTED = "protected"
|
||||
|
||||
class InputToolOptions(models.TextChoices):
|
||||
# These map to various ConversationCommand types
|
||||
GENERAL = "general"
|
||||
ONLINE = "online"
|
||||
NOTES = "notes"
|
||||
SUMMARIZE = "summarize"
|
||||
WEBPAGE = "webpage"
|
||||
|
||||
class OutputModeOptions(models.TextChoices):
|
||||
# These map to various ConversationCommand types
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
|
||||
creator = models.ForeignKey(
|
||||
KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True
|
||||
) # Creator will only be null when the agents are managed by admin
|
||||
name = models.CharField(max_length=200)
|
||||
personality = models.TextField()
|
||||
avatar = models.URLField(max_length=400, default=None, null=True, blank=True)
|
||||
tools = models.JSONField(default=list) # List of tools the agent has access to, like online search or notes search
|
||||
public = models.BooleanField(default=False)
|
||||
input_tools = ArrayField(models.CharField(max_length=200, choices=InputToolOptions.choices), default=list)
|
||||
output_modes = ArrayField(models.CharField(max_length=200, choices=OutputModeOptions.choices), default=list)
|
||||
managed_by_admin = models.BooleanField(default=False)
|
||||
chat_model = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE)
|
||||
slug = models.CharField(max_length=200)
|
||||
slug = models.CharField(max_length=200, unique=True)
|
||||
style_color = models.CharField(max_length=200, choices=StyleColorTypes.choices, default=StyleColorTypes.BLUE)
|
||||
style_icon = models.CharField(max_length=200, choices=StyleIconTypes.choices, default=StyleIconTypes.LIGHBULB)
|
||||
style_icon = models.CharField(max_length=200, choices=StyleIconTypes.choices, default=StyleIconTypes.LIGHTBULB)
|
||||
privacy_level = models.CharField(max_length=30, choices=PrivacyLevel.choices, default=PrivacyLevel.PRIVATE)
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
is_new = self._state.adding
|
||||
|
||||
if self.creator is None:
|
||||
self.managed_by_admin = True
|
||||
|
||||
if is_new:
|
||||
random_sequence = "".join(choice("0123456789") for i in range(6))
|
||||
slug = f"{self.name.lower().replace(' ', '-')}-{random_sequence}"
|
||||
self.slug = slug
|
||||
|
||||
super().save(*args, **kwargs)
|
||||
|
||||
|
||||
class ProcessLock(BaseModel):
|
||||
@@ -173,22 +220,11 @@ class ProcessLock(BaseModel):
|
||||
def verify_agent(sender, instance, **kwargs):
|
||||
# check if this is a new instance
|
||||
if instance._state.adding:
|
||||
if Agent.objects.filter(name=instance.name, public=True).exists():
|
||||
if Agent.objects.filter(name=instance.name, privacy_level=Agent.PrivacyLevel.PUBLIC).exists():
|
||||
raise ValidationError(f"A public Agent with the name {instance.name} already exists.")
|
||||
if Agent.objects.filter(name=instance.name, creator=instance.creator).exists():
|
||||
raise ValidationError(f"A private Agent with the name {instance.name} already exists.")
|
||||
|
||||
slug = instance.name.lower().replace(" ", "-")
|
||||
observed_random_numbers = set()
|
||||
while Agent.objects.filter(slug=slug).exists():
|
||||
try:
|
||||
random_number = choice([i for i in range(0, 1000) if i not in observed_random_numbers])
|
||||
except IndexError:
|
||||
raise ValidationError("Unable to generate a unique slug for the Agent. Please try again later.")
|
||||
observed_random_numbers.add(random_number)
|
||||
slug = f"{slug}-{random_number}"
|
||||
instance.slug = slug
|
||||
|
||||
|
||||
class NotionConfig(BaseModel):
|
||||
token = models.CharField(max_length=200)
|
||||
@@ -406,6 +442,7 @@ class Entry(BaseModel):
|
||||
GITHUB = "github"
|
||||
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
agent = models.ForeignKey(Agent, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
embeddings = VectorField(dimensions=None)
|
||||
raw = models.TextField()
|
||||
compiled = models.TextField()
|
||||
@@ -418,12 +455,17 @@ class Entry(BaseModel):
|
||||
hashed_value = models.CharField(max_length=100)
|
||||
corpus_id = models.UUIDField(default=uuid.uuid4, editable=False)
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
if self.user and self.agent:
|
||||
raise ValidationError("An Entry cannot be associated with both a user and an agent.")
|
||||
|
||||
|
||||
class FileObject(BaseModel):
|
||||
# Same as Entry but raw will be a much larger string
|
||||
file_name = models.CharField(max_length=400, default=None, null=True, blank=True)
|
||||
raw_text = models.TextField()
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
agent = models.ForeignKey(Agent, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
|
||||
|
||||
class EntryDates(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user