From 8f36572a9bc5224e2cdef1bf8e9380d12118343f Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 5 Jul 2023 20:49:25 -0700 Subject: [PATCH] Improve typing, null checks in controllers and gpt functions --- src/khoj/processor/conversation/gpt.py | 9 +++++++-- src/khoj/routers/api.py | 8 ++++---- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/khoj/processor/conversation/gpt.py b/src/khoj/processor/conversation/gpt.py index 2ee93053..423a2124 100644 --- a/src/khoj/processor/conversation/gpt.py +++ b/src/khoj/processor/conversation/gpt.py @@ -2,6 +2,7 @@ import json import logging from datetime import datetime +from typing import Optional # Internal Packages from khoj.utils.constants import empty_escape_sequences @@ -47,6 +48,8 @@ def summarize(text, summary_type, model, user_query=None, api_key=None, temperat prompt = prompts.summarize_chat.format(text=text) elif summary_type == "notes": prompt = prompts.summarize_notes.format(text=text, user_query=user_query) + else: + raise ValueError(f"Invalid summary type: {summary_type}") # Get Response from GPT logger.debug(f"Prompt for GPT: {prompt}") @@ -64,7 +67,9 @@ def summarize(text, summary_type, model, user_query=None, api_key=None, temperat return str(response).replace("\n\n", "") -def extract_questions(text, model="text-davinci-003", conversation_log={}, api_key=None, temperature=0, max_tokens=100): +def extract_questions( + text, model: Optional[str] = "text-davinci-003", conversation_log={}, api_key=None, temperature=0, max_tokens=100 +): """ Infer search queries to retrieve relevant notes to answer user query """ @@ -148,7 +153,7 @@ def converse( references, user_query, conversation_log={}, - model="gpt-3.5-turbo", + model: Optional[str] = "gpt-3.5-turbo", api_key=None, temperature=0.2, completion_func=None, diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 87f428da..fd53c5f2 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -324,7 +324,7 @@ async def search( state.query_cache[query_cache_key] = results user_state = { - "client_host": request.client.host, + "client_host": request.client.host if request.client else "unknown", "user_agent": user_agent or "unknown", "referer": referer or "unknown", "host": host or "unknown", @@ -380,7 +380,7 @@ def update( logger.info("📬 Processor reconfigured via API") user_state = { - "client_host": request.client.host, + "client_host": request.client.host if request.client else None, "user_agent": user_agent or "unknown", "referer": referer or "unknown", "host": host or "unknown", @@ -416,7 +416,7 @@ def chat_init( meta_log = state.processor_config.conversation.meta_log user_state = { - "client_host": request.client.host, + "client_host": request.client.host if request.client else None, "user_agent": user_agent or "unknown", "referer": referer or "unknown", "host": host or "unknown", @@ -503,7 +503,7 @@ async def chat( logger.debug(f"Conversation Type: {conversation_type}") user_state = { - "client_host": request.client.host, + "client_host": request.client.host if request.client else None, "user_agent": user_agent or "unknown", "referer": referer or "unknown", "host": host or "unknown",