Improve Chat Page Load Perf, Offline Chat Perf and Miscellaneous Fixes (#703)

### Store Generated Images as WebP 
- 78bac4ae Add migration script to convert PNG to WebP references in database
- c6e84436 Update clients to support rendering webp images inline
- d21f22ff Store Khoj generated images as webp instead of png for faster loading

### Lazy Fetch Chat Messages to Improve Time, Data to First Render
This is especially helpful for long conversations with lots of images
- 128829c4 Render latest msgs on chat session load. Fetch, render rest as they near viewport
- 9e558577 Support getting latest N chat messages via chat history API

### Intelligently set Context Window of Offline Chat to Improve Performance
- 4977b551 Use offline chat prompt config to set context window of loaded chat model

### Fixes
- 148923c1 Fix to raise error on hitting rate limit during Github indexing
- b8bc6bee Always remove loading animation on Desktop app if can't login to server
- 38250705 Fix `get_user_photo` to only return photo, not user name from DB

### Miscellaneous Improvements
- 689202e0 Update recommended CMAKE flag to enable using CUDA on linux in Docs
- b820daf3 Makes logs less noisy
This commit is contained in:
Debanjum
2024-04-15 18:34:29 +05:30
committed by GitHub
23 changed files with 510 additions and 131 deletions

View File

@@ -134,7 +134,7 @@ python -m pip install khoj-assistant
# CPU
python -m pip install khoj-assistant
# NVIDIA (CUDA) GPU
CMAKE_ARGS="DLLAMA_CUBLAS=on" FORCE_CMAKE=1 python -m pip install khoj-assistant
CMAKE_ARGS="DLLAMA_CUDA=on" FORCE_CMAKE=1 python -m pip install khoj-assistant
# AMD (ROCm) GPU
CMAKE_ARGS="-DLLAMA_HIPBLAS=on" FORCE_CMAKE=1 python -m pip install khoj-assistant
# VULCAN GPU

View File

@@ -78,6 +78,7 @@ dependencies = [
"phonenumbers == 8.13.27",
"markdownify ~= 0.11.6",
"websockets == 12.0",
"psutil >= 5.8.0",
]
dynamic = ["version"]
@@ -105,7 +106,6 @@ dev = [
"pytest-asyncio == 0.21.1",
"freezegun >= 1.2.0",
"factory-boy >= 3.2.1",
"psutil >= 5.8.0",
"mypy >= 1.0.1",
"black >= 23.1.0",
"pre-commit >= 3.0.4",

View File

@@ -4,9 +4,9 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0 maximum-scale=1.0">
<title>Khoj - Chat</title>
<link rel="stylesheet" href="./assets/khoj.css">
<link rel="icon" type="image/png" sizes="128x128" href="./assets/icons/favicon-128x128.png">
<link rel="manifest" href="/static/khoj.webmanifest">
<link rel="stylesheet" href="./assets/khoj.css">
</head>
<script type="text/javascript" src="./assets/markdown-it.min.js"></script>
<script src="./utils.js"></script>
@@ -130,7 +130,7 @@
return referenceButton;
}
function renderMessage(message, by, dt=null, annotations=null, raw=false) {
function renderMessage(message, by, dt=null, annotations=null, raw=false, renderType="append") {
let message_time = formatDate(dt ?? new Date());
let by_name = by == "khoj" ? "🏮 Khoj" : "🤔 You";
let formattedMessage = formatHTMLMessage(message, raw);
@@ -153,10 +153,15 @@
// Append chat message div to chat body
let chatBody = document.getElementById("chat-body");
chatBody.appendChild(chatMessage);
// Scroll to bottom of chat-body element
chatBody.scrollTop = chatBody.scrollHeight;
if (renderType === "append") {
chatBody.appendChild(chatMessage);
// Scroll to bottom of chat-body element
chatBody.scrollTop = chatBody.scrollHeight;
} else if (renderType === "prepend") {
chatBody.insertBefore(chatMessage, chatBody.firstChild);
} else if (renderType === "return") {
return chatMessage;
}
let chatBodyWrapper = document.getElementById("chat-body-wrapper");
chatBodyWrapperHeight = chatBodyWrapper.clientHeight;
@@ -207,6 +212,7 @@
}
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) {
// If no document or online context is provided, render the message as is
if ((context == null || context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) {
if (intentType?.includes("text-to-image")) {
let imageMarkdown;
@@ -214,30 +220,29 @@
imageMarkdown = `![](data:image/png;base64,${message})`;
} else if (intentType === "text-to-image2") {
imageMarkdown = `![](${message})`;
} else if (intentType === "text-to-image-v3") {
imageMarkdown = `![](data:image/webp;base64,${message})`;
}
const inferredQuery = inferredQueries?.[0];
if (inferredQuery) {
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
}
renderMessage(imageMarkdown, by, dt);
return;
return renderMessage(imageMarkdown, by, dt, null, false, "return");
}
renderMessage(message, by, dt);
return;
return renderMessage(message, by, dt, null, false, "return");
}
if (context == null && onlineContext == null) {
renderMessage(message, by, dt);
return;
return renderMessage(message, by, dt, null, false, "return");
}
if ((context && context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) {
renderMessage(message, by, dt);
return;
return renderMessage(message, by, dt, null, false, "return");
}
// If document or online context is provided, render the message with its references
let references = document.createElement('div');
let referenceExpandButton = document.createElement('button');
@@ -288,16 +293,17 @@
imageMarkdown = `![](data:image/png;base64,${message})`;
} else if (intentType === "text-to-image2") {
imageMarkdown = `![](${message})`;
} else if (intentType === "text-to-image-v3") {
imageMarkdown = `![](data:image/webp;base64,${message})`;
}
const inferredQuery = inferredQueries?.[0];
if (inferredQuery) {
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
}
renderMessage(imageMarkdown, by, dt, references);
return;
return renderMessage(imageMarkdown, by, dt, references, false, "return");
}
renderMessage(message, by, dt, references);
return renderMessage(message, by, dt, references, false, "return");
}
function formatHTMLMessage(htmlMessage, raw=false, willReplace=true) {
@@ -509,6 +515,8 @@
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
} else if (responseAsJson.intentType === "text-to-image2") {
rawResponse += `![${query}](${responseAsJson.image})`;
} else if (responseAsJson.intentType === "text-to-image-v3") {
rawResponse += `![${query}](data:image/webp;base64,${responseAsJson.image})`;
}
const inferredQueries = responseAsJson.inferredQueries?.[0];
if (inferredQueries) {
@@ -671,7 +679,7 @@
let firstRunSetupMessageRendered = false;
let chatBody = document.getElementById("chat-body");
chatBody.innerHTML = "";
let chatHistoryUrl = `/api/chat/history?client=desktop`;
let chatHistoryUrl = `${hostURL}/api/chat/history?client=desktop`;
if (chatBody.dataset.conversationId) {
chatHistoryUrl += `&conversation_id=${chatBody.dataset.conversationId}`;
}
@@ -683,7 +691,8 @@
loadingScreen.appendChild(yellowOrb);
chatBody.appendChild(loadingScreen);
fetch(`${hostURL}${chatHistoryUrl}`, { headers })
// Get the most recent 10 chat messages from conversation history
fetch(`${chatHistoryUrl}&n=10`, { headers })
.then(response => response.json())
.then(data => {
if (data.detail) {
@@ -703,11 +712,21 @@
chatBody.dataset.conversationId = response.conversation_id;
chatBody.dataset.conversationTitle = response.slug || `New conversation 🌱`;
const fullChatLog = response.chat || [];
// Create a new IntersectionObserver
let fetchRemainingMessagesObserver = new IntersectionObserver((entries, observer) => {
entries.forEach(entry => {
// If the element is in the viewport, fetch the remaining message and unobserve the element
if (entry.isIntersecting) {
fetchRemainingChatMessages(chatHistoryUrl);
observer.unobserve(entry.target);
}
});
}, {rootMargin: '0px 0px 0px 0px'});
fullChatLog.forEach(chat_log => {
const fullChatLog = response.chat || [];
fullChatLog.forEach((chat_log, index) => {
if (chat_log.message != null) {
renderMessageWithReference(
let messageElement = renderMessageWithReference(
chat_log.message,
chat_log.by,
chat_log.context,
@@ -715,10 +734,25 @@
chat_log.onlineContext,
chat_log.intent?.type,
chat_log.intent?.["inferred-queries"]);
chatBody.appendChild(messageElement);
// When the 4th oldest message is within viewing distance (~60% scrolled up)
// Fetch the remaining chat messages
if (index === 4) {
fetchRemainingMessagesObserver.observe(messageElement);
}
}
loadingScreen.style.height = chatBody.scrollHeight + 'px';
})
// Scroll to bottom of chat-body element
chatBody.scrollTop = chatBody.scrollHeight;
// Set height of chat-body element to the height of the chat-body-wrapper
let chatBodyWrapper = document.getElementById("chat-body-wrapper");
let chatBodyWrapperHeight = chatBodyWrapper.clientHeight;
chatBody.style.height = chatBodyWrapperHeight;
// Add fade out animation to loading screen and remove it after the animation ends
fadeOutLoadingAnimation(loadingScreen);
})
@@ -726,9 +760,9 @@
// If the server returns a 500 error with detail, render a setup hint.
if (!firstRunSetupMessageRendered) {
renderFirstRunSetupMessage();
fadeOutLoadingAnimation(loadingScreen);
}
return;
fadeOutLoadingAnimation(loadingScreen);
return;
});
await refreshChatSessionsPanel();
@@ -778,6 +812,65 @@
}
}
function fetchRemainingChatMessages(chatHistoryUrl) {
// Create a new IntersectionObserver
let observer = new IntersectionObserver((entries, observer) => {
entries.forEach(entry => {
// If the element is in the viewport, render the message and unobserve the element
if (entry.isIntersecting) {
let chat_log = entry.target.chat_log;
let messageElement = renderMessageWithReference(
chat_log.message,
chat_log.by,
chat_log.context,
new Date(chat_log.created),
chat_log.onlineContext,
chat_log.intent?.type,
chat_log.intent?.["inferred-queries"]
);
entry.target.replaceWith(messageElement);
// Remove the observer after the element has been rendered
observer.unobserve(entry.target);
}
});
}, {rootMargin: '0px 0px 200px 0px'}); // Trigger when the element is within 200px of the viewport
// Fetch remaining chat messages from conversation history
fetch(`${chatHistoryUrl}&n=-10`, { method: "GET" })
.then(response => response.json())
.then(data => {
if (data.status != "ok") {
throw new Error(data.message);
}
return data.response;
})
.then(response => {
const fullChatLog = response.chat || [];
let chatBody = document.getElementById("chat-body");
fullChatLog
.reverse()
.forEach(chat_log => {
if (chat_log.message != null) {
// Create a new element for each chat log
let placeholder = document.createElement('div');
placeholder.chat_log = chat_log;
// Insert the message placeholder as the first child of chat body after the welcome message
chatBody.insertBefore(placeholder, chatBody.firstChild.nextSibling);
// Observe the element
placeholder.style.height = "20px";
observer.observe(placeholder);
}
});
})
.catch(err => {
console.log(err);
return;
});
}
function fadeOutLoadingAnimation(loadingScreen) {
let chatBody = document.getElementById("chat-body");
let chatBodyWrapper = document.getElementById("chat-body-wrapper");

View File

@@ -156,6 +156,8 @@ export class KhojChatModal extends Modal {
imageMarkdown = `![](data:image/png;base64,${message})`;
} else if (intentType === "text-to-image2") {
imageMarkdown = `![](${message})`;
} else if (intentType === "text-to-image-v3") {
imageMarkdown = `![](data:image/webp;base64,${message})`;
}
if (inferredQueries) {
imageMarkdown += "\n\n**Inferred Query**:";
@@ -429,6 +431,8 @@ export class KhojChatModal extends Modal {
responseText += `![${query}](data:image/png;base64,${responseAsJson.image})`;
} else if (responseAsJson.intentType === "text-to-image2") {
responseText += `![${query}](${responseAsJson.image})`;
} else if (responseAsJson.intentType === "text-to-image-v3") {
responseText += `![${query}](data:image/webp;base64,${responseAsJson.image})`;
}
const inferredQuery = responseAsJson.inferredQueries?.[0];
if (inferredQuery) {

View File

@@ -39,7 +39,7 @@ from khoj.routers.twilio import is_twilio_enabled
from khoj.utils import constants, state
from khoj.utils.config import SearchType
from khoj.utils.fs_syncer import collect_files
from khoj.utils.helpers import is_none_or_empty
from khoj.utils.helpers import is_none_or_empty, telemetry_disabled
from khoj.utils.rawconfig import FullConfig
logger = logging.getLogger(__name__)
@@ -232,6 +232,9 @@ def configure_server(
state.search_models = configure_search(state.search_models, state.config.search_type)
setup_default_agent()
message = "📡 Telemetry disabled" if telemetry_disabled(state.config.app) else "📡 Telemetry enabled"
logger.info(message)
if not init:
initialize_content(regenerate, search_type, user)
@@ -329,9 +332,7 @@ def configure_search_types():
@schedule.repeat(schedule.every(2).minutes)
def upload_telemetry():
if not state.config or not state.config.app or not state.config.app.should_log_telemetry or not state.telemetry:
message = "📡 No telemetry to upload" if not state.telemetry else "📡 Telemetry logging disabled"
logger.debug(message)
if telemetry_disabled(state.config.app) or not state.telemetry:
return
try:

View File

@@ -197,9 +197,6 @@ def get_user_name(user: KhojUser):
def get_user_photo(user: KhojUser):
full_name = user.get_full_name()
if not is_none_or_empty(full_name):
return full_name
google_profile: GoogleUser = GoogleUser.objects.filter(user=user).first()
if google_profile:
return google_profile.picture

View File

@@ -23,6 +23,7 @@ from khoj.database.models import (
TextToImageModelConfig,
UserSearchModelConfig,
)
from khoj.utils.helpers import ImageIntentType
class KhojUserAdmin(UserAdmin):
@@ -114,9 +115,12 @@ class ConversationAdmin(admin.ModelAdmin):
log["by"] == "khoj"
and log["intent"]
and log["intent"]["type"]
and log["intent"]["type"] == "text-to-image"
and (
log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE.value
or log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE_V3.value
)
):
log["message"] = "image redacted for space"
log["message"] = "inline image redacted for space"
chat_log[idx] = log
modified_log["chat"] = chat_log
@@ -154,9 +158,12 @@ class ConversationAdmin(admin.ModelAdmin):
log["by"] == "khoj"
and log["intent"]
and log["intent"]["type"]
and log["intent"]["type"] == "text-to-image"
and (
log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE.value
or log["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE_V3.value
)
):
updated_log["message"] = "image redacted for space"
updated_log["message"] = "inline image redacted for space"
chat_log[idx] = updated_log
return_log["chat"] = chat_log

View File

View File

@@ -0,0 +1,40 @@
from django.core.management.base import BaseCommand
from khoj.database.models import Conversation
from khoj.utils.helpers import ImageIntentType
class Command(BaseCommand):
help = "Convert all images to WebP format or reverse."
def add_arguments(self, parser):
# Add a new argument 'reverse' to the command
parser.add_argument(
"--reverse",
action="store_true",
help="Convert from WebP to PNG instead of PNG to WebP",
)
def handle(self, *args, **options):
updated_count = 0
for conversation in Conversation.objects.all():
conversation_updated = False
for chat in conversation.conversation_log["chat"]:
if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE2.value:
if options["reverse"] and chat["message"].endswith(".webp"):
# Convert WebP url to PNG url
chat["message"] = chat["message"].replace(".webp", ".png")
conversation_updated = True
updated_count += 1
elif chat["message"].endswith(".png"):
# Convert PNG url to WebP url
chat["message"] = chat["message"].replace(".png", ".webp")
conversation_updated = True
updated_count += 1
if conversation_updated:
conversation.save()
if updated_count > 0 and options["reverse"]:
self.stdout.write(self.style.SUCCESS(f"Successfully converted {updated_count} WebP images to PNG format."))
elif updated_count > 0:
self.stdout.write(self.style.SUCCESS(f"Successfully converted {updated_count} PNG images to WebP format."))

View File

@@ -0,0 +1,69 @@
# Generated by Django 4.2.10 on 2024-04-13 17:54
import base64
import io
from django.db import migrations
from PIL import Image
from khoj.utils.helpers import ImageIntentType
def convert_png_images_to_webp(apps, schema_editor):
# Get the model from the versioned app registry to ensure the correct version is used
Conversations = apps.get_model("database", "Conversation")
for conversation in Conversations.objects.all():
for chat in conversation.conversation_log["chat"]:
if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE.value:
# Decode the base64 encoded PNG image
decoded_image = base64.b64decode(chat["message"])
# Convert images from PNG to WebP format
image_io = io.BytesIO(decoded_image)
with Image.open(image_io) as png_image:
webp_image_io = io.BytesIO()
png_image.save(webp_image_io, "WEBP")
# Encode the WebP image back to base64
webp_image_bytes = webp_image_io.getvalue()
chat["message"] = base64.b64encode(webp_image_bytes).decode()
chat["intent"]["type"] = ImageIntentType.TEXT_TO_IMAGE_V3.value
webp_image_io.close()
# Save the updated conversation history
conversation.save()
def convert_webp_images_to_png(apps, schema_editor):
# Get the model from the versioned app registry to ensure the correct version is used
Conversations = apps.get_model("database", "Conversation")
for conversation in Conversations.objects.all():
for chat in conversation.conversation_log["chat"]:
if chat["by"] == "khoj" and chat["intent"]["type"] == ImageIntentType.TEXT_TO_IMAGE_V3.value:
# Decode the base64 encoded PNG image
decoded_image = base64.b64decode(chat["message"])
# Convert images from PNG to WebP format
image_io = io.BytesIO(decoded_image)
with Image.open(image_io) as png_image:
webp_image_io = io.BytesIO()
png_image.save(webp_image_io, "PNG")
# Encode the WebP image back to base64
webp_image_bytes = webp_image_io.getvalue()
chat["message"] = base64.b64encode(webp_image_bytes).decode()
chat["intent"]["type"] = ImageIntentType.TEXT_TO_IMAGE.value
webp_image_io.close()
# Save the updated conversation history
conversation.save()
class Migration(migrations.Migration):
dependencies = [
("database", "0034_alter_chatmodeloptions_chat_model"),
]
operations = [
migrations.RunPython(convert_png_images_to_webp, reverse_code=convert_webp_images_to_png),
]

View File

@@ -4,10 +4,10 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0 maximum-scale=1.0">
<title>Khoj - Chat</title>
<link rel="stylesheet" href="/static/assets/khoj.css?v={{ khoj_version }}">
<link rel="icon" type="image/png" sizes="128x128" href="/static/assets/icons/favicon-128x128.png?v={{ khoj_version }}">
<link rel="apple-touch-icon" href="/static/assets/icons/favicon-128x128.png?v={{ khoj_version }}">
<link rel="manifest" href="/static/khoj.webmanifest?v={{ khoj_version }}">
<link rel="stylesheet" href="/static/assets/khoj.css?v={{ khoj_version }}">
</head>
<script type="text/javascript" src="/static/assets/utils.js?v={{ khoj_version }}"></script>
<script type="text/javascript" src="/static/assets/markdown-it.min.js?v={{ khoj_version }}"></script>
@@ -160,7 +160,7 @@ To get started, just start typing below. You can also type / to see a list of co
return referenceButton;
}
function renderMessage(message, by, dt=null, annotations=null, raw=false) {
function renderMessage(message, by, dt=null, annotations=null, raw=false, renderType="append") {
let message_time = formatDate(dt ?? new Date());
let by_name = by == "khoj" ? "🏮 Khoj" : "🤔 You";
let formattedMessage = formatHTMLMessage(message, raw);
@@ -183,10 +183,16 @@ To get started, just start typing below. You can also type / to see a list of co
// Append chat message div to chat body
let chatBody = document.getElementById("chat-body");
chatBody.appendChild(chatMessage);
// Scroll to bottom of chat-body element
chatBody.scrollTop = chatBody.scrollHeight;
if (renderType === "append") {
chatBody.appendChild(chatMessage);
// Scroll to bottom of chat-body element
chatBody.scrollTop = chatBody.scrollHeight;
} else if (renderType === "prepend"){
let chatBody = document.getElementById("chat-body");
chatBody.insertBefore(chatMessage, chatBody.firstChild);
} else if (renderType === "return") {
return chatMessage;
}
let chatBodyWrapper = document.getElementById("chat-body-wrapper");
chatBodyWrapperHeight = chatBodyWrapper.clientHeight;
@@ -237,6 +243,7 @@ To get started, just start typing below. You can also type / to see a list of co
}
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) {
// If no document or online context is provided, render the message as is
if ((context == null || context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) {
if (intentType?.includes("text-to-image")) {
let imageMarkdown;
@@ -244,24 +251,24 @@ To get started, just start typing below. You can also type / to see a list of co
imageMarkdown = `![](data:image/png;base64,${message})`;
} else if (intentType === "text-to-image2") {
imageMarkdown = `![](${message})`;
} else if (intentType === "text-to-image-v3") {
imageMarkdown = `![](data:image/webp;base64,${message})`;
}
const inferredQuery = inferredQueries?.[0];
if (inferredQuery) {
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
}
renderMessage(imageMarkdown, by, dt);
return;
return renderMessage(imageMarkdown, by, dt, null, false, "return");
}
renderMessage(message, by, dt);
return;
return renderMessage(message, by, dt, null, false, "return");
}
if ((context && context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) {
renderMessage(message, by, dt);
return;
return renderMessage(message, by, dt, null, false, "return");
}
// If document or online context is provided, render the message with its references
let references = document.createElement('div');
let referenceExpandButton = document.createElement('button');
@@ -312,16 +319,17 @@ To get started, just start typing below. You can also type / to see a list of co
imageMarkdown = `![](data:image/png;base64,${message})`;
} else if (intentType === "text-to-image2") {
imageMarkdown = `![](${message})`;
} else if (intentType === "text-to-image-v3") {
imageMarkdown = `![](data:image/webp;base64,${message})`;
}
const inferredQuery = inferredQueries?.[0];
if (inferredQuery) {
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
}
renderMessage(imageMarkdown, by, dt, references);
return;
return renderMessage(imageMarkdown, by, dt, references, false, "return");
}
renderMessage(message, by, dt, references);
return renderMessage(message, by, dt, references, false, "return");
}
function formatHTMLMessage(htmlMessage, raw=false, willReplace=true) {
@@ -619,6 +627,8 @@ To get started, just start typing below. You can also type / to see a list of co
rawResponse += `![generated_image](data:image/png;base64,${imageJson.image})`;
} else if (imageJson.intentType === "text-to-image2") {
rawResponse += `![generated_image](${imageJson.image})`;
} else if (imageJson.intentType === "text-to-image-v3") {
rawResponse = `![](data:image/webp;base64,${imageJson.image})`;
}
if (inferredQuery) {
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
@@ -1064,7 +1074,8 @@ To get started, just start typing below. You can also type / to see a list of co
loadingScreen.appendChild(yellowOrb);
chatBody.appendChild(loadingScreen);
fetch(chatHistoryUrl, { method: "GET" })
// Get the most recent 10 chat messages from conversation history
fetch(`${chatHistoryUrl}&n=10`, { method: "GET" })
.then(response => response.json())
.then(data => {
if (data.detail) {
@@ -1134,11 +1145,22 @@ To get started, just start typing below. You can also type / to see a list of co
agentMetadataElement.style.display = "none";
}
const fullChatLog = response.chat || [];
// Create a new IntersectionObserver
let fetchRemainingMessagesObserver = new IntersectionObserver((entries, observer) => {
entries.forEach(entry => {
// If the element is in the viewport, fetch the remaining message and unobserve the element
if (entry.isIntersecting) {
fetchRemainingChatMessages(chatHistoryUrl);
observer.unobserve(entry.target);
}
});
}, {rootMargin: '0px 0px 0px 0px'});
fullChatLog.forEach(chat_log => {
if (chat_log.message != null){
renderMessageWithReference(
const fullChatLog = response.chat || [];
fullChatLog.forEach((chat_log, index) => {
// Render the last 10 messages immediately
if (chat_log.message != null) {
let messageElement = renderMessageWithReference(
chat_log.message,
chat_log.by,
chat_log.context,
@@ -1146,14 +1168,26 @@ To get started, just start typing below. You can also type / to see a list of co
chat_log.onlineContext,
chat_log.intent?.type,
chat_log.intent?.["inferred-queries"]);
chatBody.appendChild(messageElement);
// When the 4th oldest message is within viewing distance (~60% scroll up)
// Fetch the remaining chat messages
if (index === 4) {
fetchRemainingMessagesObserver.observe(messageElement);
}
}
loadingScreen.style.height = chatBody.scrollHeight + 'px';
});
// Add fade out animation to loading screen and remove it after the animation ends
// Scroll to bottom of chat-body element
chatBody.scrollTop = chatBody.scrollHeight;
// Set height of chat-body element to the height of the chat-body-wrapper
let chatBodyWrapper = document.getElementById("chat-body-wrapper");
chatBodyWrapperHeight = chatBodyWrapper.clientHeight;
let chatBodyWrapperHeight = chatBodyWrapper.clientHeight;
chatBody.style.height = chatBodyWrapperHeight;
// Add fade out animation to loading screen and remove it after the animation ends
setTimeout(() => {
loadingScreen.remove();
chatBody.classList.remove("relative-position");
@@ -1211,6 +1245,66 @@ To get started, just start typing below. You can also type / to see a list of co
document.getElementById("chat-input").value = query_via_url;
chat();
}
}
function fetchRemainingChatMessages(chatHistoryUrl) {
// Create a new IntersectionObserver
let observer = new IntersectionObserver((entries, observer) => {
entries.forEach(entry => {
// If the element is in the viewport, render the message and unobserve the element
if (entry.isIntersecting) {
let chat_log = entry.target.chat_log;
let messageElement = renderMessageWithReference(
chat_log.message,
chat_log.by,
chat_log.context,
new Date(chat_log.created),
chat_log.onlineContext,
chat_log.intent?.type,
chat_log.intent?.["inferred-queries"]
);
entry.target.replaceWith(messageElement);
// Remove the observer after the element has been rendered
observer.unobserve(entry.target);
}
});
}, {rootMargin: '0px 0px 200px 0px'}); // Trigger when the element is within 200px of the viewport
// Fetch remaining chat messages from conversation history
fetch(`${chatHistoryUrl}&n=-10`, { method: "GET" })
.then(response => response.json())
.then(data => {
if (data.status != "ok") {
throw new Error(data.message);
}
return data.response;
})
.then(response => {
const fullChatLog = response.chat || [];
let chatBody = document.getElementById("chat-body");
fullChatLog
.reverse()
.forEach(chat_log => {
if (chat_log.message != null) {
// Create a new element for each chat log
let placeholder = document.createElement('div');
placeholder.chat_log = chat_log;
// Insert the message placeholder as the first child of chat body after the welcome message
chatBody.insertBefore(placeholder, chatBody.firstChild.nextSibling);
// Observe the element
placeholder.style.height = "20px";
observer.observe(placeholder);
}
});
})
.catch(err => {
console.log(err);
return;
});
}
function flashStatusInChatInput(message) {

View File

@@ -69,6 +69,7 @@ class GithubToEntries(TextToEntries):
markdown_files, org_files, plaintext_files = self.get_files(repo_url, repo)
except ConnectionAbortedError as e:
logger.error(f"Github rate limit reached. Skip indexing github repo {repo_shorthand}")
raise e
except Exception as e:
logger.error(f"Unable to download github repo {repo_shorthand}", exc_info=True)
raise e

View File

@@ -100,7 +100,7 @@ class NotionToEntries(TextToEntries):
for response in responses:
with timer("Processing response", logger=logger):
pages_or_databases = response["results"] if response.get("results") else []
pages_or_databases = response.get("results", [])
# Get all pages content
for p_or_d in pages_or_databases:
@@ -125,7 +125,7 @@ class NotionToEntries(TextToEntries):
current_entries = []
curr_heading = ""
for block in content["results"]:
for block in content.get("results", []):
block_type = block.get("type")
if block_type == None:
@@ -178,7 +178,7 @@ class NotionToEntries(TextToEntries):
return f"\n<b>{heading}</b>\n"
def process_nested_children(self, children, raw_content, block_type=None):
results = children["results"] if children.get("results") else []
results = children.get("results", [])
for child in results:
child_type = child.get("type")
if child_type == None:

View File

@@ -30,6 +30,7 @@ def extract_questions_offline(
use_history: bool = True,
should_extract_questions: bool = True,
location_data: LocationData = None,
max_prompt_size: int = None,
) -> List[str]:
"""
Infer search queries to retrieve relevant notes to answer user query
@@ -41,7 +42,7 @@ def extract_questions_offline(
return all_questions
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model)
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
@@ -67,12 +68,14 @@ def extract_questions_offline(
location=location,
)
messages = generate_chatml_messages_with_context(
example_questions, model_name=model, loaded_model=offline_chat_model
example_questions, model_name=model, loaded_model=offline_chat_model, max_prompt_size=max_prompt_size
)
state.chat_lock.acquire()
try:
response = send_message_to_model_offline(messages, loaded_model=offline_chat_model)
response = send_message_to_model_offline(
messages, loaded_model=offline_chat_model, max_prompt_size=max_prompt_size
)
finally:
state.chat_lock.release()
@@ -138,7 +141,7 @@ def converse_offline(
"""
# Initialize Variables
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model)
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
compiled_references_message = "\n\n".join({f"{item}" for item in references})
current_date = datetime.now().strftime("%Y-%m-%d")
@@ -190,18 +193,18 @@ def converse_offline(
)
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model))
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size))
t.start()
return g
def llm_thread(g, messages: List[ChatMessage], model: Any):
def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None):
stop_phrases = ["<s>", "INST]", "Notes:"]
state.chat_lock.acquire()
try:
response_iterator = send_message_to_model_offline(
messages, loaded_model=model, stop=stop_phrases, streaming=True
messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True
)
for response in response_iterator:
g.send(response["choices"][0]["delta"].get("content", ""))
@@ -216,9 +219,10 @@ def send_message_to_model_offline(
model="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF",
streaming=False,
stop=[],
max_prompt_size: int = None,
):
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model)
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
messages_dict = [{"role": message.role, "content": message.content} for message in messages]
response = offline_chat_model.create_chat_completion(messages_dict, stop=stop, stream=streaming)
if streaming:

View File

@@ -1,18 +1,19 @@
import glob
import logging
import math
import os
from huggingface_hub.constants import HF_HUB_CACHE
from khoj.utils import state
from khoj.utils.helpers import get_device_memory
logger = logging.getLogger(__name__)
def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf"):
from llama_cpp.llama import Llama
# Initialize Model Parameters. Use n_ctx=0 to get context size from the model
def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf", max_tokens: int = None):
# Initialize Model Parameters
# Use n_ctx=0 to get context size from the model
kwargs = {"n_threads": 4, "n_ctx": 0, "verbose": False}
# Decide whether to load model to GPU or CPU
@@ -23,23 +24,33 @@ def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf"):
model_path = load_model_from_cache(repo_id, filename)
chat_model = None
try:
if model_path:
chat_model = Llama(model_path, **kwargs)
else:
Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs)
chat_model = load_model(model_path, repo_id, filename, kwargs)
except:
# Load model on CPU if GPU is not available
kwargs["n_gpu_layers"], device = 0, "cpu"
if model_path:
chat_model = Llama(model_path, **kwargs)
else:
chat_model = Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs)
chat_model = load_model(model_path, repo_id, filename, kwargs)
logger.debug(f"{'Loaded' if model_path else 'Downloaded'} chat model to {device.upper()}")
# Now load the model with context size set based on:
# 1. context size supported by model and
# 2. configured size or machine (V)RAM
kwargs["n_ctx"] = infer_max_tokens(chat_model.n_ctx(), max_tokens)
chat_model = load_model(model_path, repo_id, filename, kwargs)
logger.debug(
f"{'Loaded' if model_path else 'Downloaded'} chat model to {device.upper()} with {kwargs['n_ctx']} token context window."
)
return chat_model
def load_model(model_path: str, repo_id: str, filename: str = "*Q4_K_M.gguf", kwargs: dict = {}):
from llama_cpp.llama import Llama
if model_path:
return Llama(model_path, **kwargs)
else:
return Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs)
def load_model_from_cache(repo_id: str, filename: str, repo_type="models"):
# Construct the path to the model file in the cache directory
repo_org, repo_name = repo_id.split("/")
@@ -52,3 +63,12 @@ def load_model_from_cache(repo_id: str, filename: str, repo_type="models"):
return paths[0]
else:
return None
def infer_max_tokens(model_context_window: int, configured_max_tokens=math.inf) -> int:
"""Infer max prompt size based on device memory and max context window supported by the model"""
vram_based_n_ctx = int(get_device_memory() / 1e6) # based on heuristic
if configured_max_tokens:
return min(configured_max_tokens, model_context_window)
else:
return min(vram_based_n_ctx, model_context_window)

View File

@@ -1,5 +1,6 @@
import json
import logging
import math
import queue
from datetime import datetime
from time import perf_counter
@@ -141,14 +142,12 @@ def generate_chatml_messages_with_context(
tokenizer_name=None,
):
"""Generate messages for ChatGPT with context from previous conversation"""
# Set max prompt size from user config, pre-configured for model or to default prompt size
try:
max_prompt_size = max_prompt_size or model_to_prompt_size[model_name]
except:
max_prompt_size = 2000
logger.warning(
f"Fallback to default prompt size: {max_prompt_size}.\nConfigure max_prompt_size for unsupported model: {model_name} in Khoj settings to longer context window."
)
# Set max prompt size from user config or based on pre-configured for model and machine specs
if not max_prompt_size:
if loaded_model:
max_prompt_size = min(loaded_model.n_ctx(), model_to_prompt_size.get(model_name, math.inf))
else:
max_prompt_size = model_to_prompt_size.get(model_name, 2000)
# Scale lookback turns proportional to max prompt size supported by model
lookback_turns = max_prompt_size // 750
@@ -187,7 +186,7 @@ def truncate_messages(
max_prompt_size,
model_name: str,
loaded_model: Optional[Llama] = None,
tokenizer_name=None,
tokenizer_name="hf-internal-testing/llama-tokenizer",
) -> list[ChatMessage]:
"""Truncate messages to fit within max prompt size supported by model"""
@@ -197,15 +196,11 @@ def truncate_messages(
elif model_name.startswith("gpt-"):
encoder = tiktoken.encoding_for_model(model_name)
else:
try:
encoder = download_model(model_name).tokenizer()
except:
encoder = AutoTokenizer.from_pretrained(tokenizer_name or model_to_tokenizer[model_name])
encoder = download_model(model_name).tokenizer()
except:
default_tokenizer = "hf-internal-testing/llama-tokenizer"
encoder = AutoTokenizer.from_pretrained(default_tokenizer)
encoder = AutoTokenizer.from_pretrained(tokenizer_name)
logger.warning(
f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
f"Fallback to default chat model tokenizer: {tokenizer_name}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
)
# Extract system message from messages

View File

@@ -289,9 +289,7 @@ async def extract_references_and_questions(
return compiled_references, inferred_queries, q
if not await sync_to_async(EntryAdapters.user_has_entries)(user=user):
logger.warning(
"No content index loaded, so cannot extract references from knowledge base. Please configure your data sources and update the index to chat with your notes."
)
logger.debug("No documents in knowledge base. Use a Khoj client to sync and chat with your docs.")
return compiled_references, inferred_queries, q
# Extract filter terms from user message
@@ -317,8 +315,9 @@ async def extract_references_and_questions(
using_offline_chat = True
default_offline_llm = await ConversationAdapters.get_default_offline_llm()
chat_model = default_offline_llm.chat_model
max_tokens = default_offline_llm.max_prompt_size
if state.offline_chat_processor_config is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model=chat_model)
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
loaded_model = state.offline_chat_processor_config.loaded_model
@@ -328,6 +327,7 @@ async def extract_references_and_questions(
conversation_log=meta_log,
should_extract_questions=True,
location_data=location_data,
max_prompt_size=conversation_config.max_prompt_size,
)
elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = await ConversationAdapters.get_openai_chat_config()

View File

@@ -76,6 +76,7 @@ def chat_history(
request: Request,
common: CommonQueryParams,
conversation_id: Optional[int] = None,
n: Optional[int] = None,
):
user = request.user.object
validate_conversation_config()
@@ -109,6 +110,13 @@ def chat_history(
}
)
# Get latest N messages if N > 0
if n > 0:
meta_log["chat"] = meta_log["chat"][-n:]
# Else return all messages except latest N
else:
meta_log["chat"] = meta_log["chat"][:n]
update_telemetry_state(
request=request,
telemetry_type="api",
@@ -425,8 +433,7 @@ async def websocket_endpoint(
api="chat",
metadata={"conversation_command": conversation_commands[0].value},
)
intent_type = "text-to-image"
image, status_code, improved_image_prompt, image_url = await text_to_image(
image, status_code, improved_image_prompt, intent_type = await text_to_image(
q,
user,
meta_log,
@@ -445,9 +452,6 @@ async def websocket_endpoint(
await send_complete_llm_response(json.dumps(content_obj))
continue
if image_url:
intent_type = "text-to-image2"
image = image_url
await sync_to_async(save_to_conversation_log)(
q,
image,
@@ -621,17 +625,13 @@ async def chat(
metadata={"conversation_command": conversation_commands[0].value},
**common.__dict__,
)
intent_type = "text-to-image"
image, status_code, improved_image_prompt, image_url = await text_to_image(
image, status_code, improved_image_prompt, intent_type = await text_to_image(
q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results
)
if image is None:
content_obj = {"image": image, "intentType": intent_type, "detail": improved_image_prompt}
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
if image_url:
intent_type = "text-to-image2"
image = image_url
await sync_to_async(save_to_conversation_log)(
q,
image,

View File

@@ -1,4 +1,6 @@
import asyncio
import base64
import io
import json
import logging
from concurrent.futures import ThreadPoolExecutor
@@ -18,6 +20,7 @@ from typing import (
import openai
from fastapi import Depends, Header, HTTPException, Request, UploadFile
from PIL import Image
from starlette.authentication import has_required_scope
from khoj.database.adapters import AgentAdapters, ConversationAdapters, EntryAdapters
@@ -46,6 +49,7 @@ from khoj.utils import state
from khoj.utils.config import OfflineChatProcessorModel
from khoj.utils.helpers import (
ConversationCommand,
ImageIntentType,
is_none_or_empty,
is_valid_url,
log_telemetry,
@@ -79,9 +83,10 @@ async def is_ready_to_chat(user: KhojUser):
if has_offline_config and user_conversation_config and user_conversation_config.model_type == "offline":
chat_model = user_conversation_config.chat_model
max_tokens = user_conversation_config.max_prompt_size
if state.offline_chat_processor_config is None:
logger.info("Loading Offline Chat Model...")
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model=chat_model)
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
return True
ready = has_openai_config or has_offline_config
@@ -382,10 +387,11 @@ async def send_message_to_model_wrapper(
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
chat_model = conversation_config.chat_model
max_tokens = conversation_config.max_prompt_size
if conversation_config.model_type == "offline":
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model)
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
loaded_model = state.offline_chat_processor_config.loaded_model
truncated_messages = generate_chatml_messages_with_context(
@@ -452,7 +458,9 @@ def generate_chat_response(
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
if conversation_config.model_type == "offline":
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(conversation_config.chat_model)
chat_model = conversation_config.chat_model
max_tokens = conversation_config.max_prompt_size
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
loaded_model = state.offline_chat_processor_config.loaded_model
chat_response = converse_offline(
@@ -508,18 +516,19 @@ async def text_to_image(
references: List[str],
online_results: Dict[str, Any],
send_status_func: Optional[Callable] = None,
) -> Tuple[Optional[str], int, Optional[str], Optional[str]]:
) -> Tuple[Optional[str], int, Optional[str], str]:
status_code = 200
image = None
response = None
image_url = None
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
if not text_to_image_config:
# If the user has not configured a text to image model, return an unsupported on server error
status_code = 501
message = "Failed to generate image. Setup image generation on the server."
return image, status_code, message, image_url
return image_url or image, status_code, message, intent_type.value
elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
logger.info("Generating image with OpenAI")
text2image_model = text_to_image_config.model_name
@@ -550,21 +559,38 @@ async def text_to_image(
)
image = response.data[0].b64_json
with timer("Convert image to webp", logger):
# Convert png to webp for faster loading
decoded_image = base64.b64decode(image)
image_io = io.BytesIO(decoded_image)
png_image = Image.open(image_io)
webp_image_io = io.BytesIO()
png_image.save(webp_image_io, "WEBP")
webp_image_bytes = webp_image_io.getvalue()
webp_image_io.close()
image_io.close()
with timer("Upload image to S3", logger):
image_url = upload_image(image, user.uuid)
return image, status_code, improved_image_prompt, image_url
image_url = upload_image(webp_image_bytes, user.uuid)
if image_url:
intent_type = ImageIntentType.TEXT_TO_IMAGE2
else:
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
image = base64.b64encode(webp_image_bytes).decode("utf-8")
return image_url or image, status_code, improved_image_prompt, intent_type.value
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
if "content_policy_violation" in e.message:
logger.error(f"Image Generation blocked by OpenAI: {e}")
status_code = e.status_code # type: ignore
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
return image, status_code, message, image_url
return image_url or image, status_code, message, intent_type.value
else:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
status_code = e.status_code # type: ignore
return image, status_code, message, image_url
return image, status_code, response, image_url
return image_url or image, status_code, message, intent_type.value
return image_url or image, status_code, response, intent_type.value
class ApiUserRateLimiter:

View File

@@ -1,4 +1,3 @@
import base64
import logging
import os
import uuid
@@ -17,16 +16,15 @@ if aws_enabled:
s3_client = client("s3", aws_access_key_id=AWS_ACCESS_KEY, aws_secret_access_key=AWS_SECRET_KEY)
def upload_image(image: str, user_id: uuid.UUID):
def upload_image(image: bytes, user_id: uuid.UUID):
"""Upload the image to the S3 bucket"""
if not aws_enabled:
logger.info("AWS is not enabled. Skipping image upload")
return None
decoded_image = base64.b64decode(image)
image_key = f"{user_id}/{uuid.uuid4()}.png"
image_key = f"{user_id}/{uuid.uuid4()}.webp"
try:
s3_client.put_object(Bucket=AWS_UPLOAD_IMAGE_BUCKET_NAME, Key=image_key, Body=decoded_image, ACL="public-read")
s3_client.put_object(Bucket=AWS_UPLOAD_IMAGE_BUCKET_NAME, Key=image_key, Body=image, ACL="public-read")
url = f"https://{AWS_UPLOAD_IMAGE_BUCKET_NAME}.s3.amazonaws.com/{image_key}"
return url
except Exception as e:

View File

@@ -69,11 +69,11 @@ class OfflineChatProcessorConfig:
class OfflineChatProcessorModel:
def __init__(self, chat_model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"):
def __init__(self, chat_model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF", max_tokens: int = None):
self.chat_model = chat_model
self.loaded_model = None
try:
self.loaded_model = download_model(self.chat_model)
self.loaded_model = download_model(self.chat_model, max_tokens=max_tokens)
except ValueError as e:
self.loaded_model = None
logger.error(f"Error while loading offline chat model: {e}", exc_info=True)

View File

@@ -17,6 +17,7 @@ from time import perf_counter
from typing import TYPE_CHECKING, Optional, Union
from urllib.parse import urlparse
import psutil
import torch
from asgiref.sync import sync_to_async
from magika import Magika
@@ -233,6 +234,10 @@ def get_server_id():
return server_id
def telemetry_disabled(app_config: AppConfig):
return not app_config or not app_config.should_log_telemetry
def log_telemetry(
telemetry_type: str,
api: str = None,
@@ -242,7 +247,7 @@ def log_telemetry(
):
"""Log basic app usage telemetry like client, os, api called"""
# Do not log usage telemetry, if telemetry is disabled via app config
if not app_config or not app_config.should_log_telemetry:
if telemetry_disabled(app_config):
return []
if properties.get("server_id") is None:
@@ -267,6 +272,17 @@ def log_telemetry(
return request_body
def get_device_memory() -> int:
"""Get device memory in GB"""
device = get_device()
if device.type == "cuda":
return torch.cuda.get_device_properties(device).total_memory
elif device.type == "mps":
return torch.mps.driver_allocated_memory()
else:
return psutil.virtual_memory().total
def get_device() -> torch.device:
"""Get device to run model on"""
if torch.cuda.is_available():
@@ -313,6 +329,20 @@ mode_descriptions_for_llm = {
}
class ImageIntentType(Enum):
"""
Chat message intent by Khoj for image responses.
Marks the schema used to reference image in chat messages
"""
# Images as Inline PNG
TEXT_TO_IMAGE = "text-to-image"
# Images as URLs
TEXT_TO_IMAGE2 = "text-to-image2"
# Images as Inline WebP
TEXT_TO_IMAGE_V3 = "text-to-image-v3"
def generate_random_name():
# List of adjectives and nouns to choose from
adjectives = [