mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-04 13:20:17 +00:00
Add Ability to Summarize Documents (#800)
* Uses entire file text and summarizer model to generate document summary. * Uses the contents of the user's query to create a tailored summary. * Integrates with File Filters #788 for a better UX.
This commit is contained in:
@@ -28,6 +28,7 @@ from khoj.database.models import (
|
||||
ClientApplication,
|
||||
Conversation,
|
||||
Entry,
|
||||
FileObject,
|
||||
GithubConfig,
|
||||
GithubRepoConfig,
|
||||
GoogleUser,
|
||||
@@ -731,7 +732,7 @@ class ConversationAdapters:
|
||||
if server_chat_settings is None or (
|
||||
server_chat_settings.summarizer_model is None and server_chat_settings.default_model is None
|
||||
):
|
||||
return await ChatModelOptions.objects.filter().afirst()
|
||||
return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst()
|
||||
return server_chat_settings.summarizer_model or server_chat_settings.default_model
|
||||
|
||||
@staticmethod
|
||||
@@ -846,6 +847,58 @@ class ConversationAdapters:
|
||||
return await TextToImageModelConfig.objects.filter().afirst()
|
||||
|
||||
|
||||
class FileObjectAdapters:
|
||||
@staticmethod
|
||||
def update_raw_text(file_object: FileObject, new_raw_text: str):
|
||||
file_object.raw_text = new_raw_text
|
||||
file_object.save()
|
||||
|
||||
@staticmethod
|
||||
def create_file_object(user: KhojUser, file_name: str, raw_text: str):
|
||||
return FileObject.objects.create(user=user, file_name=file_name, raw_text=raw_text)
|
||||
|
||||
@staticmethod
|
||||
def get_file_objects_by_name(user: KhojUser, file_name: str):
|
||||
return FileObject.objects.filter(user=user, file_name=file_name).first()
|
||||
|
||||
@staticmethod
|
||||
def get_all_file_objects(user: KhojUser):
|
||||
return FileObject.objects.filter(user=user).all()
|
||||
|
||||
@staticmethod
|
||||
def delete_file_object_by_name(user: KhojUser, file_name: str):
|
||||
return FileObject.objects.filter(user=user, file_name=file_name).delete()
|
||||
|
||||
@staticmethod
|
||||
def delete_all_file_objects(user: KhojUser):
|
||||
return FileObject.objects.filter(user=user).delete()
|
||||
|
||||
@staticmethod
|
||||
async def async_update_raw_text(file_object: FileObject, new_raw_text: str):
|
||||
file_object.raw_text = new_raw_text
|
||||
await file_object.asave()
|
||||
|
||||
@staticmethod
|
||||
async def async_create_file_object(user: KhojUser, file_name: str, raw_text: str):
|
||||
return await FileObject.objects.acreate(user=user, file_name=file_name, raw_text=raw_text)
|
||||
|
||||
@staticmethod
|
||||
async def async_get_file_objects_by_name(user: KhojUser, file_name: str):
|
||||
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name))
|
||||
|
||||
@staticmethod
|
||||
async def async_get_all_file_objects(user: KhojUser):
|
||||
return await sync_to_async(list)(FileObject.objects.filter(user=user))
|
||||
|
||||
@staticmethod
|
||||
async def async_delete_file_object_by_name(user: KhojUser, file_name: str):
|
||||
return await FileObject.objects.filter(user=user, file_name=file_name).adelete()
|
||||
|
||||
@staticmethod
|
||||
async def async_delete_all_file_objects(user: KhojUser):
|
||||
return await FileObject.objects.filter(user=user).adelete()
|
||||
|
||||
|
||||
class EntryAdapters:
|
||||
word_filer = WordFilter()
|
||||
file_filter = FileFilter()
|
||||
|
||||
37
src/khoj/database/migrations/0045_fileobject.py
Normal file
37
src/khoj/database/migrations/0045_fileobject.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# Generated by Django 4.2.11 on 2024-06-14 06:13
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0044_conversation_file_filters"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="FileObject",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
("file_name", models.CharField(blank=True, default=None, max_length=400, null=True)),
|
||||
("raw_text", models.TextField()),
|
||||
(
|
||||
"user",
|
||||
models.ForeignKey(
|
||||
blank=True,
|
||||
default=None,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
to=settings.AUTH_USER_MODEL,
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
]
|
||||
@@ -326,6 +326,13 @@ class Entry(BaseModel):
|
||||
corpus_id = models.UUIDField(default=uuid.uuid4, editable=False)
|
||||
|
||||
|
||||
class FileObject(BaseModel):
|
||||
# Same as Entry but raw will be a much larger string
|
||||
file_name = models.CharField(max_length=400, default=None, null=True, blank=True)
|
||||
raw_text = models.TextField()
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
|
||||
|
||||
class EntryDates(BaseModel):
|
||||
date = models.DateField()
|
||||
entry = models.ForeignKey(Entry, on_delete=models.CASCADE, related_name="embeddings_dates")
|
||||
|
||||
@@ -2145,7 +2145,7 @@ To get started, just start typing below. You can also type / to see a list of co
|
||||
<div style="border-top: 1px solid black; ">
|
||||
<div style="display: flex; align-items: center; justify-content: space-between; margin-bottom: 5px; margin-top: 5px;">
|
||||
<p style="margin: 0;">Files</p>
|
||||
<svg id="file-toggle-button" class="file-toggle-button" style="width:20px; height:20px; position: relative; top: 2px" viewBox="0 0 40 40" fill="#000000" xmlns="http://www.w3.org/2000/svg">
|
||||
<svg class="file-toggle-button" style="width:20px; height:20px; position: relative; top: 2px" viewBox="0 0 40 40" fill="#000000" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M16 0c-8.836 0-16 7.163-16 16s7.163 16 16 16c8.837 0 16-7.163 16-16s-7.163-16-16-16zM16 30.032c-7.72 0-14-6.312-14-14.032s6.28-14 14-14 14 6.28 14 14-6.28 14.032-14 14.032zM23 15h-6v-6c0-0.552-0.448-1-1-1s-1 0.448-1 1v6h-6c-0.552 0-1 0.448-1 1s0.448 1 1 1h6v6c0 0.552 0.448 1 1 1s1-0.448 1-1v-6h6c0.552 0 1-0.448 1-1s-0.448-1-1-1z"></path>
|
||||
</svg>
|
||||
</div>
|
||||
|
||||
@@ -33,7 +33,7 @@ class MarkdownToEntries(TextToEntries):
|
||||
max_tokens = 256
|
||||
# Extract Entries from specified Markdown files
|
||||
with timer("Extract entries from specified Markdown files", logger):
|
||||
current_entries = MarkdownToEntries.extract_markdown_entries(files, max_tokens)
|
||||
file_to_text_map, current_entries = MarkdownToEntries.extract_markdown_entries(files, max_tokens)
|
||||
|
||||
# Split entries by max tokens supported by model
|
||||
with timer("Split entries by max token size supported by model", logger):
|
||||
@@ -50,27 +50,30 @@ class MarkdownToEntries(TextToEntries):
|
||||
deletion_file_names,
|
||||
user,
|
||||
regenerate=regenerate,
|
||||
file_to_text_map=file_to_text_map,
|
||||
)
|
||||
|
||||
return num_new_embeddings, num_deleted_embeddings
|
||||
|
||||
@staticmethod
|
||||
def extract_markdown_entries(markdown_files, max_tokens=256) -> List[Entry]:
|
||||
def extract_markdown_entries(markdown_files, max_tokens=256) -> Tuple[Dict, List[Entry]]:
|
||||
"Extract entries by heading from specified Markdown files"
|
||||
entries: List[str] = []
|
||||
entry_to_file_map: List[Tuple[str, str]] = []
|
||||
file_to_text_map = dict()
|
||||
for markdown_file in markdown_files:
|
||||
try:
|
||||
markdown_content = markdown_files[markdown_file]
|
||||
entries, entry_to_file_map = MarkdownToEntries.process_single_markdown_file(
|
||||
markdown_content, markdown_file, entries, entry_to_file_map, max_tokens
|
||||
)
|
||||
file_to_text_map[markdown_file] = markdown_content
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unable to process file: {markdown_file}. This file will not be indexed.\n{e}", exc_info=True
|
||||
)
|
||||
|
||||
return MarkdownToEntries.convert_markdown_entries_to_maps(entries, dict(entry_to_file_map))
|
||||
return file_to_text_map, MarkdownToEntries.convert_markdown_entries_to_maps(entries, dict(entry_to_file_map))
|
||||
|
||||
@staticmethod
|
||||
def process_single_markdown_file(
|
||||
|
||||
@@ -33,7 +33,7 @@ class OrgToEntries(TextToEntries):
|
||||
# Extract Entries from specified Org files
|
||||
max_tokens = 256
|
||||
with timer("Extract entries from specified Org files", logger):
|
||||
current_entries = self.extract_org_entries(files, max_tokens=max_tokens)
|
||||
file_to_text_map, current_entries = self.extract_org_entries(files, max_tokens=max_tokens)
|
||||
|
||||
with timer("Split entries by max token size supported by model", logger):
|
||||
current_entries = self.split_entries_by_max_tokens(current_entries, max_tokens=max_tokens)
|
||||
@@ -49,6 +49,7 @@ class OrgToEntries(TextToEntries):
|
||||
deletion_file_names,
|
||||
user,
|
||||
regenerate=regenerate,
|
||||
file_to_text_map=file_to_text_map,
|
||||
)
|
||||
|
||||
return num_new_embeddings, num_deleted_embeddings
|
||||
@@ -56,26 +57,32 @@ class OrgToEntries(TextToEntries):
|
||||
@staticmethod
|
||||
def extract_org_entries(
|
||||
org_files: dict[str, str], index_heading_entries: bool = False, max_tokens=256
|
||||
) -> List[Entry]:
|
||||
) -> Tuple[Dict, List[Entry]]:
|
||||
"Extract entries from specified Org files"
|
||||
entries, entry_to_file_map = OrgToEntries.extract_org_nodes(org_files, max_tokens)
|
||||
return OrgToEntries.convert_org_nodes_to_entries(entries, entry_to_file_map, index_heading_entries)
|
||||
file_to_text_map, entries, entry_to_file_map = OrgToEntries.extract_org_nodes(org_files, max_tokens)
|
||||
return file_to_text_map, OrgToEntries.convert_org_nodes_to_entries(
|
||||
entries, entry_to_file_map, index_heading_entries
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def extract_org_nodes(org_files: dict[str, str], max_tokens) -> Tuple[List[List[Orgnode]], Dict[Orgnode, str]]:
|
||||
def extract_org_nodes(
|
||||
org_files: dict[str, str], max_tokens
|
||||
) -> Tuple[Dict, List[List[Orgnode]], Dict[Orgnode, str]]:
|
||||
"Extract org nodes from specified org files"
|
||||
entries: List[List[Orgnode]] = []
|
||||
entry_to_file_map: List[Tuple[Orgnode, str]] = []
|
||||
file_to_text_map = {}
|
||||
for org_file in org_files:
|
||||
try:
|
||||
org_content = org_files[org_file]
|
||||
entries, entry_to_file_map = OrgToEntries.process_single_org_file(
|
||||
org_content, org_file, entries, entry_to_file_map, max_tokens
|
||||
)
|
||||
file_to_text_map[org_file] = org_content
|
||||
except Exception as e:
|
||||
logger.error(f"Unable to process file: {org_file}. Skipped indexing it.\nError; {e}", exc_info=True)
|
||||
|
||||
return entries, dict(entry_to_file_map)
|
||||
return file_to_text_map, entries, dict(entry_to_file_map)
|
||||
|
||||
@staticmethod
|
||||
def process_single_org_file(
|
||||
|
||||
@@ -2,10 +2,12 @@ import base64
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import List, Tuple
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from langchain_community.document_loaders import PyMuPDFLoader
|
||||
|
||||
# importing FileObjectAdapter so that we can add new files and debug file object db.
|
||||
# from khoj.database.adapters import FileObjectAdapters
|
||||
from khoj.database.models import Entry as DbEntry
|
||||
from khoj.database.models import KhojUser
|
||||
from khoj.processor.content.text_to_entries import TextToEntries
|
||||
@@ -33,7 +35,7 @@ class PdfToEntries(TextToEntries):
|
||||
|
||||
# Extract Entries from specified Pdf files
|
||||
with timer("Extract entries from specified PDF files", logger):
|
||||
current_entries = PdfToEntries.extract_pdf_entries(files)
|
||||
file_to_text_map, current_entries = PdfToEntries.extract_pdf_entries(files)
|
||||
|
||||
# Split entries by max tokens supported by model
|
||||
with timer("Split entries by max token size supported by model", logger):
|
||||
@@ -50,14 +52,15 @@ class PdfToEntries(TextToEntries):
|
||||
deletion_file_names,
|
||||
user,
|
||||
regenerate=regenerate,
|
||||
file_to_text_map=file_to_text_map,
|
||||
)
|
||||
|
||||
return num_new_embeddings, num_deleted_embeddings
|
||||
|
||||
@staticmethod
|
||||
def extract_pdf_entries(pdf_files) -> List[Entry]:
|
||||
def extract_pdf_entries(pdf_files) -> Tuple[Dict, List[Entry]]: # important function
|
||||
"""Extract entries by page from specified PDF files"""
|
||||
|
||||
file_to_text_map = dict()
|
||||
entries: List[str] = []
|
||||
entry_to_location_map: List[Tuple[str, str]] = []
|
||||
for pdf_file in pdf_files:
|
||||
@@ -73,9 +76,14 @@ class PdfToEntries(TextToEntries):
|
||||
pdf_entries_per_file = [page.page_content for page in loader.load()]
|
||||
except ImportError:
|
||||
loader = PyMuPDFLoader(f"{tmp_file}")
|
||||
pdf_entries_per_file = [page.page_content for page in loader.load()]
|
||||
entry_to_location_map += zip(pdf_entries_per_file, [pdf_file] * len(pdf_entries_per_file))
|
||||
pdf_entries_per_file = [
|
||||
page.page_content for page in loader.load()
|
||||
] # page_content items list for a given pdf.
|
||||
entry_to_location_map += zip(
|
||||
pdf_entries_per_file, [pdf_file] * len(pdf_entries_per_file)
|
||||
) # this is an indexed map of pdf_entries for the pdf.
|
||||
entries.extend(pdf_entries_per_file)
|
||||
file_to_text_map[pdf_file] = pdf_entries_per_file
|
||||
except Exception as e:
|
||||
logger.warning(f"Unable to process file: {pdf_file}. This file will not be indexed.")
|
||||
logger.warning(e, exc_info=True)
|
||||
@@ -83,7 +91,7 @@ class PdfToEntries(TextToEntries):
|
||||
if os.path.exists(f"{tmp_file}"):
|
||||
os.remove(f"{tmp_file}")
|
||||
|
||||
return PdfToEntries.convert_pdf_entries_to_maps(entries, dict(entry_to_location_map))
|
||||
return file_to_text_map, PdfToEntries.convert_pdf_entries_to_maps(entries, dict(entry_to_location_map))
|
||||
|
||||
@staticmethod
|
||||
def convert_pdf_entries_to_maps(parsed_entries: List[str], entry_to_file_map) -> List[Entry]:
|
||||
|
||||
@@ -32,7 +32,7 @@ class PlaintextToEntries(TextToEntries):
|
||||
|
||||
# Extract Entries from specified plaintext files
|
||||
with timer("Extract entries from specified Plaintext files", logger):
|
||||
current_entries = PlaintextToEntries.extract_plaintext_entries(files)
|
||||
file_to_text_map, current_entries = PlaintextToEntries.extract_plaintext_entries(files)
|
||||
|
||||
# Split entries by max tokens supported by model
|
||||
with timer("Split entries by max token size supported by model", logger):
|
||||
@@ -49,6 +49,7 @@ class PlaintextToEntries(TextToEntries):
|
||||
deletion_filenames=deletion_file_names,
|
||||
user=user,
|
||||
regenerate=regenerate,
|
||||
file_to_text_map=file_to_text_map,
|
||||
)
|
||||
|
||||
return num_new_embeddings, num_deleted_embeddings
|
||||
@@ -63,21 +64,23 @@ class PlaintextToEntries(TextToEntries):
|
||||
return soup.get_text(strip=True, separator="\n")
|
||||
|
||||
@staticmethod
|
||||
def extract_plaintext_entries(text_files: Dict[str, str]) -> List[Entry]:
|
||||
def extract_plaintext_entries(text_files: Dict[str, str]) -> Tuple[Dict, List[Entry]]:
|
||||
entries: List[str] = []
|
||||
entry_to_file_map: List[Tuple[str, str]] = []
|
||||
file_to_text_map = dict()
|
||||
for text_file in text_files:
|
||||
try:
|
||||
text_content = text_files[text_file]
|
||||
entries, entry_to_file_map = PlaintextToEntries.process_single_plaintext_file(
|
||||
text_content, text_file, entries, entry_to_file_map
|
||||
)
|
||||
file_to_text_map[text_file] = text_content
|
||||
except Exception as e:
|
||||
logger.warning(f"Unable to read file: {text_file} as plaintext. Skipping file.")
|
||||
logger.warning(e, exc_info=True)
|
||||
|
||||
# Extract Entries from specified plaintext files
|
||||
return PlaintextToEntries.convert_text_files_to_entries(entries, dict(entry_to_file_map))
|
||||
return file_to_text_map, PlaintextToEntries.convert_text_files_to_entries(entries, dict(entry_to_file_map))
|
||||
|
||||
@staticmethod
|
||||
def process_single_plaintext_file(
|
||||
|
||||
@@ -9,7 +9,11 @@ from typing import Any, Callable, List, Set, Tuple
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from tqdm import tqdm
|
||||
|
||||
from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default
|
||||
from khoj.database.adapters import (
|
||||
EntryAdapters,
|
||||
FileObjectAdapters,
|
||||
get_user_search_model_or_default,
|
||||
)
|
||||
from khoj.database.models import Entry as DbEntry
|
||||
from khoj.database.models import EntryDates, KhojUser
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
@@ -120,6 +124,7 @@ class TextToEntries(ABC):
|
||||
deletion_filenames: Set[str] = None,
|
||||
user: KhojUser = None,
|
||||
regenerate: bool = False,
|
||||
file_to_text_map: dict[str, List[str]] = None,
|
||||
):
|
||||
with timer("Constructed current entry hashes in", logger):
|
||||
hashes_by_file = dict[str, set[str]]()
|
||||
@@ -186,6 +191,18 @@ class TextToEntries(ABC):
|
||||
logger.error(f"Error adding entries to database:\n{batch_indexing_error}\n---\n{e}", exc_info=True)
|
||||
logger.debug(f"Added {len(added_entries)} {file_type} entries to database")
|
||||
|
||||
if file_to_text_map:
|
||||
# get the list of file_names using added_entries
|
||||
filenames_to_update = [entry.file_path for entry in added_entries]
|
||||
# for each file_name in filenames_to_update, try getting the file object and updating raw_text and if it fails create a new file object
|
||||
for file_name in filenames_to_update:
|
||||
raw_text = " ".join(file_to_text_map[file_name])
|
||||
file_object = FileObjectAdapters.get_file_objects_by_name(user, file_name)
|
||||
if file_object:
|
||||
FileObjectAdapters.update_raw_text(file_object, raw_text)
|
||||
else:
|
||||
FileObjectAdapters.create_file_object(user, file_name, raw_text)
|
||||
|
||||
new_dates = []
|
||||
with timer("Indexed dates from added entries in", logger):
|
||||
for added_entry in added_entries:
|
||||
@@ -210,6 +227,7 @@ class TextToEntries(ABC):
|
||||
for file_path in deletion_filenames:
|
||||
deleted_count = EntryAdapters.delete_entry_by_file(user, file_path)
|
||||
num_deleted_entries += deleted_count
|
||||
FileObjectAdapters.delete_file_object_by_name(user, file_path)
|
||||
|
||||
return len(added_entries), num_deleted_entries
|
||||
|
||||
|
||||
@@ -321,6 +321,27 @@ Collate only relevant information from the website to answer the target query.
|
||||
""".strip()
|
||||
)
|
||||
|
||||
system_prompt_extract_relevant_summary = """As a professional analyst, create a comprehensive report of the most relevant information from the document in response to a user's query. The text provided is directly from within the document. The report you create should be multiple paragraphs, and it should represent the content of the document. Tell the user exactly what the document says in response to their query, while adhering to these guidelines:
|
||||
|
||||
1. Answer the user's query as specifically as possible. Include many supporting details from the document.
|
||||
2. Craft a report that is detailed, thorough, in-depth, and complex, while maintaining clarity.
|
||||
3. Rely strictly on the provided text, without including external information.
|
||||
4. Format the report in multiple paragraphs with a clear structure.
|
||||
5. Be as specific as possible in your answer to the user's query.
|
||||
6. Reproduce as much of the provided text as possible, while maintaining readability.
|
||||
""".strip()
|
||||
|
||||
extract_relevant_summary = PromptTemplate.from_template(
|
||||
"""
|
||||
Target Query: {query}
|
||||
|
||||
Document Contents:
|
||||
{corpus}
|
||||
|
||||
Collate only relevant information from the document to answer the target query.
|
||||
""".strip()
|
||||
)
|
||||
|
||||
pick_relevant_output_mode = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an excellent analyst for selecting the correct way to respond to a user's query. You have access to a limited set of modes for your response. You can only use one of these modes.
|
||||
|
||||
@@ -16,6 +16,7 @@ from websockets import ConnectionClosedOK
|
||||
from khoj.database.adapters import (
|
||||
ConversationAdapters,
|
||||
EntryAdapters,
|
||||
FileObjectAdapters,
|
||||
PublicConversationAdapters,
|
||||
aget_user_name,
|
||||
)
|
||||
@@ -42,6 +43,7 @@ from khoj.routers.helpers import (
|
||||
aget_relevant_output_modes,
|
||||
construct_automation_created_message,
|
||||
create_automation,
|
||||
extract_relevant_summary,
|
||||
get_conversation_command,
|
||||
is_query_empty,
|
||||
is_ready_to_chat,
|
||||
@@ -586,6 +588,51 @@ async def websocket_endpoint(
|
||||
await conversation_command_rate_limiter.update_and_check_if_valid(websocket, cmd)
|
||||
q = q.replace(f"/{cmd.value}", "").strip()
|
||||
|
||||
if ConversationCommand.Summarize in conversation_commands:
|
||||
file_filters = conversation.file_filters
|
||||
response_log = ""
|
||||
if len(file_filters) == 0:
|
||||
response_log = "No files selected for summarization. Please add files using the section on the left."
|
||||
await send_complete_llm_response(response_log)
|
||||
elif len(file_filters) > 1:
|
||||
response_log = "Only one file can be selected for summarization."
|
||||
await send_complete_llm_response(response_log)
|
||||
else:
|
||||
try:
|
||||
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
|
||||
if len(file_object) == 0:
|
||||
response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
|
||||
await send_complete_llm_response(response_log)
|
||||
continue
|
||||
contextual_data = " ".join([file.raw_text for file in file_object])
|
||||
if not q:
|
||||
q = "Create a general summary of the file"
|
||||
await send_status_update(f"**🧑🏾💻 Constructing Summary Using:** {file_object[0].file_name}")
|
||||
response = await extract_relevant_summary(q, contextual_data)
|
||||
response_log = str(response)
|
||||
await send_complete_llm_response(response_log)
|
||||
except Exception as e:
|
||||
response_log = "Error summarizing file."
|
||||
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
|
||||
await send_complete_llm_response(response_log)
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
q,
|
||||
response_log,
|
||||
user,
|
||||
meta_log,
|
||||
user_message_time,
|
||||
intent_type="summarize",
|
||||
client_application=websocket.user.client_app,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
update_telemetry_state(
|
||||
request=websocket,
|
||||
telemetry_type="api",
|
||||
api="chat",
|
||||
metadata={"conversation_command": conversation_commands[0].value},
|
||||
)
|
||||
continue
|
||||
|
||||
custom_filters = []
|
||||
if conversation_commands == [ConversationCommand.Help]:
|
||||
if not q:
|
||||
@@ -828,6 +875,49 @@ async def chat(
|
||||
_custom_filters.append("site:khoj.dev")
|
||||
conversation_commands.append(ConversationCommand.Online)
|
||||
|
||||
conversation = await ConversationAdapters.aget_conversation_by_user(user, conversation_id=conversation_id)
|
||||
conversation_id = conversation.id if conversation else None
|
||||
if ConversationCommand.Summarize in conversation_commands:
|
||||
file_filters = conversation.file_filters
|
||||
llm_response = ""
|
||||
if len(file_filters) == 0:
|
||||
llm_response = "No files selected for summarization. Please add files using the section on the left."
|
||||
elif len(file_filters) > 1:
|
||||
llm_response = "Only one file can be selected for summarization."
|
||||
else:
|
||||
try:
|
||||
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
|
||||
if len(file_object) == 0:
|
||||
llm_response = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again."
|
||||
return StreamingResponse(content=llm_response, media_type="text/event-stream", status_code=200)
|
||||
contextual_data = " ".join([file.raw_text for file in file_object])
|
||||
summarizeStr = "/" + ConversationCommand.Summarize
|
||||
if q.strip() == summarizeStr:
|
||||
q = "Create a general summary of the file"
|
||||
response = await extract_relevant_summary(q, contextual_data)
|
||||
llm_response = str(response)
|
||||
except Exception as e:
|
||||
logger.error(f"Error summarizing file for {user.email}: {e}")
|
||||
llm_response = "Error summarizing file."
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
q,
|
||||
llm_response,
|
||||
user,
|
||||
conversation.conversation_log,
|
||||
user_message_time,
|
||||
intent_type="summarize",
|
||||
client_application=request.user.client_app,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="chat",
|
||||
metadata={"conversation_command": conversation_commands[0].value},
|
||||
**common.__dict__,
|
||||
)
|
||||
return StreamingResponse(content=llm_response, media_type="text/event-stream", status_code=200)
|
||||
|
||||
conversation = await ConversationAdapters.aget_conversation_by_user(
|
||||
user, request.user.client_app, conversation_id, title
|
||||
)
|
||||
|
||||
@@ -200,6 +200,8 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver
|
||||
return ConversationCommand.Image
|
||||
elif query.startswith("/automated_task"):
|
||||
return ConversationCommand.AutomatedTask
|
||||
elif query.startswith("/summarize"):
|
||||
return ConversationCommand.Summarize
|
||||
# If no relevant notes found for the given query
|
||||
elif not any_references:
|
||||
return ConversationCommand.General
|
||||
@@ -418,7 +420,30 @@ async def extract_relevant_info(q: str, corpus: str) -> Union[str, None]:
|
||||
prompts.system_prompt_extract_relevant_information,
|
||||
chat_model_option=summarizer_model,
|
||||
)
|
||||
return response.strip()
|
||||
|
||||
|
||||
async def extract_relevant_summary(q: str, corpus: str) -> Union[str, None]:
|
||||
"""
|
||||
Extract relevant information for a given query from the target corpus
|
||||
"""
|
||||
|
||||
if is_none_or_empty(corpus) or is_none_or_empty(q):
|
||||
return None
|
||||
|
||||
extract_relevant_information = prompts.extract_relevant_summary.format(
|
||||
query=q,
|
||||
corpus=corpus.strip(),
|
||||
)
|
||||
|
||||
summarizer_model: ChatModelOptions = await ConversationAdapters.aget_summarizer_conversation_config()
|
||||
|
||||
with timer("Chat actor: Extract relevant information from data", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
extract_relevant_information,
|
||||
prompts.system_prompt_extract_relevant_summary,
|
||||
chat_model_option=summarizer_model,
|
||||
)
|
||||
return response.strip()
|
||||
|
||||
|
||||
|
||||
@@ -307,6 +307,7 @@ class ConversationCommand(str, Enum):
|
||||
Text = "text"
|
||||
Automation = "automation"
|
||||
AutomatedTask = "automated_task"
|
||||
Summarize = "summarize"
|
||||
|
||||
|
||||
command_descriptions = {
|
||||
@@ -318,6 +319,7 @@ command_descriptions = {
|
||||
ConversationCommand.Image: "Generate images by describing your imagination in words.",
|
||||
ConversationCommand.Automation: "Automatically run your query at a specified time or interval.",
|
||||
ConversationCommand.Help: "Get help with how to use or setup Khoj from the documentation",
|
||||
ConversationCommand.Summarize: "Create an appropriate summary using provided documents.",
|
||||
}
|
||||
|
||||
tool_descriptions_for_llm = {
|
||||
@@ -326,6 +328,7 @@ tool_descriptions_for_llm = {
|
||||
ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents.",
|
||||
ConversationCommand.Online: "To search for the latest, up-to-date information from the internet. Note: **Questions about Khoj should always use this data source**",
|
||||
ConversationCommand.Webpage: "To use if the user has directly provided the webpage urls or you are certain of the webpage urls to read.",
|
||||
ConversationCommand.Summarize: "To create a summary of the document provided by the user.",
|
||||
}
|
||||
|
||||
mode_descriptions_for_llm = {
|
||||
|
||||
Reference in New Issue
Block a user