diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index 5e73fb07..33cc3056 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -30,6 +30,8 @@ from khoj.utils.rawconfig import LocationData logger = logging.getLogger(__name__) +GOOGLE_SEARCH_API_KEY = os.getenv("GOOGLE_SEARCH_API_KEY") +GOOGLE_SEARCH_ENGINE_ID = os.getenv("GOOGLE_SEARCH_ENGINE_ID") SERPER_DEV_API_KEY = os.getenv("SERPER_DEV_API_KEY") SERPER_DEV_URL = "https://google.serper.dev/search" @@ -96,19 +98,25 @@ async def search_online( yield response_dict return - logger.info(f"🌐 Searching the Internet for {subqueries}") + if GOOGLE_SEARCH_API_KEY and GOOGLE_SEARCH_ENGINE_ID: + search_engine = "Google" + search_func = search_with_google + elif SERPER_DEV_API_KEY: + search_engine = "Serper" + search_func = search_with_serper + elif JINA_API_KEY: + search_engine = "Jina" + search_func = search_with_jina + else: + search_engine = "Searxng" + search_func = search_with_searxng + + logger.info(f"🌐 Searching the Internet with {search_engine} for {subqueries}") if send_status_func: subqueries_str = "\n- " + "\n- ".join(subqueries) async for event in send_status_func(f"**Searching the Internet for**: {subqueries_str}"): yield {ChatEvent.STATUS: event} - if SERPER_DEV_API_KEY: - search_func = search_with_serper - elif JINA_API_KEY: - search_func = search_with_jina - else: - search_func = search_with_searxng - with timer(f"Internet searches for {subqueries} took", logger): search_tasks = [search_func(subquery, location) for subquery in subqueries] search_results = await asyncio.gather(*search_tasks) @@ -195,6 +203,56 @@ async def search_with_searxng(query: str, location: LocationData) -> Tuple[str, return query, {} +async def search_with_google(query: str, location: LocationData) -> Tuple[str, Dict[str, List[Dict]]]: + country_code = location.country_code.lower() if location and location.country_code else "us" + base_url = "https://www.googleapis.com/customsearch/v1" + params = { + "key": GOOGLE_SEARCH_API_KEY, + "cx": GOOGLE_SEARCH_ENGINE_ID, + "q": query, + "cr": f"country{country_code.upper()}", # Country restrict parameter + "gl": country_code, # Geolocation parameter + } + + async with aiohttp.ClientSession() as session: + async with session.get(base_url, params=params) as response: + if response.status != 200: + logger.error(await response.text()) + return query, {} + + json_response = await response.json() + + # Transform Google's response format to match Serper's format + organic_results = [] + if "items" in json_response: + organic_results = [ + { + "title": item.get("title", ""), + "link": item.get("link", ""), + "snippet": item.get("snippet", ""), + "content": None, # Google Search API doesn't provide full content + } + for item in json_response["items"] + ] + + # Format knowledge graph if available + knowledge_graph = {} + if "knowledge_graph" in json_response: + kg = json_response["knowledge_graph"] + knowledge_graph = { + "title": kg.get("name", ""), + "description": kg.get("description", ""), + "type": kg.get("type", ""), + } + + extracted_search_result: Dict[str, Any] = {"organic": organic_results} + + if knowledge_graph: + extracted_search_result["knowledgeGraph"] = knowledge_graph + + return query, extracted_search_result + + async def search_with_serper(query: str, location: LocationData) -> Tuple[str, Dict[str, List[Dict]]]: country_code = location.country_code.lower() if location and location.country_code else "us" payload = json.dumps({"q": query, "gl": country_code})