diff --git a/src/khoj/processor/conversation/gpt4all/utils.py b/src/khoj/processor/conversation/gpt4all/utils.py index bb88ec38..a6a5901a 100644 --- a/src/khoj/processor/conversation/gpt4all/utils.py +++ b/src/khoj/processor/conversation/gpt4all/utils.py @@ -2,7 +2,7 @@ import os import logging import requests from gpt4all import GPT4All -import tqdm +from tqdm import tqdm from khoj.processor.conversation.gpt4all import model_metadata @@ -24,9 +24,17 @@ def download_model(model_name): logger.debug(f"Downloading model {model_name} from {url} to {filename}...") with requests.get(url, stream=True) as r: r.raise_for_status() - with open(filename, "wb") as f: + total_size = int(r.headers.get("content-length", 0)) + with open(filename, "wb") as f, tqdm( + unit="B", # unit string to be displayed. + unit_scale=True, # let tqdm to determine the scale in kilo, mega..etc. + unit_divisor=1024, # is used when unit_scale is true + total=total_size, # the total iteration. + desc=filename.split("/")[-1], # prefix to be displayed on progress bar. + ) as progress_bar: for chunk in r.iter_content(chunk_size=8192): f.write(chunk) + progress_bar.update(len(chunk)) return GPT4All(model_name) except Exception as e: logger.error(f"Failed to download model {model_name} from {url} to {filename}. Error: {e}") diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index c2103b8d..1dec312e 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -709,21 +709,22 @@ async def extract_references_and_questions( if conversation_type == "notes": # Infer search queries from user message with timer("Extracting search queries took", logger): - if state.processor_config.conversation and state.processor_config.conversation.openai_model: - api_key = state.processor_config.conversation.openai_model.api_key - chat_model = state.processor_config.conversation.openai_model.chat_model - inferred_queries = extract_questions(q, model=chat_model, api_key=api_key, conversation_log=meta_log) - else: + # If we've reached here, either the user has enabled offline chat or the openai model is enabled. + if state.processor_config.conversation.enable_offline_chat: loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model inferred_queries = extract_questions_offline( q, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False ) + elif state.processor_config.conversation.openai_model: + api_key = state.processor_config.conversation.openai_model.api_key + chat_model = state.processor_config.conversation.openai_model.chat_model + inferred_queries = extract_questions(q, model=chat_model, api_key=api_key, conversation_log=meta_log) # Collate search results as context for GPT with timer("Searching knowledge base took", logger): result_list = [] for query in inferred_queries: - n_items = n if state.processor_config.conversation.openai_model else min(n, 3) + n_items = min(n, 3) if state.processor_config.conversation.enable_offline_chat else n result_list.extend( await search(query, request=request, n=n_items, r=True, score_threshold=-5.0, dedupe=False) ) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index f0bfb58f..e8516c38 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -86,6 +86,7 @@ def generate_chat_response( # Switch to general conversation type if no relevant notes found for the given query conversation_type = "notes" if compiled_references else "general" logger.debug(f"Conversation Type: {conversation_type}") + chat_response = None try: with timer("Generating chat response took", logger): @@ -98,7 +99,17 @@ def generate_chat_response( meta_log=meta_log, ) - if state.processor_config.conversation.openai_model: + if state.processor_config.conversation.enable_offline_chat: + loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model + chat_response = converse_offline( + references=compiled_references, + user_query=q, + loaded_model=loaded_model, + conversation_log=meta_log, + completion_func=partial_completion, + ) + + elif state.processor_config.conversation.openai_model: api_key = state.processor_config.conversation.openai_model.api_key chat_model = state.processor_config.conversation.openai_model.chat_model chat_response = converse( @@ -109,15 +120,6 @@ def generate_chat_response( api_key=api_key, completion_func=partial_completion, ) - else: - loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model - chat_response = converse_offline( - references=compiled_references, - user_query=q, - loaded_model=loaded_model, - conversation_log=meta_log, - completion_func=partial_completion, - ) except Exception as e: logger.error(e, exc_info=True) diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py index 655e6d34..7882edcf 100644 --- a/src/khoj/utils/config.py +++ b/src/khoj/utils/config.py @@ -94,7 +94,7 @@ class ConversationProcessorConfigModel: self.chat_session: List[str] = [] self.meta_log: dict = {} - if not self.openai_model and self.enable_offline_chat: + if self.enable_offline_chat: self.gpt4all_model.loaded_model = download_model(self.gpt4all_model.chat_model) else: self.gpt4all_model.loaded_model = None