From e2abc1a257b3b58e1a0355ea23aef20b86869eb6 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 17 Oct 2024 23:05:43 -0700 Subject: [PATCH] Handle multiple images shared in query to chat API Previously Khoj could respond to a single shared image at a time. This changes updates the chat API to accept multiple images shared by the user and send it to the appropriate chat actors including the openai response generation chat actor for getting an image aware response --- .../conversation/google/gemini_chat.py | 3 +- src/khoj/processor/conversation/openai/gpt.py | 8 +-- src/khoj/processor/conversation/utils.py | 30 +++++---- src/khoj/processor/image/generate.py | 4 +- src/khoj/processor/tools/online_search.py | 8 +-- src/khoj/routers/api.py | 4 +- src/khoj/routers/api_chat.py | 62 +++++++++---------- src/khoj/routers/helpers.py | 52 ++++++++-------- 8 files changed, 90 insertions(+), 81 deletions(-) diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 7359b3eb..e8848806 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -6,7 +6,7 @@ from typing import Dict, Optional from langchain.schema import ChatMessage -from khoj.database.models import Agent, KhojUser +from khoj.database.models import Agent, ChatModelOptions, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.google.utils import ( format_messages_for_gemini, @@ -187,6 +187,7 @@ def converse_gemini( model_name=model, max_prompt_size=max_prompt_size, tokenizer_name=tokenizer_name, + model_type=ChatModelOptions.ModelType.GOOGLE, ) messages, system_prompt = format_messages_for_gemini(messages, system_prompt) diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index ad02b10e..4a656fac 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -30,7 +30,7 @@ def extract_questions( api_base_url=None, location_data: LocationData = None, user: KhojUser = None, - uploaded_image_url: Optional[str] = None, + query_images: Optional[list[str]] = None, vision_enabled: bool = False, personality_context: Optional[str] = None, ): @@ -74,7 +74,7 @@ def extract_questions( prompt = construct_structured_message( message=prompt, - image_url=uploaded_image_url, + images=query_images, model_type=ChatModelOptions.ModelType.OPENAI, vision_enabled=vision_enabled, ) @@ -135,7 +135,7 @@ def converse( location_data: LocationData = None, user_name: str = None, agent: Agent = None, - image_url: Optional[str] = None, + query_images: Optional[list[str]] = None, vision_available: bool = False, ): """ @@ -191,7 +191,7 @@ def converse( model_name=model, max_prompt_size=max_prompt_size, tokenizer_name=tokenizer_name, - uploaded_image_url=image_url, + query_images=query_images, vision_enabled=vision_available, model_type=ChatModelOptions.ModelType.OPENAI, ) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index e841c484..8d799745 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -109,7 +109,7 @@ def save_to_conversation_log( client_application: ClientApplication = None, conversation_id: str = None, automation_id: str = None, - uploaded_image_url: str = None, + query_images: List[str] = None, ): user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S") updated_conversation = message_to_log( @@ -117,7 +117,7 @@ def save_to_conversation_log( chat_response=chat_response, user_message_metadata={ "created": user_message_time, - "uploadedImageData": uploaded_image_url, + "images": query_images, }, khoj_message_metadata={ "context": compiled_references, @@ -145,10 +145,18 @@ Khoj: "{inferred_queries if ("text-to-image" in intent_type) else chat_response} ) -# Format user and system messages to chatml format -def construct_structured_message(message, image_url, model_type, vision_enabled): - if image_url and vision_enabled and model_type == ChatModelOptions.ModelType.OPENAI: - return [{"type": "text", "text": message}, {"type": "image_url", "image_url": {"url": image_url}}] +def construct_structured_message(message: str, images: list[str], model_type: str, vision_enabled: bool): + """ + Format messages into appropriate multimedia format for supported chat model types + """ + if not images or not vision_enabled: + return message + + if model_type == ChatModelOptions.ModelType.OPENAI: + return [ + {"type": "text", "text": message}, + *[{"type": "image_url", "image_url": {"url": image}} for image in images], + ] return message @@ -160,7 +168,7 @@ def generate_chatml_messages_with_context( loaded_model: Optional[Llama] = None, max_prompt_size=None, tokenizer_name=None, - uploaded_image_url=None, + query_images=None, vision_enabled=False, model_type="", ): @@ -183,9 +191,7 @@ def generate_chatml_messages_with_context( message_content = chat["message"] + message_notes - message_content = construct_structured_message( - message_content, chat.get("uploadedImageData"), model_type, vision_enabled - ) + message_content = construct_structured_message(message_content, chat.get("images"), model_type, vision_enabled) reconstructed_message = ChatMessage(content=message_content, role=role) @@ -198,7 +204,7 @@ def generate_chatml_messages_with_context( if not is_none_or_empty(user_message): messages.append( ChatMessage( - content=construct_structured_message(user_message, uploaded_image_url, model_type, vision_enabled), + content=construct_structured_message(user_message, query_images, model_type, vision_enabled), role="user", ) ) @@ -222,7 +228,6 @@ def truncate_messages( tokenizer_name=None, ) -> list[ChatMessage]: """Truncate messages to fit within max prompt size supported by model""" - default_tokenizer = "gpt-4o" try: @@ -252,6 +257,7 @@ def truncate_messages( system_message = messages.pop(idx) break + # TODO: Handle truncation of multi-part message.content, i.e when message.content is a list[dict] rather than a string system_message_tokens = ( len(encoder.encode(system_message.content)) if system_message and type(system_message.content) == str else 0 ) diff --git a/src/khoj/processor/image/generate.py b/src/khoj/processor/image/generate.py index 59073731..ee39bdc5 100644 --- a/src/khoj/processor/image/generate.py +++ b/src/khoj/processor/image/generate.py @@ -26,7 +26,7 @@ async def text_to_image( references: List[Dict[str, Any]], online_results: Dict[str, Any], send_status_func: Optional[Callable] = None, - uploaded_image_url: Optional[str] = None, + query_images: Optional[List[str]] = None, agent: Agent = None, ): status_code = 200 @@ -65,7 +65,7 @@ async def text_to_image( note_references=references, online_results=online_results, model_type=text_to_image_config.model_type, - uploaded_image_url=uploaded_image_url, + query_images=query_images, user=user, agent=agent, ) diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index 70972eac..fdf1ba9f 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -62,7 +62,7 @@ async def search_online( user: KhojUser, send_status_func: Optional[Callable] = None, custom_filters: List[str] = [], - uploaded_image_url: str = None, + query_images: List[str] = None, agent: Agent = None, ): query += " ".join(custom_filters) @@ -73,7 +73,7 @@ async def search_online( # Breakdown the query into subqueries to get the correct answer subqueries = await generate_online_subqueries( - query, conversation_history, location, user, uploaded_image_url=uploaded_image_url, agent=agent + query, conversation_history, location, user, query_images=query_images, agent=agent ) response_dict = {} @@ -151,7 +151,7 @@ async def read_webpages( location: LocationData, user: KhojUser, send_status_func: Optional[Callable] = None, - uploaded_image_url: str = None, + query_images: List[str] = None, agent: Agent = None, ): "Infer web pages to read from the query and extract relevant information from them" @@ -159,7 +159,7 @@ async def read_webpages( if send_status_func: async for event in send_status_func(f"**Inferring web pages to read**"): yield {ChatEvent.STATUS: event} - urls = await infer_webpage_urls(query, conversation_history, location, user, uploaded_image_url) + urls = await infer_webpage_urls(query, conversation_history, location, user, query_images) logger.info(f"Reading web pages at: {urls}") if send_status_func: diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 59948b47..075c8c47 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -340,7 +340,7 @@ async def extract_references_and_questions( conversation_commands: List[ConversationCommand] = [ConversationCommand.Default], location_data: LocationData = None, send_status_func: Optional[Callable] = None, - uploaded_image_url: Optional[str] = None, + query_images: Optional[List[str]] = None, agent: Agent = None, ): user = request.user.object if request.user.is_authenticated else None @@ -431,7 +431,7 @@ async def extract_references_and_questions( conversation_log=meta_log, location_data=location_data, user=user, - uploaded_image_url=uploaded_image_url, + query_images=query_images, vision_enabled=vision_enabled, personality_context=personality_context, ) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index d57b5530..ee84c554 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -535,7 +535,7 @@ class ChatRequestBody(BaseModel): country: Optional[str] = None country_code: Optional[str] = None timezone: Optional[str] = None - image: Optional[str] = None + images: Optional[list[str]] = None create_new: Optional[bool] = False @@ -564,9 +564,9 @@ async def chat( country = body.country or get_country_name_from_timezone(body.timezone) country_code = body.country_code or get_country_code_from_timezone(body.timezone) timezone = body.timezone - image = body.image + raw_images = body.images - async def event_generator(q: str, image: str): + async def event_generator(q: str, images: list[str]): start_time = time.perf_counter() ttft = None chat_metadata: dict = {} @@ -576,16 +576,16 @@ async def chat( q = unquote(q) nonlocal conversation_id - uploaded_image_url = None - if image: - decoded_string = unquote(image) - base64_data = decoded_string.split(",", 1)[1] - image_bytes = base64.b64decode(base64_data) - webp_image_bytes = convert_image_to_webp(image_bytes) - try: - uploaded_image_url = upload_image_to_bucket(webp_image_bytes, request.user.object.id) - except: - uploaded_image_url = None + uploaded_images: list[str] = [] + if images: + for image in images: + decoded_string = unquote(image) + base64_data = decoded_string.split(",", 1)[1] + image_bytes = base64.b64decode(base64_data) + webp_image_bytes = convert_image_to_webp(image_bytes) + uploaded_image = upload_image_to_bucket(webp_image_bytes, request.user.object.id) + if uploaded_image: + uploaded_images.append(uploaded_image) async def send_event(event_type: ChatEvent, data: str | dict): nonlocal connection_alive, ttft @@ -692,7 +692,7 @@ async def chat( meta_log, is_automated_task, user=user, - uploaded_image_url=uploaded_image_url, + query_images=uploaded_images, agent=agent, ) conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) @@ -701,7 +701,7 @@ async def chat( ): yield result - mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_image_url, agent) + mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_images, agent) async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"): yield result if mode not in conversation_commands: @@ -764,7 +764,7 @@ async def chat( q, contextual_data, conversation_history=meta_log, - uploaded_image_url=uploaded_image_url, + query_images=uploaded_images, user=user, agent=agent, ) @@ -785,7 +785,7 @@ async def chat( intent_type="summarize", client_application=request.user.client_app, conversation_id=conversation_id, - uploaded_image_url=uploaded_image_url, + query_images=uploaded_images, ) return @@ -828,7 +828,7 @@ async def chat( conversation_id=conversation_id, inferred_queries=[query_to_run], automation_id=automation.id, - uploaded_image_url=uploaded_image_url, + query_images=uploaded_images, ) async for result in send_llm_response(llm_response): yield result @@ -848,7 +848,7 @@ async def chat( conversation_commands, location, partial(send_event, ChatEvent.STATUS), - uploaded_image_url=uploaded_image_url, + query_images=uploaded_images, agent=agent, ): if isinstance(result, dict) and ChatEvent.STATUS in result: @@ -892,7 +892,7 @@ async def chat( user, partial(send_event, ChatEvent.STATUS), custom_filters, - uploaded_image_url=uploaded_image_url, + query_images=uploaded_images, agent=agent, ): if isinstance(result, dict) and ChatEvent.STATUS in result: @@ -916,7 +916,7 @@ async def chat( location, user, partial(send_event, ChatEvent.STATUS), - uploaded_image_url=uploaded_image_url, + query_images=uploaded_images, agent=agent, ): if isinstance(result, dict) and ChatEvent.STATUS in result: @@ -966,20 +966,20 @@ async def chat( references=compiled_references, online_results=online_results, send_status_func=partial(send_event, ChatEvent.STATUS), - uploaded_image_url=uploaded_image_url, + query_images=uploaded_images, agent=agent, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] else: - image, status_code, improved_image_prompt, intent_type = result + generated_image, status_code, improved_image_prompt, intent_type = result - if image is None or status_code != 200: + if generated_image is None or status_code != 200: content_obj = { "content-type": "application/json", "intentType": intent_type, "detail": improved_image_prompt, - "image": image, + "image": None, } async for result in send_llm_response(json.dumps(content_obj)): yield result @@ -987,7 +987,7 @@ async def chat( await sync_to_async(save_to_conversation_log)( q, - image, + generated_image, user, meta_log, user_message_time, @@ -997,12 +997,12 @@ async def chat( conversation_id=conversation_id, compiled_references=compiled_references, online_results=online_results, - uploaded_image_url=uploaded_image_url, + query_images=uploaded_images, ) content_obj = { "intentType": intent_type, "inferredQueries": [improved_image_prompt], - "image": image, + "image": generated_image, } async for result in send_llm_response(json.dumps(content_obj)): yield result @@ -1024,7 +1024,7 @@ async def chat( conversation_id, location, user_name, - uploaded_image_url, + uploaded_images, ) # Send Response @@ -1050,9 +1050,9 @@ async def chat( ## Stream Text Response if stream: - return StreamingResponse(event_generator(q, image=image), media_type="text/plain") + return StreamingResponse(event_generator(q, images=raw_images), media_type="text/plain") ## Non-Streaming Text Response else: - response_iterator = event_generator(q, image=image) + response_iterator = event_generator(q, images=raw_images) response_data = await read_chat_stream(response_iterator) return Response(content=json.dumps(response_data), media_type="application/json", status_code=200) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 12616e36..7ed9c72d 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -290,7 +290,7 @@ async def aget_relevant_information_sources( conversation_history: dict, is_task: bool, user: KhojUser, - uploaded_image_url: str = None, + query_images: List[str] = None, agent: Agent = None, ): """ @@ -309,8 +309,8 @@ async def aget_relevant_information_sources( chat_history = construct_chat_history(conversation_history) - if uploaded_image_url: - query = f"[placeholder for user attached image]\n{query}" + if query_images: + query = f"[placeholder for {len(query_images)} user attached images]\n{query}" personality_context = ( prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else "" @@ -367,7 +367,7 @@ async def aget_relevant_output_modes( conversation_history: dict, is_task: bool = False, user: KhojUser = None, - uploaded_image_url: str = None, + query_images: List[str] = None, agent: Agent = None, ): """ @@ -389,8 +389,8 @@ async def aget_relevant_output_modes( chat_history = construct_chat_history(conversation_history) - if uploaded_image_url: - query = f"[placeholder for user attached image]\n{query}" + if query_images: + query = f"[placeholder for {len(query_images)} user attached images]\n{query}" personality_context = ( prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else "" @@ -433,7 +433,7 @@ async def infer_webpage_urls( conversation_history: dict, location_data: LocationData, user: KhojUser, - uploaded_image_url: str = None, + query_images: List[str] = None, agent: Agent = None, ) -> List[str]: """ @@ -459,7 +459,7 @@ async def infer_webpage_urls( with timer("Chat actor: Infer webpage urls to read", logger): response = await send_message_to_model_wrapper( - online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user + online_queries_prompt, query_images=query_images, response_type="json_object", user=user ) # Validate that the response is a non-empty, JSON-serializable list of URLs @@ -479,7 +479,7 @@ async def generate_online_subqueries( conversation_history: dict, location_data: LocationData, user: KhojUser, - uploaded_image_url: str = None, + query_images: List[str] = None, agent: Agent = None, ) -> List[str]: """ @@ -505,7 +505,7 @@ async def generate_online_subqueries( with timer("Chat actor: Generate online search subqueries", logger): response = await send_message_to_model_wrapper( - online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user + online_queries_prompt, query_images=query_images, response_type="json_object", user=user ) # Validate that the response is a non-empty, JSON-serializable list @@ -524,7 +524,7 @@ async def generate_online_subqueries( async def schedule_query( - q: str, conversation_history: dict, user: KhojUser, uploaded_image_url: str = None + q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None ) -> Tuple[str, ...]: """ Schedule the date, time to run the query. Assume the server timezone is UTC. @@ -537,7 +537,7 @@ async def schedule_query( ) raw_response = await send_message_to_model_wrapper( - crontime_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user + crontime_prompt, query_images=query_images, response_type="json_object", user=user ) # Validate that the response is a non-empty, JSON-serializable list @@ -583,7 +583,7 @@ async def extract_relevant_summary( q: str, corpus: str, conversation_history: dict, - uploaded_image_url: str = None, + query_images: List[str] = None, user: KhojUser = None, agent: Agent = None, ) -> Union[str, None]: @@ -612,7 +612,7 @@ async def extract_relevant_summary( extract_relevant_information, prompts.system_prompt_extract_relevant_summary, user=user, - uploaded_image_url=uploaded_image_url, + query_images=query_images, ) return response.strip() @@ -624,7 +624,7 @@ async def generate_better_image_prompt( note_references: List[Dict[str, Any]], online_results: Optional[dict] = None, model_type: Optional[str] = None, - uploaded_image_url: Optional[str] = None, + query_images: Optional[List[str]] = None, user: KhojUser = None, agent: Agent = None, ) -> str: @@ -676,7 +676,7 @@ async def generate_better_image_prompt( ) with timer("Chat actor: Generate contextual image prompt", logger): - response = await send_message_to_model_wrapper(image_prompt, uploaded_image_url=uploaded_image_url, user=user) + response = await send_message_to_model_wrapper(image_prompt, query_images=query_images, user=user) response = response.strip() if response.startswith(('"', "'")) and response.endswith(('"', "'")): response = response[1:-1] @@ -689,11 +689,11 @@ async def send_message_to_model_wrapper( system_message: str = "", response_type: str = "text", user: KhojUser = None, - uploaded_image_url: str = None, + query_images: List[str] = None, ): conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user) vision_available = conversation_config.vision_enabled - if not vision_available and uploaded_image_url: + if not vision_available and query_images: vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config() if vision_enabled_config: conversation_config = vision_enabled_config @@ -746,7 +746,7 @@ async def send_message_to_model_wrapper( max_prompt_size=max_tokens, tokenizer_name=tokenizer, vision_enabled=vision_available, - uploaded_image_url=uploaded_image_url, + query_images=query_images, model_type=conversation_config.model_type, ) @@ -766,7 +766,7 @@ async def send_message_to_model_wrapper( max_prompt_size=max_tokens, tokenizer_name=tokenizer, vision_enabled=vision_available, - uploaded_image_url=uploaded_image_url, + query_images=query_images, model_type=conversation_config.model_type, ) @@ -784,7 +784,8 @@ async def send_message_to_model_wrapper( max_prompt_size=max_tokens, tokenizer_name=tokenizer, vision_enabled=vision_available, - uploaded_image_url=uploaded_image_url, + query_images=query_images, + model_type=conversation_config.model_type, ) return gemini_send_message_to_model( @@ -875,6 +876,7 @@ def send_message_to_model_wrapper_sync( model_name=chat_model, max_prompt_size=max_tokens, vision_enabled=vision_available, + model_type=conversation_config.model_type, ) return gemini_send_message_to_model( @@ -900,7 +902,7 @@ def generate_chat_response( conversation_id: str = None, location_data: LocationData = None, user_name: Optional[str] = None, - uploaded_image_url: Optional[str] = None, + query_images: Optional[List[str]] = None, ) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: # Initialize Variables chat_response = None @@ -919,12 +921,12 @@ def generate_chat_response( inferred_queries=inferred_queries, client_application=client_application, conversation_id=conversation_id, - uploaded_image_url=uploaded_image_url, + query_images=query_images, ) conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation) vision_available = conversation_config.vision_enabled - if not vision_available and uploaded_image_url: + if not vision_available and query_images: vision_enabled_config = ConversationAdapters.get_vision_enabled_config() if vision_enabled_config: conversation_config = vision_enabled_config @@ -955,7 +957,7 @@ def generate_chat_response( chat_response = converse( compiled_references, q, - image_url=uploaded_image_url, + query_images=query_images, online_results=online_results, conversation_log=meta_log, model=chat_model,