Part 1: Server-side changes to support agents integrated with Conversations (#671)

* Initial pass at backend changes to support agents
- Add a db model for Agents, attaching them to conversations
- When an agent is added to a conversation, override the system prompt to tweak the instructions
- Agents can be configured with prompt modification, model specification, a profile picture, and other things
- Admin-configured models will not be editable by individual users
- Add unit tests to verify agent behavior. Unit tests demonstrate imperfect adherence to prompt specifications

* Customize default behaviors for conversations without agents or with default agents

* Use agent_id for getting correct agent

* Merge migrations

* Simplify some variable definitions, add additional security checks for agents

* Rename agent.tuning -> agent.personality
This commit is contained in:
sabaimran
2024-03-23 09:39:38 -07:00
committed by GitHub
parent 7416ca9ae1
commit 8abc8ded82
18 changed files with 527 additions and 60 deletions

View File

@@ -21,6 +21,7 @@ from starlette.middleware.sessions import SessionMiddleware
from starlette.requests import HTTPConnection
from khoj.database.adapters import (
AgentAdapters,
ClientApplicationAdapters,
ConversationAdapters,
SubscriptionState,
@@ -229,11 +230,16 @@ def configure_server(
state.SearchType = configure_search_types()
state.search_models = configure_search(state.search_models, state.config.search_type)
setup_default_agent()
initialize_content(regenerate, search_type, init, user)
except Exception as e:
raise e
def setup_default_agent():
AgentAdapters.create_default_agent()
def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, init=False, user: KhojUser = None):
# Initialize Content from Config
if state.search_models:

View File

@@ -16,6 +16,7 @@ from pgvector.django import CosineDistance
from torch import Tensor
from khoj.database.models import (
Agent,
ChatModelOptions,
ClientApplication,
Conversation,
@@ -37,6 +38,7 @@ from khoj.database.models import (
UserRequests,
UserSearchModelConfig,
)
from khoj.processor.conversation import prompts
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter
@@ -391,6 +393,58 @@ class ClientApplicationAdapters:
return await ClientApplication.objects.filter(client_id=client_id, client_secret=client_secret).afirst()
class AgentAdapters:
DEFAULT_AGENT_NAME = "khoj"
DEFAULT_AGENT_AVATAR = "https://khoj-web-bucket.s3.amazonaws.com/lamp-128.png"
@staticmethod
async def aget_agent_by_id(agent_id: int, user: KhojUser):
agent = await Agent.objects.filter(id=agent_id).afirst()
# Check if it's accessible to the user
if agent and (agent.public or agent.creator == user):
return agent
return None
@staticmethod
def get_all_accessible_agents(user: KhojUser = None):
return Agent.objects.filter(Q(public=True) | Q(creator=user)).distinct()
@staticmethod
def get_conversation_agent_by_id(agent_id: int):
agent = Agent.objects.filter(id=agent_id).first()
if agent == AgentAdapters.get_default_agent():
# If the agent is set to the default agent, then return None and let the default application code be used
return None
return agent
@staticmethod
def get_default_agent():
return Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).first()
@staticmethod
def create_default_agent():
# First delete the existing default
Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).delete()
default_conversation_config = ConversationAdapters.get_default_conversation_config()
default_personality = prompts.personality.format(current_date="placeholder")
# The default agent is public and managed by the admin. It's handled a little differently than other agents.
return Agent.objects.create(
name=AgentAdapters.DEFAULT_AGENT_NAME,
public=True,
managed_by_admin=True,
chat_model=default_conversation_config,
personality=default_personality,
tools=["*"],
avatar=AgentAdapters.DEFAULT_AGENT_AVATAR,
)
@staticmethod
async def aget_default_agent():
return await Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).afirst()
class ConversationAdapters:
@staticmethod
def get_conversation_by_user(
@@ -431,7 +485,12 @@ class ConversationAdapters:
return Conversation.objects.filter(id=conversation_id).first()
@staticmethod
async def acreate_conversation_session(user: KhojUser, client_application: ClientApplication = None):
async def acreate_conversation_session(
user: KhojUser, client_application: ClientApplication = None, agent_id: int = None
):
if agent_id:
agent = await AgentAdapters.aget_agent_by_id(agent_id, user)
return await Conversation.objects.acreate(user=user, client=client_application, agent=agent)
return await Conversation.objects.acreate(user=user, client=client_application)
@staticmethod
@@ -443,9 +502,14 @@ class ConversationAdapters:
elif title:
return await Conversation.objects.filter(user=user, client=client_application, title=title).afirst()
else:
return await (
Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst()
) or await Conversation.objects.acreate(user=user, client=client_application)
conversation = Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at")
if await conversation.aexists():
return await conversation.prefetch_related("agent").afirst()
return await (
Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst()
) or await Conversation.objects.acreate(user=user, client=client_application)
@staticmethod
async def adelete_conversation_by_user(
@@ -603,9 +667,14 @@ class ConversationAdapters:
return random.sample(all_questions, max_results)
@staticmethod
def get_valid_conversation_config(user: KhojUser):
def get_valid_conversation_config(user: KhojUser, conversation: Conversation):
offline_chat_config = ConversationAdapters.get_offline_chat_conversation_config()
conversation_config = ConversationAdapters.get_conversation_config(user)
if conversation.agent and conversation.agent.chat_model:
conversation_config = conversation.agent.chat_model
else:
conversation_config = ConversationAdapters.get_conversation_config(user)
if conversation_config is None:
conversation_config = ConversationAdapters.get_default_conversation_config()

View File

@@ -6,6 +6,7 @@ from django.contrib.auth.admin import UserAdmin
from django.http import HttpResponse
from khoj.database.models import (
Agent,
ChatModelOptions,
ClientApplication,
Conversation,
@@ -50,6 +51,7 @@ admin.site.register(ReflectiveQuestion)
admin.site.register(UserSearchModelConfig)
admin.site.register(TextToImageModelConfig)
admin.site.register(ClientApplication)
admin.site.register(Agent)
@admin.register(Entry)

View File

@@ -0,0 +1,52 @@
# Generated by Django 4.2.10 on 2024-03-11 05:12
import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0030_conversation_slug_and_title"),
]
operations = [
migrations.CreateModel(
name="Agent",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("name", models.CharField(max_length=200)),
("tuning", models.TextField()),
("avatar", models.URLField(blank=True, default=None, max_length=400, null=True)),
("tools", models.JSONField(default=list)),
("public", models.BooleanField(default=False)),
("managed_by_admin", models.BooleanField(default=False)),
(
"chat_model",
models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="database.chatmodeloptions"),
),
(
"creator",
models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
to=settings.AUTH_USER_MODEL,
),
),
],
options={
"abstract": False,
},
),
migrations.AddField(
model_name="conversation",
name="agent",
field=models.ForeignKey(
blank=True, default=None, null=True, on_delete=django.db.models.deletion.SET_NULL, to="database.agent"
),
),
]

View File

@@ -0,0 +1,14 @@
# Generated by Django 4.2.10 on 2024-03-22 04:27
from typing import List
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("database", "0031_agent_conversation_agent"),
("database", "0031_alter_googleuser_locale"),
]
operations: List[str] = []

View File

@@ -0,0 +1,17 @@
# Generated by Django 4.2.10 on 2024-03-23 16:01
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("database", "0032_merge_20240322_0427"),
]
operations = [
migrations.RenameField(
model_name="agent",
old_name="tuning",
new_name="personality",
),
]

View File

@@ -1,7 +1,10 @@
import uuid
from django.contrib.auth.models import AbstractUser
from django.core.exceptions import ValidationError
from django.db import models
from django.db.models.signals import pre_save
from django.dispatch import receiver
from pgvector.django import VectorField
from phonenumber_field.modelfields import PhoneNumberField
@@ -69,6 +72,37 @@ class Subscription(BaseModel):
renewal_date = models.DateTimeField(null=True, default=None, blank=True)
class ChatModelOptions(BaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
OFFLINE = "offline"
max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
chat_model = models.CharField(max_length=200, default="mistral-7b-instruct-v0.1.Q4_0.gguf")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
class Agent(BaseModel):
creator = models.ForeignKey(
KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True
) # Creator will only be null when the agents are managed by admin
name = models.CharField(max_length=200)
personality = models.TextField()
avatar = models.URLField(max_length=400, default=None, null=True, blank=True)
tools = models.JSONField(default=list) # List of tools the agent has access to, like online search or notes search
public = models.BooleanField(default=False)
managed_by_admin = models.BooleanField(default=False)
chat_model = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE)
@receiver(pre_save, sender=Agent)
def check_public_name(sender, instance, **kwargs):
if instance.public:
if Agent.objects.filter(name=instance.name, public=True).exists():
raise ValidationError(f"A public Agent with the name {instance.name} already exists.")
class NotionConfig(BaseModel):
token = models.CharField(max_length=200)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
@@ -153,17 +187,6 @@ class SpeechToTextModelOptions(BaseModel):
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
class ChatModelOptions(BaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
OFFLINE = "offline"
max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
chat_model = models.CharField(max_length=200, default="mistral-7b-instruct-v0.1.Q4_0.gguf")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
class UserConversationConfig(BaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True)
@@ -180,6 +203,7 @@ class Conversation(BaseModel):
client = models.ForeignKey(ClientApplication, on_delete=models.CASCADE, default=None, null=True, blank=True)
slug = models.CharField(max_length=200, default=None, null=True, blank=True)
title = models.CharField(max_length=200, default=None, null=True, blank=True)
agent = models.ForeignKey(Agent, on_delete=models.SET_NULL, default=None, null=True, blank=True)
class ReflectiveQuestion(BaseModel):

View File

@@ -6,6 +6,7 @@ from typing import Any, Iterator, List, Union
from langchain.schema import ChatMessage
from khoj.database.models import Agent
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import (
ThreadedGenerator,
@@ -141,6 +142,7 @@ def converse_offline(
tokenizer_name=None,
location_data: LocationData = None,
user_name: str = None,
agent: Agent = None,
) -> Union[ThreadedGenerator, Iterator[str]]:
"""
Converse with user using Llama
@@ -156,6 +158,15 @@ def converse_offline(
# Initialize Variables
compiled_references_message = "\n\n".join({f"{item}" for item in references})
current_date = datetime.now().strftime("%Y-%m-%d")
if agent and agent.personality:
system_prompt = prompts.custom_system_prompt_message_gpt4all.format(
name=agent.name, bio=agent.personality, current_date=current_date
)
else:
system_prompt = prompts.system_prompt_message_gpt4all.format(current_date=current_date)
conversation_primer = prompts.query_prompt.format(query=user_query)
if location_data:
@@ -185,10 +196,9 @@ def converse_offline(
conversation_primer = f"{prompts.notes_conversation_gpt4all.format(references=compiled_references_message)}\n{conversation_primer}"
# Setup Prompt with Primer or Conversation History
current_date = datetime.now().strftime("%Y-%m-%d")
messages = generate_chatml_messages_with_context(
conversation_primer,
prompts.system_prompt_message_gpt4all.format(current_date=current_date),
system_prompt,
conversation_log,
model_name=model,
max_prompt_size=max_prompt_size,

View File

@@ -5,6 +5,7 @@ from typing import Optional
from langchain.schema import ChatMessage
from khoj.database.models import Agent
from khoj.processor.conversation import prompts
from khoj.processor.conversation.openai.utils import (
chat_completion_with_backoff,
@@ -115,6 +116,7 @@ def converse(
tokenizer_name=None,
location_data: LocationData = None,
user_name: str = None,
agent: Agent = None,
):
"""
Converse with user using OpenAI's ChatGPT
@@ -125,6 +127,13 @@ def converse(
conversation_primer = prompts.query_prompt.format(query=user_query)
if agent and agent.personality:
system_prompt = prompts.custom_personality.format(
name=agent.name, bio=agent.personality, current_date=current_date
)
else:
system_prompt = prompts.personality.format(current_date=current_date)
if location_data:
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
location_prompt = prompts.user_location.format(location=location)
@@ -152,7 +161,7 @@ def converse(
# Setup Prompt with Primer or Conversation History
messages = generate_chatml_messages_with_context(
conversation_primer,
prompts.personality.format(current_date=current_date),
system_prompt,
conversation_log,
model,
max_prompt_size,

View File

@@ -21,6 +21,24 @@ Today is {current_date} in UTC.
""".strip()
)
custom_personality = PromptTemplate.from_template(
"""
Your are {name}, a personal agent on Khoj.
Use your general knowledge and past conversation with the user as context to inform your responses.
You were created by Khoj Inc. with the following capabilities:
- You *CAN REMEMBER ALL NOTES and PERSONAL INFORMATION FOREVER* that the user ever shares with you.
- Users can share files and other information with you using the Khoj Desktop, Obsidian or Emacs app. They can also drag and drop their files into the chat window.
- Say "I don't know" or "I don't understand" if you don't know what to say or if you don't know the answer to a question.
- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations.
- Sometimes the user will share personal information that needs to be remembered, like an account ID or a residential address. These can be acknowledged with a simple "Got it" or "Okay".
Today is {current_date} in UTC.
Instructions:\n{bio}
""".strip()
)
## General Conversation
## --
general_conversation = PromptTemplate.from_template(
@@ -61,6 +79,20 @@ Today is {current_date} in UTC.
""".strip()
)
custom_system_prompt_message_gpt4all = PromptTemplate.from_template(
"""
You are {name}, a personal agent on Khoj.
- Use your general knowledge and past conversation with the user as context to inform your responses.
- If you do not know the answer, say 'I don't know.'
- Think step-by-step and ask questions to get the necessary information to answer the user's question.
- Do not print verbatim Notes unless necessary.
Today is {current_date} in UTC.
Instructions:\n{bio}
""".strip()
)
system_prompt_message_extract_questions_gpt4all = f"""You are Khoj, a kind and intelligent personal assistant. When the user asks you a question, you ask follow-up questions to clarify the necessary information you need in order to answer from the user's perspective.
- Write the question as if you can search for the answer on the user's personal notes.
- Try to be as specific as possible. Instead of saying "they" or "it" or "he", use the name of the person or thing you are referring to. For example, instead of saying "Which store did they go to?", say "Which store did Alice and Bob go to?".

View File

@@ -13,7 +13,7 @@ from fastapi.requests import Request
from fastapi.responses import Response
from starlette.authentication import requires
from khoj.configure import configure_server, initialize_content
from khoj.configure import initialize_content
from khoj.database.adapters import (
ConversationAdapters,
EntryAdapters,

View File

@@ -148,11 +148,12 @@ def chat_sessions(
async def create_chat_session(
request: Request,
common: CommonQueryParams,
agent_id: Optional[int] = None,
):
user = request.user.object
# Create new Conversation Session
conversation = await ConversationAdapters.acreate_conversation_session(user, request.user.client_app)
conversation = await ConversationAdapters.acreate_conversation_session(user, request.user.client_app, agent_id)
response = {"conversation_id": conversation.id}
@@ -341,6 +342,7 @@ async def chat(
llm_response, chat_metadata = await agenerate_chat_response(
defiltered_query,
meta_log,
conversation,
compiled_references,
online_results,
inferred_queries,

View File

@@ -10,10 +10,11 @@ import openai
from fastapi import Depends, Header, HTTPException, Request, UploadFile
from starlette.authentication import has_required_scope
from khoj.database.adapters import ConversationAdapters, EntryAdapters
from khoj.database.adapters import AgentAdapters, ConversationAdapters, EntryAdapters
from khoj.database.models import (
ChatModelOptions,
ClientApplication,
Conversation,
KhojUser,
Subscription,
TextToImageModelConfig,
@@ -364,6 +365,7 @@ async def send_message_to_model_wrapper(
def generate_chat_response(
q: str,
meta_log: dict,
conversation: Conversation,
compiled_references: List[str] = [],
online_results: Dict[str, Any] = {},
inferred_queries: List[str] = [],
@@ -379,6 +381,7 @@ def generate_chat_response(
logger.debug(f"Conversation Types: {conversation_commands}")
metadata = {}
agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
try:
partial_completion = partial(
@@ -393,7 +396,7 @@ def generate_chat_response(
conversation_id=conversation_id,
)
conversation_config = ConversationAdapters.get_valid_conversation_config(user)
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
if conversation_config.model_type == "offline":
if state.gpt4all_processor_config is None or state.gpt4all_processor_config.loaded_model is None:
state.gpt4all_processor_config = GPT4AllProcessorModel(conversation_config.chat_model)
@@ -412,6 +415,7 @@ def generate_chat_response(
tokenizer_name=conversation_config.tokenizer,
location_data=location_data,
user_name=user_name,
agent=agent,
)
elif conversation_config.model_type == "openai":
@@ -431,6 +435,7 @@ def generate_chat_response(
tokenizer_name=conversation_config.tokenizer,
location_data=location_data,
user_name=user_name,
agent=agent,
)
metadata.update({"chat_model": conversation_config.chat_model})