mirror of
https://github.com/khoaliber/khoj.git
synced 2026-03-02 21:19:12 +00:00
Continue interrupt queries only after previous query written to DB
This commit is contained in:
@@ -682,6 +682,7 @@ async def chat(
|
|||||||
timezone = body.timezone
|
timezone = body.timezone
|
||||||
raw_images = body.images
|
raw_images = body.images
|
||||||
raw_query_files = body.files
|
raw_query_files = body.files
|
||||||
|
interrupt_flag = body.interrupt
|
||||||
|
|
||||||
async def event_generator(q: str, images: list[str]):
|
async def event_generator(q: str, images: list[str]):
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
@@ -920,6 +921,34 @@ async def chat(
|
|||||||
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
meta_log = conversation.conversation_log
|
meta_log = conversation.conversation_log
|
||||||
|
|
||||||
|
# If interrupt flag is set, wait for the previous turn to be saved before proceeding
|
||||||
|
if interrupt_flag:
|
||||||
|
max_wait_time = 20.0 # seconds
|
||||||
|
wait_interval = 0.3 # seconds
|
||||||
|
wait_start = wait_current = time.time()
|
||||||
|
while wait_current - wait_start < max_wait_time:
|
||||||
|
# Refresh conversation to check if interrupted message saved to DB
|
||||||
|
conversation = await ConversationAdapters.aget_conversation_by_user(
|
||||||
|
user,
|
||||||
|
client_application=request.user.client_app,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
conversation
|
||||||
|
and conversation.messages
|
||||||
|
and conversation.messages[-1].by == "khoj"
|
||||||
|
and not conversation.messages[-1].message
|
||||||
|
):
|
||||||
|
logger.info(f"Detected interrupted message save to conversation {conversation_id}.")
|
||||||
|
break
|
||||||
|
await asyncio.sleep(wait_interval)
|
||||||
|
wait_current = time.time()
|
||||||
|
|
||||||
|
if wait_current - wait_start >= max_wait_time:
|
||||||
|
logger.warning(
|
||||||
|
f"Timeout waiting to load interrupted context from conversation {conversation_id}. Proceed without previous context."
|
||||||
|
)
|
||||||
|
|
||||||
# If interrupted message in DB
|
# If interrupted message in DB
|
||||||
if (
|
if (
|
||||||
conversation
|
conversation
|
||||||
|
|||||||
@@ -168,6 +168,7 @@ class ChatRequestBody(BaseModel):
|
|||||||
images: Optional[list[str]] = None
|
images: Optional[list[str]] = None
|
||||||
files: Optional[list[FileAttachment]] = []
|
files: Optional[list[FileAttachment]] = []
|
||||||
create_new: Optional[bool] = False
|
create_new: Optional[bool] = False
|
||||||
|
interrupt: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
class Entry:
|
class Entry:
|
||||||
|
|||||||
Reference in New Issue
Block a user