Support multiple chat sessions within the web UI (#638)

* Enable support for multiple chat sessions within the web client

- Allow users to create multiple chat sessions and manage them
- Give chat session slugs based on the most recent message
- Update web UI to have a collapsible menu with active chats
- Move chat routes into a separate file

* Make the collapsible side panel more graceful, improve some styling elements of the new layout

* Support modification of the conversation title

- Add a new field to the conversation object
- Update UI to add a threedotmenu to each conversation

* Get the default conversation if a matching one is not found by id
This commit is contained in:
sabaimran
2024-02-11 02:18:28 -08:00
committed by GitHub
parent 208ccc83ec
commit 1412ed6a00
15 changed files with 981 additions and 301 deletions

View File

@@ -107,7 +107,7 @@ a.khoj-nav-selected {
background-color: var(--primary);
}
img.khoj-logo {
width: min(60vw, 111px);
width: min(60vw, 90px);
max-width: 100%;
justify-self: center;
}
@@ -117,7 +117,7 @@ img.khoj-logo {
display: grid;
grid-auto-flow: column;
gap: 20px;
padding: 16px 10px;
padding: 12px 10px;
margin: 0 0 16px 0;
}

View File

@@ -54,7 +54,6 @@
// Add event listener to toggle full reference on click
referenceButton.addEventListener('click', function() {
console.log(`Toggling ref-${index}`)
if (this.classList.contains("collapsed")) {
this.classList.remove("collapsed");
this.classList.add("expanded");
@@ -100,7 +99,6 @@
// Add event listener to toggle full reference on click
referenceButton.addEventListener('click', function() {
console.log(`Toggling ref-${index}`)
if (this.classList.contains("collapsed")) {
this.classList.remove("collapsed");
this.classList.add("expanded");
@@ -586,8 +584,10 @@
return data.response;
})
.then(response => {
const fullChatLog = response.chat;
// Render conversation history, if any
response.forEach(chat_log => {
fullChatLog.forEach(chat_log => {
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type, chat_log.intent?.["inferred-queries"]);
});
})

View File

@@ -269,7 +269,7 @@ export class KhojChatModal extends Modal {
return false;
} else if (responseJson.response) {
let chatLogs = responseJson.response;
let chatLogs = responseJson.response.chat;
chatLogs.forEach((chatLog: any) => {
this.renderMessageWithReferences(chatBodyEl, chatLog.message, chatLog.by, chatLog.context, new Date(chatLog.created), chatLog.intent?.type);
});

View File

@@ -256,6 +256,7 @@ def initialize_content(regenerate: bool, search_type: Optional[SearchType] = Non
def configure_routes(app):
# Import APIs here to setup search types before while configuring server
from khoj.routers.api import api
from khoj.routers.api_chat import api_chat
from khoj.routers.api_config import api_config
from khoj.routers.auth import auth_router
from khoj.routers.indexer import indexer
@@ -266,6 +267,7 @@ def configure_routes(app):
app.include_router(indexer, prefix="/api/v1/index")
app.include_router(web_client)
app.include_router(auth_router, prefix="/auth")
app.include_router(api_chat, prefix="/api/chat")
if state.billing_enabled:
from khoj.routers.subscription import subscription_router

View File

@@ -357,22 +357,61 @@ class ClientApplicationAdapters:
class ConversationAdapters:
@staticmethod
def get_conversation_by_user(user: KhojUser, client_application: ClientApplication = None):
conversation = Conversation.objects.filter(user=user, client=client_application)
def get_conversation_by_user(
user: KhojUser, client_application: ClientApplication = None, conversation_id: int = None
):
if conversation_id:
conversation = Conversation.objects.filter(user=user, client=client_application, id=conversation_id)
if not conversation_id or not conversation.exists():
conversation = Conversation.objects.filter(user=user, client=client_application)
if conversation.exists():
return conversation.first()
return Conversation.objects.create(user=user, client=client_application)
@staticmethod
async def aget_conversation_by_user(user: KhojUser, client_application: ClientApplication = None):
conversation = Conversation.objects.filter(user=user, client=client_application)
if await conversation.aexists():
return await conversation.afirst()
def get_conversation_sessions(user: KhojUser, client_application: ClientApplication = None):
return Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at")
@staticmethod
async def aset_conversation_title(
user: KhojUser, client_application: ClientApplication, conversation_id: int, title: str
):
conversation = await Conversation.objects.filter(
user=user, client=client_application, id=conversation_id
).afirst()
if conversation:
conversation.title = title
await conversation.asave()
return conversation
return None
@staticmethod
def get_conversation_by_id(conversation_id: int):
return Conversation.objects.filter(id=conversation_id).first()
@staticmethod
async def acreate_conversation_session(user: KhojUser, client_application: ClientApplication = None):
return await Conversation.objects.acreate(user=user, client=client_application)
@staticmethod
async def adelete_conversation_by_user(user: KhojUser):
return await Conversation.objects.filter(user=user).adelete()
async def aget_conversation_by_user(
user: KhojUser, client_application: ClientApplication = None, conversation_id: int = None, slug: str = None
):
if conversation_id:
conversation = Conversation.objects.filter(user=user, client=client_application, id=conversation_id)
else:
conversation = Conversation.objects.filter(user=user, client=client_application, slug=slug)
if await conversation.aexists():
return await conversation.afirst()
return await Conversation.objects.acreate(user=user, client=client_application, slug=slug)
@staticmethod
async def adelete_conversation_by_user(
user: KhojUser, client_application: ClientApplication = None, conversation_id: int = None
):
if conversation_id:
return await Conversation.objects.filter(user=user, client=client_application, id=conversation_id).adelete()
return await Conversation.objects.filter(user=user, client=client_application).adelete()
@staticmethod
def has_any_conversation_config(user: KhojUser):
@@ -433,12 +472,24 @@ class ConversationAdapters:
return await ChatModelOptions.objects.filter().afirst()
@staticmethod
def save_conversation(user: KhojUser, conversation_log: dict, client_application: ClientApplication = None):
conversation = Conversation.objects.filter(user=user, client=client_application)
if conversation.exists():
conversation.update(conversation_log=conversation_log)
def save_conversation(
user: KhojUser,
conversation_log: dict,
client_application: ClientApplication = None,
conversation_id: int = None,
user_message: str = None,
):
slug = user_message.strip()[:200] if not is_none_or_empty(user_message) else None
if conversation_id:
conversation = Conversation.objects.filter(user=user, client=client_application, id=conversation_id)
else:
Conversation.objects.create(user=user, conversation_log=conversation_log, client=client_application)
conversation = Conversation.objects.filter(user=user, client=client_application)
if conversation.exists():
conversation.update(conversation_log=conversation_log, slug=slug)
else:
Conversation.objects.create(
user=user, conversation_log=conversation_log, client=client_application, slug=slug
)
@staticmethod
def get_conversation_processor_options():

View File

@@ -0,0 +1,22 @@
# Generated by Django 4.2.7 on 2024-02-05 04:39
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0029_userrequests"),
]
operations = [
migrations.AddField(
model_name="conversation",
name="slug",
field=models.CharField(blank=True, default=None, max_length=200, null=True),
),
migrations.AddField(
model_name="conversation",
name="title",
field=models.CharField(blank=True, default=None, max_length=200, null=True),
),
]

View File

@@ -178,6 +178,8 @@ class Conversation(BaseModel):
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
conversation_log = models.JSONField(default=dict)
client = models.ForeignKey(ClientApplication, on_delete=models.CASCADE, default=None, null=True, blank=True)
slug = models.CharField(max_length=200, default=None, null=True, blank=True)
title = models.CharField(max_length=200, default=None, null=True, blank=True)
class ReflectiveQuestion(BaseModel):

View File

@@ -110,7 +110,7 @@ a.khoj-logo {
background-color: var(--primary);
}
img.khoj-logo {
width: min(60vw, 111px);
width: min(60vw, 90px);
max-width: 100%;
justify-self: center;
}
@@ -202,7 +202,7 @@ img.khoj-logo {
grid-auto-flow: column;
gap: 20px;
padding: 16px 10px;
margin: 0 0 16px 0;
margin: 0;
}
nav.khoj-nav {

View File

@@ -66,7 +66,6 @@ To get started, just start typing below. You can also type / to see a list of co
// Add event listener to toggle full reference on click
referenceButton.addEventListener('click', function() {
console.log(`Toggling ref-${index}`)
if (this.classList.contains("collapsed")) {
this.classList.remove("collapsed");
this.classList.add("expanded");
@@ -112,7 +111,6 @@ To get started, just start typing below. You can also type / to see a list of co
// Add event listener to toggle full reference on click
referenceButton.addEventListener('click', function() {
console.log(`Toggling ref-${index}`)
if (this.classList.contains("collapsed")) {
this.classList.remove("collapsed");
this.classList.add("expanded");
@@ -154,6 +152,9 @@ To get started, just start typing below. You can also type / to see a list of co
// Scroll to bottom of chat-body element
chatBody.scrollTop = chatBody.scrollHeight;
let chatBodyWrapper = document.getElementById("chat-body-wrapper");
chatBodyWrapperHeight = chatBodyWrapper.clientHeight;
}
function processOnlineReferences(referenceSection, onlineContext) {
@@ -332,11 +333,20 @@ To get started, just start typing below. You can also type / to see a list of co
document.getElementById("chat-input").value = "";
autoResize();
document.getElementById("chat-input").setAttribute("disabled", "disabled");
let chat_body = document.getElementById("chat-body");
let conversationID = chat_body.dataset.conversationId;
if (!conversationID) {
let response = await fetch('/api/chat/sessions', { method: "POST" });
let data = await response.json();
conversationID = data.conversation_id;
chat_body.dataset.conversationId = conversationID;
}
// Generate backend API URL to execute query
let url = `/api/chat?q=${encodeURIComponent(query)}&n=${resultsCount}&client=web&stream=true`;
let url = `/api/chat?q=${encodeURIComponent(query)}&n=${resultsCount}&client=web&stream=true&conversation_id=${conversationID}`;
let chat_body = document.getElementById("chat-body");
let new_response = document.createElement("div");
new_response.classList.add("chat-message", "khoj");
new_response.attributes["data-meta"] = "🏮 Khoj at " + formatDate(new Date());
@@ -537,7 +547,18 @@ To get started, just start typing below. You can also type / to see a list of co
window.onload = loadChat;
function loadChat() {
fetch('/api/chat/history?client=web')
let chatBody = document.getElementById("chat-body");
let conversationId = chatBody.dataset.conversationId;
let chatHistoryUrl = `/api/chat/history?client=web`;
if (conversationId) {
chatHistoryUrl += `&conversation_id=${conversationId}`;
}
if (window.screen.width < 700) {
handleCollapseSidePanel();
}
fetch(chatHistoryUrl, { method: "GET" })
.then(response => response.json())
.then(data => {
if (data.detail) {
@@ -556,9 +577,172 @@ To get started, just start typing below. You can also type / to see a list of co
})
.then(response => {
// Render conversation history, if any
response.forEach(chat_log => {
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type, chat_log.intent?.["inferred-queries"]);
conversationId = response.conversation_id;
const conversationTitle = response.slug || `New conversation 🌱`;
let chatBody = document.getElementById("chat-body");
chatBody.dataset.conversationId = conversationId;
chatBody.dataset.conversationTitle = conversationTitle;
const fullChatLog = response.chat || [];
fullChatLog.forEach(chat_log => {
if (chat_log.message != null){
renderMessageWithReference(
chat_log.message,
chat_log.by,
chat_log.context,
new Date(chat_log.created),
chat_log.onlineContext,
chat_log.intent?.type,
chat_log.intent?.["inferred-queries"]);
}
});
let chatBodyWrapper = document.getElementById("chat-body-wrapper");
chatBodyWrapperHeight = chatBodyWrapper.clientHeight;
chatBody.style.height = chatBodyWrapperHeight;
})
.catch(err => {
console.log(err);
return;
});
fetch('/api/chat/sessions', { method: "GET" })
.then(response => response.json())
.then(data => {
let conversationListBody = document.getElementById("conversation-list-body");
conversationListBody.innerHTML = "";
let conversationListBodyHeader = document.getElementById("conversation-list-header");
let chatBody = document.getElementById("chat-body");
conversationId = chatBody.dataset.conversationId;
if (data.length > 0) {
conversationListBodyHeader.style.display = "block";
for (let index in data) {
let conversation = data[index];
let conversationButton = document.createElement('div');
let incomingConversationId = conversation["conversation_id"];
const conversationTitle = conversation["slug"] || `New conversation 🌱`;
conversationButton.innerHTML = conversationTitle;
conversationButton.classList.add("conversation-button");
if (incomingConversationId == conversationId) {
conversationButton.classList.add("selected-conversation");
}
conversationButton.addEventListener('click', function() {
let chatBody = document.getElementById("chat-body");
chatBody.innerHTML = "";
chatBody.dataset.conversationId = incomingConversationId;
chatBody.dataset.conversationTitle = conversationTitle;
loadChat();
});
let threeDotMenu = document.createElement('div');
threeDotMenu.classList.add("three-dot-menu");
let threeDotMenuButton = document.createElement('button');
threeDotMenuButton.innerHTML = "⋮";
threeDotMenuButton.classList.add("three-dot-menu-button");
threeDotMenuButton.addEventListener('click', function(event) {
event.stopPropagation();
let existingChildren = threeDotMenu.children;
if (existingChildren.length > 1) {
// Skip deleting the first, since that's the menu button.
for (let i = 1; i < existingChildren.length; i++) {
existingChildren[i].remove();
}
return;
}
let conversationMenu = document.createElement('div');
conversationMenu.classList.add("conversation-menu");
let deleteButton = document.createElement('button');
deleteButton.innerHTML = "Delete";
deleteButton.classList.add("delete-conversation-button");
deleteButton.classList.add("three-dot-menu-button-item");
deleteButton.addEventListener('click', function() {
let deleteURL = `/api/chat/history?client=web&conversation_id=${incomingConversationId}`;
fetch(deleteURL , { method: "DELETE" })
.then(response => response.ok ? response.json() : Promise.reject(response))
.then(data => {
let chatBody = document.getElementById("chat-body");
chatBody.innerHTML = "";
chatBody.dataset.conversationId = "";
chatBody.dataset.conversationTitle = "";
loadChat();
})
.catch(err => {
return;
});
});
conversationMenu.appendChild(deleteButton);
threeDotMenu.appendChild(conversationMenu);
let editTitleButton = document.createElement('button');
editTitleButton.innerHTML = "Rename";
editTitleButton.classList.add("edit-title-button");
editTitleButton.classList.add("three-dot-menu-button-item");
editTitleButton.addEventListener('click', function(event) {
event.stopPropagation();
let conversationMenuChildren = conversationMenu.children;
let totalItems = conversationMenuChildren.length;
for (let i = totalItems - 1; i >= 0; i--) {
conversationMenuChildren[i].remove();
}
// Create a dialog box to get new title for conversation
let conversationTitleInputBox = document.createElement('div');
conversationTitleInputBox.classList.add("conversation-title-input-box");
let conversationTitleInput = document.createElement('input');
conversationTitleInput.classList.add("conversation-title-input");
conversationTitleInput.value = conversationTitle;
conversationTitleInput.addEventListener('click', function(event) {
event.stopPropagation();
if (event.key === "Enter") {
event.preventDefault();
conversationTitleInputButton.click();
}
});
conversationTitleInputBox.appendChild(conversationTitleInput);
let conversationTitleInputButton = document.createElement('button');
conversationTitleInputButton.innerHTML = "Save";
conversationTitleInputButton.classList.add("three-dot-menu-button-item");
conversationTitleInputButton.addEventListener('click', function(event) {
event.stopPropagation();
let newTitle = conversationTitleInput.value;
if (newTitle != null) {
let editURL = `/api/chat/title?client=web&conversation_id=${incomingConversationId}&title=${newTitle}`;
fetch(editURL , { method: "PATCH" })
.then(response => response.ok ? response.json() : Promise.reject(response))
.then(data => {
conversationButton.innerHTML = newTitle;
})
.catch(err => {
return;
});
conversationTitleInputBox.remove();
}});
conversationTitleInputBox.appendChild(conversationTitleInputButton);
conversationMenu.appendChild(conversationTitleInputBox);
});
conversationMenu.appendChild(editTitleButton);
threeDotMenu.appendChild(conversationMenu);
});
threeDotMenu.appendChild(threeDotMenuButton);
conversationButton.appendChild(threeDotMenu);
conversationListBody.appendChild(conversationButton);
}
}
})
.catch(err => {
console.log(err);
@@ -622,15 +806,32 @@ To get started, just start typing below. You can also type / to see a list of co
}, 2000);
}
function createNewConversation() {
let chatBody = document.getElementById("chat-body");
chatBody.innerHTML = "";
flashStatusInChatInput("📝 New conversation started");
chatBody.dataset.conversationId = "";
chatBody.dataset.conversationTitle = "";
renderMessage(welcome_message, "khoj");
}
function clearConversationHistory() {
let chatInput = document.getElementById("chat-input");
let originalPlaceholder = chatInput.placeholder;
let chatBody = document.getElementById("chat-body");
let conversationId = chatBody.dataset.conversationId;
fetch(`/api/chat/history?client=web`, { method: "DELETE" })
let deleteURL = `/api/chat/history?client=web`;
if (conversationId) {
deleteURL += `&conversation_id=${conversationId}`;
}
fetch(deleteURL , { method: "DELETE" })
.then(response => response.ok ? response.json() : Promise.reject(response))
.then(data => {
chatBody.innerHTML = "";
chatBody.dataset.conversationId = "";
chatBody.dataset.conversationTitle = "";
loadChat();
flashStatusInChatInput("🗑 Cleared conversation history");
})
@@ -739,6 +940,14 @@ To get started, just start typing below. You can also type / to see a list of co
// Stop the countdown timer UI
document.getElementById('countdown-circle').style.animation = "none";
};
function handleCollapseSidePanel() {
document.getElementById('side-panel').classList.toggle('collapsed');
document.getElementById('new-conversation').classList.toggle('collapsed');
document.getElementById('existing-conversations').classList.toggle('collapsed');
document.getElementById('chat-section-wrapper').classList.toggle('mobile-friendly');
}
</script>
<body>
<div id="khoj-empty-container" class="khoj-empty-container">
@@ -747,53 +956,89 @@ To get started, just start typing below. You can also type / to see a list of co
<!--Add Header Logo and Nav Pane-->
{% import 'utils.html' as utils %}
{{ utils.heading_pane(user_photo, username, is_active, has_documents) }}
<div id="chat-section-wrapper">
<div id="side-panel-wrapper">
<div id="side-panel">
<div id="new-conversation">
<button class="side-panel-button" id="new-conversation-button" onclick="createNewConversation()">
New Topic
<svg class="new-convo-button" viewBox="0 0 35 35" fill="#000000" viewBox="0 0 32 32" version="1.1" xmlns="http://www.w3.org/2000/svg">
<path d="M16 0c-8.836 0-16 7.163-16 16s7.163 16 16 16c8.837 0 16-7.163 16-16s-7.163-16-16-16zM16 30.032c-7.72 0-14-6.312-14-14.032s6.28-14 14-14 14 6.28 14 14-6.28 14.032-14 14.032zM23 15h-6v-6c0-0.552-0.448-1-1-1s-1 0.448-1 1v6h-6c-0.552 0-1 0.448-1 1s0.448 1 1 1h6v6c0 0.552 0.448 1 1 1s1-0.448 1-1v-6h6c0.552 0 1-0.448 1-1s-0.448-1-1-1z"></path>
</svg>
</button>
</div>
<div id="existing-conversations">
<div id="conversation-list">
<div id="conversation-list-header" style="display: none;">Recent Conversations</div>
<div id="conversation-list-body"></div>
</div>
</div>
</div>
<div id="collapse-side-panel">
<button
class="side-panel-button"
id="collapse-side-panel-button"
onclick="handleCollapseSidePanel()"
>
<svg class="side-panel-collapse" viewBox="0 0 25 25" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M7.82054 20.7313C8.21107 21.1218 8.84423 21.1218 9.23476 20.7313L15.8792 14.0868C17.0505 12.9155 17.0508 11.0167 15.88 9.84497L9.3097 3.26958C8.91918 2.87905 8.28601 2.87905 7.89549 3.26958C7.50497 3.6601 7.50497 4.29327 7.89549 4.68379L14.4675 11.2558C14.8581 11.6464 14.8581 12.2795 14.4675 12.67L7.82054 19.317C7.43002 19.7076 7.43002 20.3407 7.82054 20.7313Z" fill="#0F0F0F"/>
</svg>
</button>
</div>
</div>
<div id="chat-body-wrapper">
<!-- Chat Body -->
<div id="chat-body"></div>
<!-- Chat Body -->
<div id="chat-body"></div>
<!-- Chat Suggestions -->
<div id="question-starters" style="display: none;"></div>
<!-- Chat Suggestions -->
<div id="question-starters" style="display: none;"></div>
<!-- Chat Footer -->
<div id="chat-footer">
<div id="chat-tooltip" style="display: none;"></div>
<div id="input-row">
<button id="clear-chat-button" class="input-row-button" onclick="clearConversationHistory()">
<svg class="input-row-button-img" alt="Clear Chat History" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 256 256">
<rect width="128" height="128" fill="none"/>
<line x1="216" y1="56" x2="40" y2="56" fill="none" stroke="#000" stroke-linecap="round" stroke-linejoin="round" stroke-width="12"/>
<line x1="104" y1="104" x2="104" y2="168" fill="none" stroke="#000" stroke-linecap="round" stroke-linejoin="round" stroke-width="12"/>
<line x1="152" y1="104" x2="152" y2="168" fill="none" stroke="#000" stroke-linecap="round" stroke-linejoin="round" stroke-width="12"/>
<path d="M200,56V208a8,8,0,0,1-8,8H64a8,8,0,0,1-8-8V56" fill="none" stroke="#000" stroke-linecap="round" stroke-linejoin="round" stroke-width="12"/>
<path d="M168,56V40a16,16,0,0,0-16-16H104A16,16,0,0,0,88,40V56" fill="none" stroke="#000" stroke-linecap="round" stroke-linejoin="round" stroke-width="12"/>
</svg>
</button>
<textarea id="chat-input" class="option" oninput="onChatInput()" onkeydown=incrementalChat(event) autofocus="autofocus" placeholder="Type / to see a list of commands"></textarea>
<button id="speak-button" class="input-row-button"
ontouchstart="speechToText(event)" ontouchend="speechToText(event)" ontouchcancel="speechToText(event)" onmousedown="speechToText(event)">
<svg id="speak-button-img" class="input-row-button-img" alt="Transcribe" xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" viewBox="0 0 16 16">
<path d="M3.5 6.5A.5.5 0 0 1 4 7v1a4 4 0 0 0 8 0V7a.5.5 0 0 1 1 0v1a5 5 0 0 1-4.5 4.975V15h3a.5.5 0 0 1 0 1h-7a.5.5 0 0 1 0-1h3v-2.025A5 5 0 0 1 3 8V7a.5.5 0 0 1 .5-.5z"/>
<path d="M10 8a2 2 0 1 1-4 0V3a2 2 0 1 1 4 0v5zM8 0a3 3 0 0 0-3 3v5a3 3 0 0 0 6 0V3a3 3 0 0 0-3-3z"/>
</svg>
<svg id="stop-record-button-img" style="display: none" class="input-row-button-img" alt="Stop Transcribing" xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" viewBox="0 0 16 16">
<path d="M8 15A7 7 0 1 1 8 1a7 7 0 0 1 0 14zm0 1A8 8 0 1 0 8 0a8 8 0 0 0 0 16z"/>
<path d="M5 6.5A1.5 1.5 0 0 1 6.5 5h3A1.5 1.5 0 0 1 11 6.5v3A1.5 1.5 0 0 1 9.5 11h-3A1.5 1.5 0 0 1 5 9.5v-3z"/>
</svg>
</button>
<button id="send-button" class="input-row-button" alt="Send message">
<svg id="send-button-img" onclick="chat()" class="input-row-button-img" xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" viewBox="0 0 16 16">
<path fill-rule="evenodd" d="M1 8a7 7 0 1 0 14 0A7 7 0 0 0 1 8zm15 0A8 8 0 1 1 0 8a8 8 0 0 1 16 0zm-7.5 3.5a.5.5 0 0 1-1 0V5.707L5.354 7.854a.5.5 0 1 1-.708-.708l3-3a.5.5 0 0 1 .708 0l3 3a.5.5 0 0 1-.708.708L8.5 5.707V11.5z"/>
</svg>
<svg id="stop-send-button-img" onclick="cancelSendMessage()" style="display: none" class="input-row-button-img" alt="Stop Message Send" xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" viewBox="0 0 16 16">
<circle id="countdown-circle" class="countdown-circle" cx="8" cy="8" r="7" />
<path d="M5 6.5A1.5 1.5 0 0 1 6.5 5h3A1.5 1.5 0 0 1 11 6.5v3A1.5 1.5 0 0 1 9.5 11h-3A1.5 1.5 0 0 1 5 9.5v-3z"/>
</svg>
</button>
<!-- Chat Footer -->
<div id="chat-footer">
<div id="chat-tooltip" style="display: none;"></div>
<div id="input-row">
<button id="clear-chat-button" class="input-row-button" onclick="clearConversationHistory()">
<svg class="input-row-button-img" alt="Clear Chat History" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 256 256">
<rect width="128" height="128" fill="none"/>
<line x1="216" y1="56" x2="40" y2="56" fill="none" stroke="#000" stroke-linecap="round" stroke-linejoin="round" stroke-width="12"/>
<line x1="104" y1="104" x2="104" y2="168" fill="none" stroke="#000" stroke-linecap="round" stroke-linejoin="round" stroke-width="12"/>
<line x1="152" y1="104" x2="152" y2="168" fill="none" stroke="#000" stroke-linecap="round" stroke-linejoin="round" stroke-width="12"/>
<path d="M200,56V208a8,8,0,0,1-8,8H64a8,8,0,0,1-8-8V56" fill="none" stroke="#000" stroke-linecap="round" stroke-linejoin="round" stroke-width="12"/>
<path d="M168,56V40a16,16,0,0,0-16-16H104A16,16,0,0,0,88,40V56" fill="none" stroke="#000" stroke-linecap="round" stroke-linejoin="round" stroke-width="12"/>
</svg>
</button>
<textarea id="chat-input" class="option" oninput="onChatInput()" onkeydown=incrementalChat(event) autofocus="autofocus" placeholder="Type / to see a list of commands"></textarea>
<button id="speak-button" class="input-row-button"
ontouchstart="speechToText(event)" ontouchend="speechToText(event)" ontouchcancel="speechToText(event)" onmousedown="speechToText(event)">
<svg id="speak-button-img" class="input-row-button-img" alt="Transcribe" xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" viewBox="0 0 16 16">
<path d="M3.5 6.5A.5.5 0 0 1 4 7v1a4 4 0 0 0 8 0V7a.5.5 0 0 1 1 0v1a5 5 0 0 1-4.5 4.975V15h3a.5.5 0 0 1 0 1h-7a.5.5 0 0 1 0-1h3v-2.025A5 5 0 0 1 3 8V7a.5.5 0 0 1 .5-.5z"/>
<path d="M10 8a2 2 0 1 1-4 0V3a2 2 0 1 1 4 0v5zM8 0a3 3 0 0 0-3 3v5a3 3 0 0 0 6 0V3a3 3 0 0 0-3-3z"/>
</svg>
<svg id="stop-record-button-img" style="display: none" class="input-row-button-img" alt="Stop Transcribing" xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" viewBox="0 0 16 16">
<path d="M8 15A7 7 0 1 1 8 1a7 7 0 0 1 0 14zm0 1A8 8 0 1 0 8 0a8 8 0 0 0 0 16z"/>
<path d="M5 6.5A1.5 1.5 0 0 1 6.5 5h3A1.5 1.5 0 0 1 11 6.5v3A1.5 1.5 0 0 1 9.5 11h-3A1.5 1.5 0 0 1 5 9.5v-3z"/>
</svg>
</button>
<button id="send-button" class="input-row-button" alt="Send message">
<svg id="send-button-img" onclick="chat()" class="input-row-button-img" xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" viewBox="0 0 16 16">
<path fill-rule="evenodd" d="M1 8a7 7 0 1 0 14 0A7 7 0 0 0 1 8zm15 0A8 8 0 1 1 0 8a8 8 0 0 1 16 0zm-7.5 3.5a.5.5 0 0 1-1 0V5.707L5.354 7.854a.5.5 0 1 1-.708-.708l3-3a.5.5 0 0 1 .708 0l3 3a.5.5 0 0 1-.708.708L8.5 5.707V11.5z"/>
</svg>
<svg id="stop-send-button-img" onclick="cancelSendMessage()" style="display: none" class="input-row-button-img" alt="Stop Message Send" xmlns="http://www.w3.org/2000/svg" width="16" height="16" fill="currentColor" viewBox="0 0 16 16">
<circle id="countdown-circle" class="countdown-circle" cx="8" cy="8" r="7" />
<path d="M5 6.5A1.5 1.5 0 0 1 6.5 5h3A1.5 1.5 0 0 1 11 6.5v3A1.5 1.5 0 0 1 9.5 11h-3A1.5 1.5 0 0 1 5 9.5v-3z"/>
</svg>
</button>
</div>
</div>
</div>
</div>
</body>
<script>
document.getElementById("chat-nav").classList.add("khoj-nav-selected");
// Set the active nav pane
let chatNav = document.getElementById("chat-nav");
if (chatNav) {
chatNav.classList.add("khoj-nav-selected");
}
</script>
<style>
html, body {
@@ -811,6 +1056,8 @@ To get started, just start typing below. You can also type / to see a list of co
font-size: 20px;
font-weight: 300;
line-height: 1.5em;
height: 100vh;
margin: 0;
}
body > * {
padding: 10px;
@@ -932,11 +1179,79 @@ To get started, just start typing below. You can also type / to see a list of co
line-height: 1.5em;
}
input.conversation-title-input {
font-family: var(--font-family);
font-size: 14px;
font-weight: 300;
line-height: 1.5em;
padding: 5px;
border: 1px solid var(--main-text-color);
border-radius: 5px;
margin: 4px;
}
input.conversation-title-input:focus {
outline: none;
}
#chat-section-wrapper {
display: grid;
grid-template-columns: auto auto;
grid-column-gap: 10px;
grid-row-gap: 10px;
padding: 10px;
margin: 10px;
overflow-y: scroll;
}
#chat-section-wrapper.mobile-friendly {
grid-template-columns: auto auto;
}
#chat-body-wrapper {
display: flex;
flex-direction: column;
overflow: hidden;
}
#side-panel {
padding: 10px;
background: var(--background-color);
border-radius: 5px;
box-shadow: 0 0 11px #aaa;
overflow-y: scroll;
text-align: left;
transition: width 0.3s ease-in-out;
width: 250px;
}
div#side-panel.collapsed {
width: 1px;
display: block;
overflow: hidden;
}
div#collapse-side-panel {
align-self: center;
padding: 8px;
}
div#conversation-list-body {
display: grid;
grid-template-columns: 1fr;
grid-gap: 8px;
}
div#side-panel-wrapper {
display: flex
}
#chat-body {
font-size: medium;
margin: 0px;
line-height: 20px;
overflow-y: scroll; /* Make chat body scroll to see history */
overflow-y: scroll;
overflow-x: hidden;
}
/* add chat metatdata to bottom of bubble */
.chat-message::after {
@@ -944,7 +1259,7 @@ To get started, just start typing below. You can also type / to see a list of co
display: block;
font-size: x-small;
color: var(--main-text-color);
margin: -8px 4px 0 -5px;
margin: -8px 4px 0px 0px;
}
/* move message by khoj to left */
.chat-message.khoj {
@@ -1070,12 +1385,36 @@ To get started, just start typing below. You can also type / to see a list of co
margin-top: -2px;
margin-left: -5px;
}
.side-panel-button {
background: var(--background-color);
border: none;
box-shadow: none;
font-size: 14px;
font-weight: 300;
line-height: 1.5em;
cursor: pointer;
transition: background 0.3s ease-in-out;
border-radius: 5%;;
font-family: var(--font-family);
padding: 8px;
font-size: large;
}
svg.side-panel-collapse {
width: 30px;
height: 30px;
}
.side-panel-button:hover,
.input-row-button:hover {
background: var(--primary-hover);
}
.side-panel-button:active,
.input-row-button:active {
background: var(--primary-active);
}
.input-row-button-img {
width: 24px;
height: 24px;
@@ -1193,6 +1532,32 @@ To get started, just start typing below. You can also type / to see a list of co
#clear-chat-button {
margin-left: 0;
}
div#side-panel.collapsed {
width: 0px;
display: block;
overflow: hidden;
padding: 0;
}
svg.side-panel-collapse {
width: 24px;
height: 24px;
}
#chat-body-wrapper {
min-width: 0;
}
div#chat-section-wrapper {
padding: 4px;
margin: 4px;
grid-column-gap: 4px;
}
div#collapse-side-panel {
align-self: center;
padding: 0px;
}
}
@media only screen and (min-width: 700px) {
body {
@@ -1209,6 +1574,110 @@ To get started, just start typing below. You can also type / to see a list of co
font-size: medium;
}
svg.new-convo-button {
width: 20px;
margin-left: 5px;
}
div#new-conversation {
text-align: left;
border-bottom: 1px solid var(--main-text-color);
margin-bottom: 8px;
}
button#new-conversation-button {
display: inline-flex;
align-items: center;
}
div.conversation-button {
background: var(--background-color);
color: var(--main-text-color);
border: 1px solid var(--main-text-color);
border-radius: 5px;
padding: 5px;
font-size: 14px;
font-weight: 300;
line-height: 1.5em;
cursor: pointer;
transition: background 0.2s ease-in-out;
text-align: left;
display: flex;
position: relative;
}
.three-dot-menu {
display: none;
/* background: var(--background-color); */
/* border: 1px solid var(--main-text-color); */
border-radius: 5px;
/* position: relative; */
position: absolute;
right: 4;
top: 4;
}
button.three-dot-menu-button-item {
background: var(--background-color);
color: var(--main-text-color);
border: none;
box-shadow: none;
font-size: 14px;
font-weight: 300;
line-height: 1.5em;
cursor: pointer;
transition: background 0.3s ease-in-out;
font-family: var(--font-family);
border-radius: 4px;
right: 0;
}
button.three-dot-menu-button-item:hover {
background: var(--primary-hover);
color: var(--primary-inverse);
}
.three-dot-menu-button {
background: var(--background-color);
border: none;
box-shadow: none;
font-size: 14px;
font-weight: 300;
line-height: 1.5em;
cursor: pointer;
transition: background 0.3s ease-in-out;
font-family: var(--font-family);
border-radius: 4px;
right: 0;
}
.conversation-button:hover .three-dot-menu {
display: block;
}
div.conversation-menu {
position: absolute;
z-index: 1;
top: 100%;
right: 0;
text-align: right;
background-color: var(--background-color);
border: 1px solid var(--main-text-color);
border-radius: 5px;
padding: 5px;
box-shadow: 0 0 11px #aaa;
}
div.conversation-button:hover {
background: var(--primary-hover);
color: var(--primary-inverse);
}
div.selected-conversation {
background: var(--primary-hover) !important;
color: var(--primary-inverse) !important;
}
@keyframes gradient {
0% {
background-position: 0% 50%;

View File

@@ -4,8 +4,8 @@
<img class="khoj-logo" src="/static/assets/icons/khoj-logo-sideways-500.png" alt="Khoj"></img>
</a>
<nav class="khoj-nav">
<a id="chat-nav" class="khoj-nav" href="/chat">💬 Chat</a>
{% if has_documents %}
<a id="chat-nav" class="khoj-nav" href="/chat">💬 Chat</a>
<a id="search-nav" class="khoj-nav" href="/search">🔎 Search</a>
{% endif %}
<!-- Dropdown Menu -->

View File

@@ -101,6 +101,7 @@ def save_to_conversation_log(
inferred_queries: List[str] = [],
intent_type: str = "remember",
client_application: ClientApplication = None,
conversation_id: int = None,
):
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
updated_conversation = message_to_log(
@@ -114,7 +115,13 @@ def save_to_conversation_log(
},
conversation_log=meta_log.get("chat", []),
)
ConversationAdapters.save_conversation(user, {"chat": updated_conversation}, client_application=client_application)
ConversationAdapters.save_conversation(
user,
{"chat": updated_conversation},
client_application=client_application,
conversation_id=conversation_id,
user_message=q,
)
def generate_chatml_messages_with_context(

View File

@@ -5,13 +5,12 @@ import math
import os
import time
import uuid
from typing import Any, Dict, List, Optional, Union
from urllib.parse import unquote
from typing import Any, List, Optional, Union
from asgiref.sync import sync_to_async
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
from fastapi.requests import Request
from fastapi.responses import Response, StreamingResponse
from fastapi.responses import Response
from starlette.authentication import requires
from khoj.configure import configure_server
@@ -25,19 +24,11 @@ from khoj.processor.conversation.offline.chat_model import extract_questions_off
from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
from khoj.processor.conversation.openai.gpt import extract_questions
from khoj.processor.conversation.openai.whisper import transcribe_audio
from khoj.processor.conversation.prompts import help_message, no_entries_found
from khoj.processor.conversation.utils import save_to_conversation_log
from khoj.processor.tools.online_search import search_with_google
from khoj.routers.helpers import (
ApiUserRateLimiter,
CommonQueryParams,
ConversationCommandRateLimiter,
agenerate_chat_response,
get_conversation_command,
is_ready_to_chat,
text_to_image,
update_telemetry_state,
validate_conversation_config,
)
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
@@ -45,14 +36,7 @@ from khoj.search_filter.word_filter import WordFilter
from khoj.search_type import image_search, text_search
from khoj.utils import constants, state
from khoj.utils.config import GPT4AllProcessorModel
from khoj.utils.helpers import (
AsyncIteratorWrapper,
ConversationCommand,
command_descriptions,
get_device,
is_none_or_empty,
timer,
)
from khoj.utils.helpers import ConversationCommand, timer
from khoj.utils.rawconfig import SearchResponse
from khoj.utils.state import SearchType
@@ -222,81 +206,6 @@ def update(
return {"status": "ok", "message": "khoj reloaded"}
@api.get("/chat/starters", response_class=Response)
@requires(["authenticated"])
async def chat_starters(
request: Request,
common: CommonQueryParams,
) -> Response:
user: KhojUser = request.user.object
starter_questions = await ConversationAdapters.aget_conversation_starters(user)
return Response(content=json.dumps(starter_questions), media_type="application/json", status_code=200)
@api.get("/chat/history")
@requires(["authenticated"])
def chat_history(
request: Request,
common: CommonQueryParams,
):
user = request.user.object
validate_conversation_config()
# Load Conversation History
meta_log = ConversationAdapters.get_conversation_by_user(
user=user, client_application=request.user.client_app
).conversation_log
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
**common.__dict__,
)
return {"status": "ok", "response": meta_log.get("chat", [])}
@api.delete("/chat/history")
@requires(["authenticated"])
async def clear_chat_history(
request: Request,
common: CommonQueryParams,
):
user = request.user.object
# Clear Conversation History
await ConversationAdapters.adelete_conversation_by_user(user)
update_telemetry_state(
request=request,
telemetry_type="api",
api="clear_chat_history",
**common.__dict__,
)
return {"status": "ok", "message": "Conversation history cleared"}
@api.get("/chat/options", response_class=Response)
@requires(["authenticated"])
async def chat_options(
request: Request,
common: CommonQueryParams,
) -> Response:
cmd_options = {}
for cmd in ConversationCommand:
cmd_options[cmd.value] = command_descriptions[cmd]
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat_options",
**common.__dict__,
)
return Response(content=json.dumps(cmd_options), media_type="application/json", status_code=200)
@api.post("/transcribe")
@requires(["authenticated"])
async def transcribe(
@@ -358,139 +267,6 @@ async def transcribe(
return Response(content=content, media_type="application/json", status_code=200)
@api.get("/chat", response_class=Response)
@requires(["authenticated"])
async def chat(
request: Request,
common: CommonQueryParams,
q: str,
n: Optional[int] = 5,
d: Optional[float] = 0.18,
stream: Optional[bool] = False,
rate_limiter_per_minute=Depends(
ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
),
rate_limiter_per_day=Depends(
ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
),
) -> Response:
user: KhojUser = request.user.object
q = unquote(q)
await is_ready_to_chat(user)
conversation_command = get_conversation_command(query=q, any_references=True)
await conversation_command_rate_limiter.update_and_check_if_valid(request, conversation_command)
q = q.replace(f"/{conversation_command.value}", "").strip()
meta_log = (await ConversationAdapters.aget_conversation_by_user(user, request.user.client_app)).conversation_log
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
request, common, meta_log, q, (n or 5), (d or math.inf), conversation_command
)
online_results: Dict = dict()
if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references):
conversation_command = ConversationCommand.General
elif conversation_command == ConversationCommand.Help:
conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
if conversation_config == None:
conversation_config = await ConversationAdapters.aget_default_conversation_config()
model_type = conversation_config.model_type
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)
elif conversation_command == ConversationCommand.Notes and not await EntryAdapters.auser_has_entries(user):
no_entries_found_format = no_entries_found.format()
if stream:
return StreamingResponse(iter([no_entries_found_format]), media_type="text/event-stream", status_code=200)
else:
response_obj = {"response": no_entries_found_format}
return Response(content=json.dumps(response_obj), media_type="text/plain", status_code=200)
elif conversation_command == ConversationCommand.Online:
try:
online_results = await search_with_google(defiltered_query, meta_log)
except ValueError as e:
return StreamingResponse(
iter(["Please set your SERPER_DEV_API_KEY to get started with online searches 🌐"]),
media_type="text/event-stream",
status_code=200,
)
elif conversation_command == ConversationCommand.Image:
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
metadata={"conversation_command": conversation_command.value},
**common.__dict__,
)
image, status_code, improved_image_prompt = await text_to_image(q, meta_log)
if image is None:
content_obj = {
"image": image,
"intentType": "text-to-image",
"detail": "Failed to generate image. Make sure your image generation configuration is set.",
}
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
await sync_to_async(save_to_conversation_log)(
q,
image,
user,
meta_log,
intent_type="text-to-image",
inferred_queries=[improved_image_prompt],
client_application=request.user.client_app,
)
content_obj = {"image": image, "intentType": "text-to-image", "inferredQueries": [improved_image_prompt]} # type: ignore
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
# Get the (streamed) chat response from the LLM of choice.
llm_response, chat_metadata = await agenerate_chat_response(
defiltered_query,
meta_log,
compiled_references,
online_results,
inferred_queries,
conversation_command,
user,
request.user.client_app,
)
chat_metadata.update({"conversation_command": conversation_command.value})
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
metadata=chat_metadata,
**common.__dict__,
)
if llm_response is None:
return Response(content=llm_response, media_type="text/plain", status_code=500)
if stream:
return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200)
iterator = AsyncIteratorWrapper(llm_response)
# Get the full response from the generator if the stream is not requested.
aggregated_gpt_response = ""
async for item in iterator:
if item is None:
break
aggregated_gpt_response += item
actual_response = aggregated_gpt_response.split("### compiled references:")[0]
response_obj = {"response": actual_response, "context": compiled_references}
return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200)
async def extract_references_and_questions(
request: Request,
common: CommonQueryParams,

View File

@@ -0,0 +1,346 @@
import json
import logging
import math
from typing import Dict, Optional
from urllib.parse import unquote
from asgiref.sync import sync_to_async
from fastapi import APIRouter, Depends, Request
from fastapi.requests import Request
from fastapi.responses import Response, StreamingResponse
from starlette.authentication import requires
from khoj.database.adapters import ConversationAdapters, EntryAdapters
from khoj.database.models import KhojUser
from khoj.processor.conversation.prompts import help_message, no_entries_found
from khoj.processor.conversation.utils import save_to_conversation_log
from khoj.processor.tools.online_search import search_with_google
from khoj.routers.api import extract_references_and_questions
from khoj.routers.helpers import (
ApiUserRateLimiter,
CommonQueryParams,
ConversationCommandRateLimiter,
agenerate_chat_response,
get_conversation_command,
is_ready_to_chat,
text_to_image,
update_telemetry_state,
validate_conversation_config,
)
from khoj.utils import state
from khoj.utils.helpers import (
AsyncIteratorWrapper,
ConversationCommand,
command_descriptions,
get_device,
is_none_or_empty,
)
# Initialize Router
logger = logging.getLogger(__name__)
conversation_command_rate_limiter = ConversationCommandRateLimiter(
trial_rate_limit=2, subscribed_rate_limit=100, slug="command"
)
api_chat = APIRouter()
@api_chat.get("/starters", response_class=Response)
@requires(["authenticated"])
async def chat_starters(
request: Request,
common: CommonQueryParams,
) -> Response:
user: KhojUser = request.user.object
starter_questions = await ConversationAdapters.aget_conversation_starters(user)
return Response(content=json.dumps(starter_questions), media_type="application/json", status_code=200)
@api_chat.get("/history")
@requires(["authenticated"])
def chat_history(
request: Request,
common: CommonQueryParams,
conversation_id: Optional[int] = None,
):
user = request.user.object
validate_conversation_config()
# Load Conversation History
conversation = ConversationAdapters.get_conversation_by_user(
user=user, client_application=request.user.client_app, conversation_id=conversation_id
)
meta_log = conversation.conversation_log
meta_log.update(
{"conversation_id": conversation.id, "slug": conversation.title if conversation.title else conversation.slug}
)
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
**common.__dict__,
)
return {"status": "ok", "response": meta_log}
@api_chat.delete("/history")
@requires(["authenticated"])
async def clear_chat_history(
request: Request,
common: CommonQueryParams,
conversation_id: Optional[int] = None,
):
user = request.user.object
# Clear Conversation History
await ConversationAdapters.adelete_conversation_by_user(user, request.user.client_app, conversation_id)
update_telemetry_state(
request=request,
telemetry_type="api",
api="clear_chat_history",
**common.__dict__,
)
return {"status": "ok", "message": "Conversation history cleared"}
@api_chat.get("/sessions")
@requires(["authenticated"])
def chat_sessions(
request: Request,
common: CommonQueryParams,
):
user = request.user.object
# Load Conversation Sessions
sessions = ConversationAdapters.get_conversation_sessions(user, request.user.client_app).values_list(
"id", "slug", "title"
)
session_values = [{"conversation_id": session[0], "slug": session[2] or session[1]} for session in sessions]
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat_sessions",
**common.__dict__,
)
return Response(content=json.dumps(session_values), media_type="application/json", status_code=200)
@api_chat.post("/sessions")
@requires(["authenticated"])
async def create_chat_session(
request: Request,
common: CommonQueryParams,
):
user = request.user.object
# Create new Conversation Session
conversation = await ConversationAdapters.acreate_conversation_session(user, request.user.client_app)
response = {"conversation_id": conversation.id}
update_telemetry_state(
request=request,
telemetry_type="api",
api="create_chat_sessions",
**common.__dict__,
)
return Response(content=json.dumps(response), media_type="application/json", status_code=200)
@api_chat.get("/options", response_class=Response)
@requires(["authenticated"])
async def chat_options(
request: Request,
common: CommonQueryParams,
) -> Response:
cmd_options = {}
for cmd in ConversationCommand:
cmd_options[cmd.value] = command_descriptions[cmd]
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat_options",
**common.__dict__,
)
return Response(content=json.dumps(cmd_options), media_type="application/json", status_code=200)
@api_chat.patch("/title", response_class=Response)
@requires(["authenticated"])
async def set_conversation_title(
request: Request,
common: CommonQueryParams,
title: str,
conversation_id: Optional[int] = None,
) -> Response:
user = request.user.object
title = title.strip()[:200]
# Set Conversation Title
conversation = await ConversationAdapters.aset_conversation_title(
user, request.user.client_app, conversation_id, title
)
success = True if conversation else False
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_conversation_title",
**common.__dict__,
)
return Response(
content=json.dumps({"status": "ok", "success": success}), media_type="application/json", status_code=200
)
@api_chat.get("", response_class=Response)
@requires(["authenticated"])
async def chat(
request: Request,
common: CommonQueryParams,
q: str,
n: Optional[int] = 5,
d: Optional[float] = 0.18,
stream: Optional[bool] = False,
slug: Optional[str] = None,
conversation_id: Optional[int] = None,
rate_limiter_per_minute=Depends(
ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
),
rate_limiter_per_day=Depends(
ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
),
) -> Response:
user: KhojUser = request.user.object
q = unquote(q)
await is_ready_to_chat(user)
conversation_command = get_conversation_command(query=q, any_references=True)
await conversation_command_rate_limiter.update_and_check_if_valid(request, conversation_command)
q = q.replace(f"/{conversation_command.value}", "").strip()
meta_log = (
await ConversationAdapters.aget_conversation_by_user(user, request.user.client_app, conversation_id, slug)
).conversation_log
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
request, common, meta_log, q, (n or 5), (d or math.inf), conversation_command
)
online_results: Dict = dict()
if conversation_command == ConversationCommand.Default and is_none_or_empty(compiled_references):
conversation_command = ConversationCommand.General
elif conversation_command == ConversationCommand.Help:
conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
if conversation_config == None:
conversation_config = await ConversationAdapters.aget_default_conversation_config()
model_type = conversation_config.model_type
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)
elif conversation_command == ConversationCommand.Notes and not await EntryAdapters.auser_has_entries(user):
no_entries_found_format = no_entries_found.format()
if stream:
return StreamingResponse(iter([no_entries_found_format]), media_type="text/event-stream", status_code=200)
else:
response_obj = {"response": no_entries_found_format}
return Response(content=json.dumps(response_obj), media_type="text/plain", status_code=200)
elif conversation_command == ConversationCommand.Online:
try:
online_results = await search_with_google(defiltered_query, meta_log)
except ValueError as e:
return StreamingResponse(
iter(["Please set your SERPER_DEV_API_KEY to get started with online searches 🌐"]),
media_type="text/event-stream",
status_code=200,
)
elif conversation_command == ConversationCommand.Image:
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
metadata={"conversation_command": conversation_command.value},
**common.__dict__,
)
image, status_code, improved_image_prompt = await text_to_image(q, meta_log)
if image is None:
content_obj = {
"image": image,
"intentType": "text-to-image",
"detail": "Failed to generate image. Make sure your image generation configuration is set.",
}
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
await sync_to_async(save_to_conversation_log)(
q,
image,
user,
meta_log,
intent_type="text-to-image",
inferred_queries=[improved_image_prompt],
client_application=request.user.client_app,
conversation_id=conversation_id,
)
content_obj = {"image": image, "intentType": "text-to-image", "inferredQueries": [improved_image_prompt]} # type: ignore
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
# Get the (streamed) chat response from the LLM of choice.
llm_response, chat_metadata = await agenerate_chat_response(
defiltered_query,
meta_log,
compiled_references,
online_results,
inferred_queries,
conversation_command,
user,
request.user.client_app,
conversation_id,
)
chat_metadata.update({"conversation_command": conversation_command.value})
update_telemetry_state(
request=request,
telemetry_type="api",
api="chat",
metadata=chat_metadata,
**common.__dict__,
)
if llm_response is None:
return Response(content=llm_response, media_type="text/plain", status_code=500)
if stream:
return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200)
iterator = AsyncIteratorWrapper(llm_response)
# Get the full response from the generator if the stream is not requested.
aggregated_gpt_response = ""
async for item in iterator:
if item is None:
break
aggregated_gpt_response += item
actual_response = aggregated_gpt_response.split("### compiled references:")[0]
response_obj = {"response": actual_response, "context": compiled_references}
return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200)

View File

@@ -244,6 +244,7 @@ def generate_chat_response(
conversation_command: ConversationCommand = ConversationCommand.Default,
user: KhojUser = None,
client_application: ClientApplication = None,
conversation_id: int = None,
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
# Initialize Variables
chat_response = None
@@ -261,6 +262,7 @@ def generate_chat_response(
online_results=online_results,
inferred_queries=inferred_queries,
client_application=client_application,
conversation_id=conversation_id,
)
conversation_config = ConversationAdapters.get_valid_conversation_config(user)

View File

@@ -157,6 +157,8 @@ def config_page(request: Request):
for search_model_option in search_model_options:
all_search_model_options.append({"name": search_model_option.name, "id": search_model_option.id})
current_search_model_option = adapters.get_user_search_model_or_default(user)
return templates.TemplateResponse(
"config.html",
context={
@@ -166,6 +168,7 @@ def config_page(request: Request):
"username": user.username,
"conversation_options": all_conversation_options,
"search_model_options": all_search_model_options,
"selected_search_model_config": current_search_model_option.id,
"selected_conversation_config": selected_conversation_config.id if selected_conversation_config else None,
"user_photo": user_picture,
"billing_enabled": state.billing_enabled,