Merge Improve Khoj Chat PR #183 from debanjum/improve-chat-interface

# Improve Khoj Chat
## Main Changes
- Use the new [API](https://openai.com/blog/introducing-chatgpt-and-whisper-apis) for [ChatGPT](https://openai.com/blog/chatgpt) to improve conversation quality and cost
- Improve Prompt to answer query using indexed notes
  - Previously was asking GPT to summarize the notes
  - Both the chat and answer API use this new prompt
- Support Multi-Turn conversations
  - Pass previous messages and associated reference notes to ChatGPT for context
- Show note snippets referenced to generate response
  - Allows fact-checking, getting details
- Simplify chat interface by using only single unified chat type for now

## Miscellaneous
- Replace summarize with answer API. Summarize via API not useful for now
- Only pass Khoj search results above a threshold confidence to GPT for context
  - Allows Khoj to say don't know if it can't find answer to query from notes
  - Allows relying on (only) conversation history to generate response in multi-turn conversation
- Move Chat API out of beta. Update Readme
This commit is contained in:
Debanjum
2023-03-10 19:03:44 -06:00
committed by GitHub
11 changed files with 330 additions and 271 deletions

View File

@@ -9,6 +9,7 @@ import schedule
from fastapi.staticfiles import StaticFiles
# Internal Packages
from khoj.processor.conversation.gpt import summarize
from khoj.processor.ledger.beancount_to_jsonl import BeancountToJsonl
from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
@@ -186,3 +187,39 @@ def configure_conversation_processor(conversation_processor_config):
conversation_processor.chat_session = ""
return conversation_processor
@schedule.repeat(schedule.every(15).minutes)
def save_chat_session():
# No need to create empty log file
if not (
state.processor_config
and state.processor_config.conversation
and state.processor_config.conversation.meta_log
and state.processor_config.conversation.chat_session
):
return
# Summarize Conversation Logs for this Session
chat_session = state.processor_config.conversation.chat_session
openai_api_key = state.processor_config.conversation.openai_api_key
conversation_log = state.processor_config.conversation.meta_log
model = state.processor_config.conversation.model
session = {
"summary": summarize(chat_session, summary_type="chat", model=model, api_key=openai_api_key),
"session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"],
"session-end": len(conversation_log["chat"]),
}
if "session" in conversation_log:
conversation_log["session"].append(session)
else:
conversation_log["session"] = [session]
# Save Conversation Metadata Logs to Disk
conversation_logfile = resolve_absolute_path(state.processor_config.conversation.conversation_logfile)
conversation_logfile.parent.mkdir(parents=True, exist_ok=True) # create conversation directory if doesn't exist
with open(conversation_logfile, "w+", encoding="utf-8") as logfile:
json.dump(conversation_log, logfile)
state.processor_config.conversation.chat_session = None
logger.info("📩 Saved current chat session to conversation logs")

View File

@@ -6,15 +6,9 @@
<link rel="icon" href="data:image/svg+xml,<svg xmlns=%22http://www.w3.org/2000/svg%22 viewBox=%220 0 144 144%22><text y=%22.86em%22 font-size=%22144%22>🦅</text></svg>">
<link rel="icon" type="image/png" sizes="144x144" href="/static/assets/icons/favicon-144x144.png">
<link rel="manifest" href="/static/khoj.webmanifest">
<link rel="manifest" href="/static/khoj_chat.webmanifest">
</head>
<script>
function setTypeFieldInUrl(type) {
let url = new URL(window.location.href);
url.searchParams.set("t", type.value);
window.history.pushState({}, "", url.href);
}
function formatDate(date) {
// Format date in HH:MM, DD MMM YYYY format
let time_string = date.toLocaleTimeString('en-IN', { hour: '2-digit', minute: '2-digit', hour12: false });
@@ -22,6 +16,11 @@
return `${time_string}, ${date_string}`;
}
function generateReference(reference, index) {
// Generate HTML for Chat Reference
return `<sup><abbr title="${reference}" tabindex="0">${index}</abbr></sup>`;
}
function renderMessage(message, by, dt=null) {
let message_time = formatDate(dt ?? new Date());
let by_name = by == "khoj" ? "🦅 Khoj" : "🤔 You";
@@ -31,15 +30,26 @@
<div class="chat-message-text ${by}">${message}</div>
</div>
`;
// Scroll to bottom of input-body element
// Scroll to bottom of chat-body element
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
}
function renderMessageWithReference(message, by, context=null, dt=null) {
let references = '';
if (context) {
references = context
.split("\n\n# ")
.map((reference, index) => generateReference(reference, index))
.join("<sup>,</sup>");
}
renderMessage(message+references, by, dt);
}
function chat() {
// Extract required fields for search from form
query = document.getElementById("chat-input").value.trim();
type_ = document.getElementById("chat-type").value;
console.log(`Query: ${query}, Type: ${type_}`);
let query = document.getElementById("chat-input").value.trim();
console.log(`Query: ${query}`);
// Short circuit on empty query
if (query.length === 0)
@@ -50,18 +60,15 @@
document.getElementById("chat-input").value = "";
// Generate backend API URL to execute query
url = type_ === "chat"
? `/api/beta/chat?q=${encodeURIComponent(query)}`
: `/api/beta/summarize?q=${encodeURIComponent(query)}`;
let url = `/api/chat?q=${encodeURIComponent(query)}`;
// Call specified Khoj API
fetch(url)
.then(response => response.json())
.then(data => data.response)
.then(response => {
.then(data => {
// Render message by Khoj to chat body
console.log(response);
renderMessage(response, "khoj");
console.log(data.response);
renderMessageWithReference(data.response, "khoj", data.context);
});
}
@@ -73,18 +80,13 @@
}
window.onload = function () {
// Fill type field with value passed in URL query parameters, if any.
var type_via_url = new URLSearchParams(window.location.search).get("t");
if (type_via_url)
document.getElementById("chat-type").value = type_via_url;
fetch('/api/beta/chat')
fetch('/api/chat')
.then(response => response.json())
.then(data => data.response)
.then(chat_logs => {
// Render conversation history, if any
chat_logs.forEach(chat_log => {
renderMessage(chat_log.message, chat_log.by, new Date(chat_log.created));
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created));
});
});
@@ -109,12 +111,6 @@
<!-- Chat Footer -->
<div id="chat-footer">
<input type="text" id="chat-input" class="option" onkeyup=incrementalChat(event) autofocus="autofocus" placeholder="What is the meaning of life?">
<!--Select Chat Type from: Chat, Summarize -->
<select id="chat-type" class="option" onchange="setTypeFieldInUrl(this)">
<option value="chat">Chat</option>
<option value="summarize">Summarize</option>
</select>
</div>
</body>
@@ -217,7 +213,7 @@
#chat-footer {
padding: 0;
display: grid;
grid-template-columns: minmax(70px, 85%) auto;
grid-template-columns: minmax(70px, 100%);
grid-column-gap: 10px;
grid-row-gap: 10px;
}
@@ -234,6 +230,29 @@
font-size: medium;
}
@media (pointer: coarse), (hover: none) {
abbr[title] {
position: relative;
padding-left: 4px; /* space references out to ease tapping */
}
abbr[title]:focus:after {
content: attr(title);
/* position tooltip */
position: absolute;
left: 16px; /* open tooltip to right of ref link, instead of on top of it */
width: auto;
z-index: 1; /* show tooltip above chat messages */
/* style tooltip */
background-color: #aaa;
color: #f8fafc;
border-radius: 2px;
box-shadow: 1px 1px 4px 0 rgba(0, 0, 0, 0.4);
font-size: 14px;
padding: 2px 4px;
}
}
@media only screen and (max-width: 600px) {
body {
grid-template-columns: 1fr;

View File

@@ -0,0 +1,16 @@
{
"name": "Khoj Chat",
"short_name": "Khoj Chat",
"description": "A personal assistant for your notes",
"icons": [
{
"src": "/static/assets/icons/favicon-144x144.png",
"sizes": "144x144",
"type": "image/png"
}
],
"theme_color": "#ffffff",
"background_color": "#ffffff",
"display": "standalone",
"start_url": "/chat"
}

View File

@@ -1,6 +1,7 @@
# Standard Packages
import os
import json
import logging
from datetime import datetime
# External Packages
@@ -8,6 +9,38 @@ import openai
# Internal Packages
from khoj.utils.constants import empty_escape_sequences
from khoj.utils.helpers import merge_dicts
logger = logging.getLogger(__name__)
def answer(text, user_query, model, api_key=None, temperature=0.5, max_tokens=500):
"""
Answer user query using provided text as reference with OpenAI's GPT
"""
# Initialize Variables
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
# Setup Prompt based on Summary Type
prompt = f"""
You are a friendly, helpful personal assistant.
Using the users notes below, answer their following question. If the answer is not contained within the notes, say "I don't know."
Notes:
{text}
Question: {user_query}
Answer (in second person):"""
# Get Response from GPT
logger.debug(f"Prompt for GPT: {prompt}")
response = openai.Completion.create(
prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens, stop='"""'
)
# Extract, Clean Message from GPT's Response
story = response["choices"][0]["text"]
return str(story).replace("\n\n", "")
def summarize(text, summary_type, model, user_query=None, api_key=None, temperature=0.5, max_tokens=200):
@@ -34,6 +67,7 @@ Summarize the below notes about {user_query}:
Summarize the notes in second person perspective:"""
# Get Response from GPT
logger.debug(f"Prompt for GPT: {prompt}")
response = openai.Completion.create(
prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens, frequency_penalty=0.2, stop='"""'
)
@@ -77,6 +111,7 @@ A:{ "search-type": "notes" }"""
print(f"Message -> Prompt: {text} -> {prompt}")
# Get Response from GPT
logger.debug(f"Prompt for GPT: {prompt}")
response = openai.Completion.create(
prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens, frequency_penalty=0.2, stop=["\n"]
)
@@ -86,104 +121,68 @@ A:{ "search-type": "notes" }"""
return json.loads(story.strip(empty_escape_sequences))
def understand(text, model, api_key=None, temperature=0.5, max_tokens=100, verbose=0):
def converse(text, user_query, conversation_log=None, api_key=None, temperature=0):
"""
Understand user input using OpenAI's GPT
Converse with user using OpenAI's ChatGPT
"""
# Initialize Variables
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
understand_primer = """
Objective: Extract intent and trigger emotion information as JSON from each chat message
Potential intent types and valid argument values are listed below:
- intent
- remember(memory-type, query);
- memory-type=["companion","notes","ledger","image","music"]
- search(search-type, query);
- search-type=["google"]
- generate(activity, query);
- activity=["paint","write","chat"]
- trigger-emotion(emotion)
- emotion=["happy","confidence","fear","surprise","sadness","disgust","anger","shy","curiosity","calm"]
Some examples are given below for reference:
Q: How are you doing?
A: { "intent": {"type": "generate", "activity": "chat", "query": "How are you doing?"}, "trigger-emotion": "happy" }
Q: Do you remember what I told you about my brother Antoine when we were at the beach?
A: { "intent": {"type": "remember", "memory-type": "companion", "query": "Brother Antoine when we were at the beach"}, "trigger-emotion": "curiosity" }
Q: what was that fantasy story you told me last time?
A: { "intent": {"type": "remember", "memory-type": "companion", "query": "fantasy story told last time"}, "trigger-emotion": "curiosity" }
Q: Let's make some drawings about the stars on a clear full moon night!
A: { "intent": {"type": "generate", "activity": "paint", "query": "stars on a clear full moon night"}, "trigger-emotion: "happy" }
Q: Do you know anything about Lebanon cuisine in the 18th century?
A: { "intent": {"type": "search", "search-type": "google", "query": "lebanon cusine in the 18th century"}, "trigger-emotion; "confidence" }
Q: Tell me a scary story
A: { "intent": {"type": "generate", "activity": "write", "query": "A scary story"}, "trigger-emotion": "fear" }
Q: What fiction book was I reading last week about AI starship?
A: { "intent": {"type": "remember", "memory-type": "notes", "query": "fiction book about AI starship last week"}, "trigger-emotion": "curiosity" }
Q: How much did I spend at Subway for dinner last time?
A: { "intent": {"type": "remember", "memory-type": "ledger", "query": "last Subway dinner"}, "trigger-emotion": "calm" }
Q: I'm feeling sleepy
A: { "intent": {"type": "generate", "activity": "chat", "query": "I'm feeling sleepy"}, "trigger-emotion": "calm" }
Q: What was that popular Sri lankan song that Alex had mentioned?
A: { "intent": {"type": "remember", "memory-type": "music", "query": "popular Sri lankan song mentioned by Alex"}, "trigger-emotion": "curiosity" }
Q: You're pretty funny!
A: { "intent": {"type": "generate", "activity": "chat", "query": "You're pretty funny!"}, "trigger-emotion": "shy" }
Q: Can you recommend a movie to watch from my notes?
A: { "intent": {"type": "remember", "memory-type": "notes", "query": "recommend movie to watch"}, "trigger-emotion": "curiosity" }
Q: When did I go surfing last?
A: { "intent": {"type": "remember", "memory-type": "notes", "query": "When did I go surfing last"}, "trigger-emotion": "calm" }
Q: Can you dance for me?
A: { "intent": {"type": "generate", "activity": "chat", "query": "Can you dance for me?"}, "trigger-emotion": "sad" }"""
# Setup Prompt with Understand Primer
prompt = message_to_prompt(text, understand_primer, start_sequence="\nA:", restart_sequence="\nQ:")
if verbose > 1:
print(f"Message -> Prompt: {text} -> {prompt}")
# Get Response from GPT
response = openai.Completion.create(
prompt=prompt, model=model, temperature=temperature, max_tokens=max_tokens, frequency_penalty=0.2, stop=["\n"]
)
# Extract, Clean Message from GPT's Response
story = str(response["choices"][0]["text"])
return json.loads(story.strip(empty_escape_sequences))
def converse(text, model, conversation_history=None, api_key=None, temperature=0.9, max_tokens=150):
"""
Converse with user using OpenAI's GPT
"""
# Initialize Variables
max_words = 500
model = "gpt-3.5-turbo"
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
personality_primer = "You are a friendly, helpful personal assistant."
conversation_primer = f"""
The following is a conversation with an AI assistant. The assistant is helpful, creative, clever, and a very friendly companion.
Using the notes and our chats as context, answer the following question.
Current Date: {datetime.now().strftime("%Y-%m-%d")}
Human: Hello, who are you?
AI: Hi, I am an AI conversational companion created by OpenAI. How can I help you today?"""
Notes:
{text}
Question: {user_query}"""
# Setup Prompt with Primer or Conversation History
prompt = message_to_prompt(text, conversation_history or conversation_primer)
prompt = " ".join(prompt.split()[:max_words])
messages = generate_chatml_messages_with_context(
conversation_primer,
personality_primer,
conversation_log,
)
# Get Response from GPT
response = openai.Completion.create(
prompt=prompt,
logger.debug(f"Conversation Context for GPT: {messages}")
response = openai.ChatCompletion.create(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
presence_penalty=0.6,
stop=["\n", "Human:", "AI:"],
)
# Extract, Clean Message from GPT's Response
story = str(response["choices"][0]["text"])
story = str(response["choices"][0]["message"]["content"])
return story.strip(empty_escape_sequences)
def generate_chatml_messages_with_context(user_message, system_message, conversation_log=None):
"""Generate messages for ChatGPT with context from previous conversation"""
# Extract Chat History for Context
chat_logs = [f'{chat["message"]}\n\nNotes:\n{chat.get("context","")}' for chat in conversation_log.get("chat", [])]
last_backnforth = reciprocal_conversation_to_chatml(chat_logs[-2:])
rest_backnforth = reciprocal_conversation_to_chatml(chat_logs[-4:-2])
# Format user and system messages to chatml format
system_chatml_message = [message_to_chatml(system_message, "system")]
user_chatml_message = [message_to_chatml(user_message, "user")]
return rest_backnforth + system_chatml_message + last_backnforth + user_chatml_message
def reciprocal_conversation_to_chatml(message_pair):
"""Convert a single back and forth between user and assistant to chatml format"""
return [message_to_chatml(message, role) for message, role in zip(message_pair, ["user", "assistant"])]
def message_to_chatml(message, role="assistant"):
"""Create chatml message from message and role"""
return {"role": role, "content": message}
def message_to_prompt(
user_message, conversation_history="", gpt_message=None, start_sequence="\nAI:", restart_sequence="\nHuman:"
):
@@ -193,22 +192,20 @@ def message_to_prompt(
return f"{conversation_history}{restart_sequence} {user_message}{start_sequence}{gpt_message}"
def message_to_log(user_message, gpt_message, user_message_metadata={}, conversation_log=[]):
def message_to_log(user_message, gpt_message, khoj_message_metadata={}, conversation_log=[]):
"""Create json logs from messages, metadata for conversation log"""
default_user_message_metadata = {
default_khoj_message_metadata = {
"intent": {"type": "remember", "memory-type": "notes", "query": user_message},
"trigger-emotion": "calm",
}
current_dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Create json log from Human's message
human_log = user_message_metadata or default_user_message_metadata
human_log["message"] = user_message
human_log["by"] = "you"
human_log["created"] = current_dt
human_log = {"message": user_message, "by": "you", "created": current_dt}
# Create json log from GPT's response
khoj_log = {"message": gpt_message, "by": "khoj", "created": current_dt}
khoj_log = merge_dicts(khoj_message_metadata, default_khoj_message_metadata)
khoj_log = merge_dicts({"message": gpt_message, "by": "khoj", "created": current_dt}, khoj_log)
conversation_log.extend([human_log, khoj_log])
return conversation_log

View File

@@ -1,7 +1,8 @@
# Standard Packages
import math
import yaml
import logging
from typing import List, Optional
from typing import List, Optional, Union
# External Packages
from fastapi import APIRouter
@@ -9,6 +10,7 @@ from fastapi import HTTPException
# Internal Packages
from khoj.configure import configure_processor, configure_search
from khoj.processor.conversation.gpt import converse, message_to_log, message_to_prompt
from khoj.search_type import image_search, text_search
from khoj.utils.helpers import timer
from khoj.utils.rawconfig import FullConfig, SearchResponse
@@ -53,7 +55,14 @@ async def set_config_data(updated_config: FullConfig):
@api.get("/search", response_model=List[SearchResponse])
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False):
def search(
q: str,
n: Optional[int] = 5,
t: Optional[SearchType] = None,
r: Optional[bool] = False,
score_threshold: Optional[Union[float, None]] = None,
dedupe: Optional[bool] = True,
):
results: List[SearchResponse] = []
if q is None or q == "":
logger.warn(f"No query param (q) passed in API call to initiate search")
@@ -62,9 +71,10 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
# initialize variables
user_query = q.strip()
results_count = n
score_threshold = score_threshold if score_threshold is not None else -math.inf
# return cached results, if available
query_cache_key = f"{user_query}-{n}-{t}-{r}"
query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}"
if query_cache_key in state.query_cache:
logger.debug(f"Return response from query cache")
return state.query_cache[query_cache_key]
@@ -72,7 +82,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
if (t == SearchType.Org or t == None) and state.model.orgmode_search:
# query org-mode notes
with timer("Query took", logger):
hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r)
hits, entries = text_search.query(
user_query, state.model.orgmode_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe
)
# collate and return results
with timer("Collating results took", logger):
@@ -81,7 +93,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
elif (t == SearchType.Markdown or t == None) and state.model.markdown_search:
# query markdown files
with timer("Query took", logger):
hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r)
hits, entries = text_search.query(
user_query, state.model.markdown_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe
)
# collate and return results
with timer("Collating results took", logger):
@@ -90,7 +104,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
elif (t == SearchType.Ledger or t == None) and state.model.ledger_search:
# query transactions
with timer("Query took", logger):
hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r)
hits, entries = text_search.query(
user_query, state.model.ledger_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe
)
# collate and return results
with timer("Collating results took", logger):
@@ -99,7 +115,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
elif (t == SearchType.Music or t == None) and state.model.music_search:
# query music library
with timer("Query took", logger):
hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r)
hits, entries = text_search.query(
user_query, state.model.music_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe
)
# collate and return results
with timer("Collating results took", logger):
@@ -108,7 +126,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
elif (t == SearchType.Image or t == None) and state.model.image_search:
# query images
with timer("Query took", logger):
hits = image_search.query(user_query, results_count, state.model.image_search)
hits = image_search.query(
user_query, results_count, state.model.image_search, score_threshold=score_threshold
)
output_directory = constants.web_directory / "images"
# collate and return results
@@ -129,6 +149,8 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
# Get plugin search model for specified search type, or the first one if none specified
state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())),
rank_results=r,
score_threshold=score_threshold,
dedupe=dedupe,
)
# collate and return results
@@ -162,3 +184,40 @@ def update(t: Optional[SearchType] = None, force: Optional[bool] = False):
logger.info("📬 Processor reconfigured via API")
return {"status": "ok", "message": "khoj reloaded"}
@api.get("/chat")
def chat(q: Optional[str] = None):
# Initialize Variables
api_key = state.processor_config.conversation.openai_api_key
# Load Conversation History
chat_session = state.processor_config.conversation.chat_session
meta_log = state.processor_config.conversation.meta_log
# If user query is empty, return chat history
if not q:
if meta_log.get("chat"):
return {"status": "ok", "response": meta_log["chat"]}
else:
return {"status": "ok", "response": []}
# Collate context for GPT
result_list = search(q, n=2, r=True, score_threshold=0, dedupe=False)
collated_result = "\n\n".join([f"# {item.additional['compiled']}" for item in result_list])
logger.debug(f"Reference Context:\n{collated_result}")
try:
gpt_response = converse(collated_result, q, meta_log, api_key=api_key)
status = "ok"
except Exception as e:
gpt_response = str(e)
status = "error"
# Update Conversation History
state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
state.processor_config.conversation.meta_log["chat"] = message_to_log(
q, gpt_response, khoj_message_metadata={"context": collated_result}, conversation_log=meta_log.get("chat", [])
)
return {"status": status, "response": gpt_response, "context": collated_result}

View File

@@ -1,24 +1,18 @@
# Standard Packages
import json
import logging
from typing import Optional
# External Packages
import schedule
from fastapi import APIRouter
# Internal Packages
from khoj.routers.api import search
from khoj.processor.conversation.gpt import (
converse,
answer,
extract_search_type,
message_to_log,
message_to_prompt,
understand,
summarize,
)
from khoj.utils.state import SearchType
from khoj.utils.helpers import get_from_dict, resolve_absolute_path
from khoj.utils.helpers import get_from_dict
from khoj.utils import state
@@ -48,116 +42,23 @@ def search_beta(q: str, n: Optional[int] = 1):
return {"status": "ok", "result": search_results, "type": search_type}
@api_beta.get("/summarize")
def summarize_beta(q: str):
@api_beta.get("/answer")
def answer_beta(q: str):
# Initialize Variables
model = state.processor_config.conversation.model
api_key = state.processor_config.conversation.openai_api_key
# Load Conversation History
chat_session = state.processor_config.conversation.chat_session
meta_log = state.processor_config.conversation.meta_log
# Collate context for GPT
result_list = search(q, n=2, r=True, score_threshold=0, dedupe=False)
collated_result = "\n\n".join([f"# {item.additional['compiled']}" for item in result_list])
logger.debug(f"Reference Context:\n{collated_result}")
# Converse with OpenAI GPT
result_list = search(q, n=1, r=True)
collated_result = "\n".join([item.entry for item in result_list])
logger.debug(f"Semantically Similar Notes:\n{collated_result}")
# Make GPT respond to user query using provided context
try:
gpt_response = summarize(collated_result, summary_type="notes", user_query=q, model=model, api_key=api_key)
gpt_response = answer(collated_result, user_query=q, model=model, api_key=api_key)
status = "ok"
except Exception as e:
gpt_response = str(e)
status = "error"
# Update Conversation History
state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
state.processor_config.conversation.meta_log["chat"] = message_to_log(
q, gpt_response, conversation_log=meta_log.get("chat", [])
)
return {"status": status, "response": gpt_response}
@api_beta.get("/chat")
def chat(q: Optional[str] = None):
# Initialize Variables
model = state.processor_config.conversation.model
api_key = state.processor_config.conversation.openai_api_key
# Load Conversation History
chat_session = state.processor_config.conversation.chat_session
meta_log = state.processor_config.conversation.meta_log
# If user query is empty, return chat history
if not q:
if meta_log.get("chat"):
return {"status": "ok", "response": meta_log["chat"]}
else:
return {"status": "ok", "response": []}
# Converse with OpenAI GPT
metadata = understand(q, model=model, api_key=api_key, verbose=state.verbose)
logger.debug(f'Understood: {get_from_dict(metadata, "intent")}')
if get_from_dict(metadata, "intent", "memory-type") == "notes":
query = get_from_dict(metadata, "intent", "query")
result_list = search(query, n=1, t=SearchType.Org, r=True)
collated_result = "\n".join([item.entry for item in result_list])
logger.debug(f"Semantically Similar Notes:\n{collated_result}")
try:
gpt_response = summarize(collated_result, summary_type="notes", user_query=q, model=model, api_key=api_key)
status = "ok"
except Exception as e:
gpt_response = str(e)
status = "error"
else:
try:
gpt_response = converse(q, model, chat_session, api_key=api_key)
status = "ok"
except Exception as e:
gpt_response = str(e)
status = "error"
# Update Conversation History
state.processor_config.conversation.chat_session = message_to_prompt(q, chat_session, gpt_message=gpt_response)
state.processor_config.conversation.meta_log["chat"] = message_to_log(
q, gpt_response, metadata, meta_log.get("chat", [])
)
return {"status": status, "response": gpt_response}
@schedule.repeat(schedule.every(5).minutes)
def save_chat_session():
# No need to create empty log file
if not (
state.processor_config
and state.processor_config.conversation
and state.processor_config.conversation.meta_log
and state.processor_config.conversation.chat_session
):
return
# Summarize Conversation Logs for this Session
chat_session = state.processor_config.conversation.chat_session
openai_api_key = state.processor_config.conversation.openai_api_key
conversation_log = state.processor_config.conversation.meta_log
model = state.processor_config.conversation.model
session = {
"summary": summarize(chat_session, summary_type="chat", model=model, api_key=openai_api_key),
"session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"],
"session-end": len(conversation_log["chat"]),
}
if "session" in conversation_log:
conversation_log["session"].append(session)
else:
conversation_log["session"] = [session]
# Save Conversation Metadata Logs to Disk
conversation_logfile = resolve_absolute_path(state.processor_config.conversation.conversation_logfile)
conversation_logfile.parent.mkdir(parents=True, exist_ok=True) # create conversation directory if doesn't exist
with open(conversation_logfile, "w+", encoding="utf-8") as logfile:
json.dump(conversation_log, logfile)
state.processor_config.conversation.chat_session = None
logger.info("📩 Saved current chat session to conversation logs")

View File

@@ -1,5 +1,6 @@
# Standard Packages
import glob
import math
import pathlib
import copy
import shutil
@@ -142,7 +143,7 @@ def extract_metadata(image_name):
return image_processed_metadata
def query(raw_query, count, model: ImageSearchModel):
def query(raw_query, count, model: ImageSearchModel, score_threshold: float = -math.inf):
# Set query to image content if query is of form file:/path/to/file.png
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True)
@@ -198,6 +199,9 @@ def query(raw_query, count, model: ImageSearchModel):
for corpus_id, scores in image_hits.items()
]
# Filter results by score threshold
hits = [hit for hit in hits if hit["image_score"] >= score_threshold]
# Sort the images based on their combined metadata, image scores
return sorted(hits, key=lambda hit: hit["score"], reverse=True)

View File

@@ -1,5 +1,6 @@
# Standard Packages
import logging
import math
from pathlib import Path
from typing import List, Tuple, Type
@@ -99,7 +100,13 @@ def compute_embeddings(
return corpus_embeddings
def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) -> Tuple[List[dict], List[Entry]]:
def query(
raw_query: str,
model: TextSearchModel,
rank_results: bool = False,
score_threshold: float = -math.inf,
dedupe: bool = True,
) -> Tuple[List[dict], List[Entry]]:
"Search for entries that answer the query"
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
@@ -129,11 +136,15 @@ def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) ->
if rank_results:
hits = cross_encoder_score(model.cross_encoder, query, entries, hits)
# Filter results by score threshold
hits = [hit for hit in hits if hit.get("cross-score", hit.get("score")) >= score_threshold]
# Order results by cross-encoder score followed by bi-encoder score
hits = sort_results(rank_results, hits)
# Deduplicate entries by raw entry text before showing to users
hits = deduplicate_results(entries, hits)
if dedupe:
hits = deduplicate_results(entries, hits)
return hits, entries
@@ -143,7 +154,7 @@ def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse]
SearchResponse.parse_obj(
{
"entry": entries[hit["corpus_id"]].raw,
"score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}",
"score": f"{hit.get('cross-score', 'score')}:.3f",
"additional": {"file": entries[hit["corpus_id"]].file, "compiled": entries[hit["corpus_id"]].compiled},
}
)