mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-06 21:29:12 +00:00
Make streaming optional for the /chat endpoint (#287)
* Update the /chat endpoint to conditionally support streaming - If streams are enabled, return the threadgenerator as it does currently - If stream is disabled, return a JSON response with the response/compiled references separated out - Correspondingly, update the chat.html UI to use the streamed API, as well as Obsidian - Rename chat/init/ to chat/history * Update khoj.el to use the /history endpoint - Update corresponding unit tests to use stream=true * Remove & from call to /chat for obsidian * Abstract functions out into a helpers.py file and clean up some of the error-catching
This commit is contained in:
@@ -688,7 +688,7 @@ Render results in BUFFER-NAME using QUERY, CONTENT-TYPE."
|
||||
|
||||
(defun khoj--load-chat-history (buffer-name)
|
||||
"Load Khoj Chat conversation history into BUFFER-NAME."
|
||||
(let ((json-response (cdr (assoc 'response (khoj--query-chat-api "")))))
|
||||
(let ((json-response (cdr (assoc 'response (khoj--get-chat-history-api)))))
|
||||
(with-current-buffer (get-buffer-create buffer-name)
|
||||
(erase-buffer)
|
||||
(insert "* Khoj Chat\n")
|
||||
@@ -766,7 +766,21 @@ Render results in BUFFER-NAME using QUERY, CONTENT-TYPE."
|
||||
"Send QUERY to Khoj Chat API."
|
||||
(let* ((url-request-method "GET")
|
||||
(encoded-query (url-hexify-string query))
|
||||
(query-url (format "%s/api/chat?q=%s&n=%s&client=emacs" khoj-server-url khoj-results-count encoded-query)))
|
||||
(query-url (format "%s/api/chat?q=%s&n=%s&client=emacs" khoj-server-url encoded-query khoj-results-count)))
|
||||
(with-temp-buffer
|
||||
(condition-case ex
|
||||
(progn
|
||||
(url-insert-file-contents query-url)
|
||||
(json-parse-buffer :object-type 'alist))
|
||||
('file-error (cond ((string-match "Internal server error" (nth 2 ex))
|
||||
(message "Chat processor not configured. Configure OpenAI API key and restart it. Exception: [%s]" ex))
|
||||
(t (message "Chat exception: [%s]" ex))))))))
|
||||
|
||||
|
||||
(defun khoj--get-chat-history-api ()
|
||||
"Send QUERY to Khoj Chat History API."
|
||||
(let* ((url-request-method "GET")
|
||||
(query-url (format "%s/api/chat/history?client=emacs" khoj-server-url)))
|
||||
(with-temp-buffer
|
||||
(condition-case ex
|
||||
(progn
|
||||
|
||||
@@ -140,7 +140,7 @@ export class KhojChatModal extends Modal {
|
||||
|
||||
async getChatHistory(): Promise<void> {
|
||||
// Get chat history from Khoj backend
|
||||
let chatUrl = `${this.setting.khojUrl}/api/chat/init?client=obsidian`;
|
||||
let chatUrl = `${this.setting.khojUrl}/api/chat/history?client=obsidian`;
|
||||
let response = await request(chatUrl);
|
||||
let chatLogs = JSON.parse(response).response;
|
||||
chatLogs.forEach((chatLog: any) => {
|
||||
@@ -157,7 +157,7 @@ export class KhojChatModal extends Modal {
|
||||
|
||||
// Get chat response from Khoj backend
|
||||
let encodedQuery = encodeURIComponent(query);
|
||||
let chatUrl = `${this.setting.khojUrl}/api/chat?q=${encodedQuery}&n=${this.setting.resultsCount}&client=obsidian`;
|
||||
let chatUrl = `${this.setting.khojUrl}/api/chat?q=${encodedQuery}&n=${this.setting.resultsCount}&client=obsidian&stream=true`;
|
||||
let responseElement = this.createKhojResponseDiv();
|
||||
|
||||
// Temporary status message to indicate that Khoj is thinking
|
||||
|
||||
@@ -63,7 +63,7 @@
|
||||
document.getElementById("chat-input").value = "";
|
||||
|
||||
// Generate backend API URL to execute query
|
||||
let url = `/api/chat?q=${encodeURIComponent(query)}&n=${results_count}&client=web`;
|
||||
let url = `/api/chat?q=${encodeURIComponent(query)}&n=${results_count}&client=web&stream=true`;
|
||||
|
||||
let chat_body = document.getElementById("chat-body");
|
||||
let new_response = document.createElement("div");
|
||||
@@ -130,7 +130,7 @@
|
||||
}
|
||||
|
||||
window.onload = function () {
|
||||
fetch('/api/chat/init?client=web')
|
||||
fetch('/api/chat/history?client=web')
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
if (data.detail) {
|
||||
|
||||
@@ -4,9 +4,8 @@ import math
|
||||
import time
|
||||
import yaml
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import json
|
||||
from typing import List, Optional, Union
|
||||
from functools import partial
|
||||
|
||||
# External Packages
|
||||
from fastapi import APIRouter, HTTPException, Header, Request
|
||||
@@ -14,8 +13,6 @@ from sentence_transformers import util
|
||||
|
||||
# Internal Packages
|
||||
from khoj.configure import configure_processor, configure_search
|
||||
from khoj.processor.conversation.gpt import converse, extract_questions
|
||||
from khoj.processor.conversation.utils import message_to_log, reciprocal_conversation_to_chatml
|
||||
from khoj.search_type import image_search, text_search
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
@@ -35,7 +32,11 @@ from khoj.utils.rawconfig import (
|
||||
from khoj.utils.state import SearchType
|
||||
from khoj.utils import state, constants
|
||||
from khoj.utils.yaml import save_config_to_file_updated_state
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.responses import StreamingResponse, Response
|
||||
from khoj.routers.helpers import perform_chat_checks, generate_chat_response
|
||||
from khoj.processor.conversation.gpt import extract_questions
|
||||
from fastapi.requests import Request
|
||||
|
||||
|
||||
# Initialize Router
|
||||
api = APIRouter()
|
||||
@@ -408,22 +409,15 @@ def update(
|
||||
return {"status": "ok", "message": "khoj reloaded"}
|
||||
|
||||
|
||||
@api.get("/chat/init")
|
||||
def chat_init(
|
||||
@api.get("/chat/history")
|
||||
def chat_history(
|
||||
request: Request,
|
||||
client: Optional[str] = None,
|
||||
user_agent: Optional[str] = Header(None),
|
||||
referer: Optional[str] = Header(None),
|
||||
host: Optional[str] = Header(None),
|
||||
):
|
||||
if (
|
||||
state.processor_config is None
|
||||
or state.processor_config.conversation is None
|
||||
or state.processor_config.conversation.openai_api_key is None
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Set your OpenAI API key via Khoj settings and restart it to use Khoj Chat."
|
||||
)
|
||||
perform_chat_checks()
|
||||
|
||||
# Load Conversation History
|
||||
meta_log = state.processor_config.conversation.meta_log
|
||||
@@ -444,53 +438,71 @@ def chat_init(
|
||||
return {"status": "ok", "response": meta_log.get("chat", [])}
|
||||
|
||||
|
||||
@api.get("/chat", response_class=StreamingResponse)
|
||||
@api.get("/chat", response_class=Response)
|
||||
async def chat(
|
||||
request: Request,
|
||||
q: Optional[str] = None,
|
||||
q: str,
|
||||
n: Optional[int] = 5,
|
||||
client: Optional[str] = None,
|
||||
stream: Optional[bool] = False,
|
||||
user_agent: Optional[str] = Header(None),
|
||||
referer: Optional[str] = Header(None),
|
||||
host: Optional[str] = Header(None),
|
||||
) -> StreamingResponse:
|
||||
def _save_to_conversation_log(
|
||||
q: str,
|
||||
gpt_response: str,
|
||||
user_message_time: str,
|
||||
compiled_references: List[str],
|
||||
inferred_queries: List[str],
|
||||
meta_log,
|
||||
):
|
||||
state.processor_config.conversation.chat_session += reciprocal_conversation_to_chatml([q, gpt_response])
|
||||
state.processor_config.conversation.meta_log["chat"] = message_to_log(
|
||||
q,
|
||||
gpt_response,
|
||||
user_message_metadata={"created": user_message_time},
|
||||
khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}},
|
||||
conversation_log=meta_log.get("chat", []),
|
||||
)
|
||||
) -> Response:
|
||||
perform_chat_checks()
|
||||
compiled_references, inferred_queries = await extract_references_and_questions(request, q, (n or 5))
|
||||
|
||||
if (
|
||||
state.processor_config is None
|
||||
or state.processor_config.conversation is None
|
||||
or state.processor_config.conversation.openai_api_key is None
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Set your OpenAI API key via Khoj settings and restart it to use Khoj Chat."
|
||||
)
|
||||
# Get the (streamed) chat response from GPT.
|
||||
gpt_response = generate_chat_response(
|
||||
q,
|
||||
meta_log=state.processor_config.conversation.meta_log,
|
||||
compiled_references=compiled_references,
|
||||
inferred_queries=inferred_queries,
|
||||
)
|
||||
if gpt_response is None:
|
||||
return Response(content=gpt_response, media_type="text/plain", status_code=500)
|
||||
|
||||
if stream:
|
||||
return StreamingResponse(gpt_response, media_type="text/event-stream", status_code=200)
|
||||
|
||||
# Get the full response from the generator if the stream is not requested.
|
||||
aggregated_gpt_response = ""
|
||||
while True:
|
||||
try:
|
||||
aggregated_gpt_response += next(gpt_response)
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
actual_response = aggregated_gpt_response.split("### compiled references:")[0]
|
||||
|
||||
response_obj = {"response": actual_response, "context": compiled_references}
|
||||
|
||||
user_state = {
|
||||
"client_host": request.client.host if request.client else None,
|
||||
"user_agent": user_agent or "unknown",
|
||||
"referer": referer or "unknown",
|
||||
"host": host or "unknown",
|
||||
}
|
||||
|
||||
state.telemetry += [
|
||||
log_telemetry(
|
||||
telemetry_type="api", api="chat", client=client, app_config=state.config.app, properties=user_state
|
||||
)
|
||||
]
|
||||
|
||||
return Response(content=json.dumps(response_obj), media_type="application/json", status_code=200)
|
||||
|
||||
|
||||
async def extract_references_and_questions(
|
||||
request: Request,
|
||||
q: str,
|
||||
n: int,
|
||||
):
|
||||
# Load Conversation History
|
||||
meta_log = state.processor_config.conversation.meta_log
|
||||
|
||||
# If user query is empty, return nothing
|
||||
if not q:
|
||||
return StreamingResponse(None)
|
||||
|
||||
# Initialize Variables
|
||||
api_key = state.processor_config.conversation.openai_api_key
|
||||
chat_model = state.processor_config.conversation.chat_model
|
||||
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
conversation_type = "general" if q.startswith("@general") else "notes"
|
||||
compiled_references = []
|
||||
inferred_queries = []
|
||||
@@ -509,39 +521,4 @@ async def chat(
|
||||
)
|
||||
compiled_references = [item.additional["compiled"] for item in result_list]
|
||||
|
||||
# Switch to general conversation type if no relevant notes found for the given query
|
||||
conversation_type = "notes" if compiled_references else "general"
|
||||
logger.debug(f"Conversation Type: {conversation_type}")
|
||||
|
||||
user_state = {
|
||||
"client_host": request.client.host if request.client else None,
|
||||
"user_agent": user_agent or "unknown",
|
||||
"referer": referer or "unknown",
|
||||
"host": host or "unknown",
|
||||
}
|
||||
|
||||
state.telemetry += [
|
||||
log_telemetry(
|
||||
telemetry_type="api", api="chat", client=client, app_config=state.config.app, properties=user_state
|
||||
)
|
||||
]
|
||||
|
||||
try:
|
||||
with timer("Generating chat response took", logger):
|
||||
partial_completion = partial(
|
||||
_save_to_conversation_log,
|
||||
q,
|
||||
user_message_time=user_message_time,
|
||||
compiled_references=compiled_references,
|
||||
inferred_queries=inferred_queries,
|
||||
meta_log=meta_log,
|
||||
)
|
||||
|
||||
gpt_response = converse(
|
||||
compiled_references, q, meta_log, model=chat_model, api_key=api_key, completion_func=partial_completion
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
gpt_response = str(e)
|
||||
|
||||
return StreamingResponse(gpt_response, media_type="text/event-stream", status_code=200)
|
||||
return compiled_references, inferred_queries
|
||||
|
||||
82
src/khoj/routers/helpers.py
Normal file
82
src/khoj/routers/helpers.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from fastapi import HTTPException
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import timer
|
||||
from khoj.processor.conversation.gpt import converse
|
||||
from khoj.processor.conversation.utils import message_to_log, reciprocal_conversation_to_chatml
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def perform_chat_checks():
|
||||
if (
|
||||
state.processor_config is None
|
||||
or state.processor_config.conversation is None
|
||||
or state.processor_config.conversation.openai_api_key is None
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Set your OpenAI API key via Khoj settings and restart it to use Khoj Chat."
|
||||
)
|
||||
|
||||
|
||||
def generate_chat_response(
|
||||
q: str,
|
||||
meta_log: dict,
|
||||
compiled_references: List[str] = [],
|
||||
inferred_queries: List[str] = [],
|
||||
):
|
||||
def _save_to_conversation_log(
|
||||
q: str,
|
||||
gpt_response: str,
|
||||
user_message_time: str,
|
||||
compiled_references: List[str],
|
||||
inferred_queries: List[str],
|
||||
meta_log,
|
||||
):
|
||||
state.processor_config.conversation.chat_session += reciprocal_conversation_to_chatml([q, gpt_response])
|
||||
state.processor_config.conversation.meta_log["chat"] = message_to_log(
|
||||
q,
|
||||
gpt_response,
|
||||
user_message_metadata={"created": user_message_time},
|
||||
khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}},
|
||||
conversation_log=meta_log.get("chat", []),
|
||||
)
|
||||
|
||||
# Load Conversation History
|
||||
meta_log = state.processor_config.conversation.meta_log
|
||||
|
||||
# Initialize Variables
|
||||
api_key = state.processor_config.conversation.openai_api_key
|
||||
chat_model = state.processor_config.conversation.chat_model
|
||||
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
conversation_type = "general" if q.startswith("@general") else "notes"
|
||||
|
||||
# Switch to general conversation type if no relevant notes found for the given query
|
||||
conversation_type = "notes" if compiled_references else "general"
|
||||
logger.debug(f"Conversation Type: {conversation_type}")
|
||||
|
||||
try:
|
||||
with timer("Generating chat response took", logger):
|
||||
partial_completion = partial(
|
||||
_save_to_conversation_log,
|
||||
q,
|
||||
user_message_time=user_message_time,
|
||||
compiled_references=compiled_references,
|
||||
inferred_queries=inferred_queries,
|
||||
meta_log=meta_log,
|
||||
)
|
||||
|
||||
gpt_response = converse(
|
||||
compiled_references, q, meta_log, model=chat_model, api_key=api_key, completion_func=partial_completion
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
return gpt_response
|
||||
Reference in New Issue
Block a user