Return enabled scrapers as WebScraper objects for more ergonomic code

This commit is contained in:
Debanjum Singh Solanky
2024-10-17 17:15:53 -07:00
parent 0db52786ed
commit 2c20f49bc5
3 changed files with 40 additions and 24 deletions

View File

@@ -1045,41 +1045,59 @@ class ConversationAdapters:
return None return None
@staticmethod @staticmethod
async def aget_enabled_webscrapers(): async def aget_enabled_webscrapers() -> list[WebScraper]:
enabled_scrapers = [] enabled_scrapers: list[WebScraper] = []
server_webscraper = await ConversationAdapters.aget_server_webscraper() server_webscraper = await ConversationAdapters.aget_server_webscraper()
if server_webscraper: if server_webscraper:
# Only use the webscraper set in the server chat settings # Only use the webscraper set in the server chat settings
enabled_scrapers = [ enabled_scrapers = [server_webscraper]
(server_webscraper.type, server_webscraper.api_key, server_webscraper.api_url, server_webscraper.name)
]
if not enabled_scrapers: if not enabled_scrapers:
# Use the enabled web scrapers, ordered by priority, until get web page content # Use the enabled web scrapers, ordered by priority, until get web page content
enabled_scrapers = [ enabled_scrapers = [scraper async for scraper in WebScraper.objects.all().order_by("priority").aiterator()]
(scraper.type, scraper.api_key, scraper.api_url, scraper.name)
async for scraper in WebScraper.objects.all().order_by("priority").aiterator()
]
if not enabled_scrapers: if not enabled_scrapers:
# Use scrapers enabled via environment variables # Use scrapers enabled via environment variables
if os.getenv("FIRECRAWL_API_KEY"): if os.getenv("FIRECRAWL_API_KEY"):
api_url = os.getenv("FIRECRAWL_API_URL", "https://api.firecrawl.dev") api_url = os.getenv("FIRECRAWL_API_URL", "https://api.firecrawl.dev")
enabled_scrapers.append( enabled_scrapers.append(
(WebScraper.WebScraperType.FIRECRAWL, os.getenv("FIRECRAWL_API_KEY"), api_url, "Firecrawl") WebScraper(
type=WebScraper.WebScraperType.FIRECRAWL,
name=WebScraper.WebScraperType.FIRECRAWL.capitalize(),
api_key=os.getenv("FIRECRAWL_API_KEY"),
api_url=api_url,
)
) )
if os.getenv("OLOSTEP_API_KEY"): if os.getenv("OLOSTEP_API_KEY"):
api_url = os.getenv("OLOSTEP_API_URL", "https://agent.olostep.com/olostep-p2p-incomingAPI") api_url = os.getenv("OLOSTEP_API_URL", "https://agent.olostep.com/olostep-p2p-incomingAPI")
enabled_scrapers.append( enabled_scrapers.append(
(WebScraper.WebScraperType.OLOSTEP, os.getenv("OLOSTEP_API_KEY"), api_url, "Olostep") WebScraper(
type=WebScraper.WebScraperType.OLOSTEP,
name=WebScraper.WebScraperType.OLOSTEP.capitalize(),
api_key=os.getenv("OLOSTEP_API_KEY"),
api_url=api_url,
)
) )
# Jina is the default fallback scrapers to use as it does not require an API key # Jina is the default fallback scrapers to use as it does not require an API key
api_url = os.getenv("JINA_READER_API_URL", "https://r.jina.ai/") api_url = os.getenv("JINA_READER_API_URL", "https://r.jina.ai/")
enabled_scrapers.append((WebScraper.WebScraperType.JINA, os.getenv("JINA_API_KEY"), api_url, "Jina")) enabled_scrapers.append(
WebScraper(
type=WebScraper.WebScraperType.JINA,
name=WebScraper.WebScraperType.JINA.capitalize(),
api_key=os.getenv("JINA_API_KEY"),
api_url=api_url,
)
)
# Only enable the direct web page scraper by default in self-hosted single user setups. # Only enable the direct web page scraper by default in self-hosted single user setups.
# Useful for reading webpages on your intranet. # Useful for reading webpages on your intranet.
if state.anonymous_mode or in_debug_mode(): if state.anonymous_mode or in_debug_mode():
enabled_scrapers.append((WebScraper.WebScraperType.DIRECT, None, None, "Direct")) enabled_scrapers.append(
WebScraper(
type=WebScraper.WebScraperType.DIRECT,
name=WebScraper.WebScraperType.DIRECT.capitalize(),
api_key=None,
api_url=None,
)
)
return enabled_scrapers return enabled_scrapers

View File

@@ -198,16 +198,18 @@ async def read_webpage_and_extract_content(
web_scrapers = await ConversationAdapters.aget_enabled_webscrapers() web_scrapers = await ConversationAdapters.aget_enabled_webscrapers()
# Only use the direct web scraper for internal URLs # Only use the direct web scraper for internal URLs
if is_internal_url(url): if is_internal_url(url):
web_scrapers = [scraper for scraper in web_scrapers if scraper[0] == WebScraper.WebScraperType.DIRECT] web_scrapers = [scraper for scraper in web_scrapers if scraper.type == WebScraper.WebScraperType.DIRECT]
# Fallback through enabled web scrapers until we successfully read the web page # Fallback through enabled web scrapers until we successfully read the web page
extracted_info = None extracted_info = None
for scraper_type, api_key, api_url, api_name in web_scrapers: for scraper in web_scrapers:
try: try:
# Read the web page # Read the web page
if is_none_or_empty(content): if is_none_or_empty(content):
with timer(f"Reading web page with {scraper_type} at '{url}' took", logger, log_level=logging.INFO): with timer(f"Reading web page with {scraper.type} at '{url}' took", logger, log_level=logging.INFO):
content, extracted_info = await read_webpage(url, scraper_type, api_key, api_url, subqueries, agent) content, extracted_info = await read_webpage(
url, scraper.type, scraper.api_key, scraper.api_url, subqueries, agent
)
# Extract relevant information from the web page # Extract relevant information from the web page
if is_none_or_empty(extracted_info): if is_none_or_empty(extracted_info):
@@ -218,9 +220,9 @@ async def read_webpage_and_extract_content(
if not is_none_or_empty(extracted_info): if not is_none_or_empty(extracted_info):
break break
except Exception as e: except Exception as e:
logger.warning(f"Failed to read web page with {scraper_type} at '{url}' with {e}") logger.warning(f"Failed to read web page with {scraper.type} at '{url}' with {e}")
# If this is the last web scraper in the list, log an error # If this is the last web scraper in the list, log an error
if api_name == web_scrapers[-1][-1]: if scraper.name == web_scrapers[-1].name:
logger.error(f"All web scrapers failed for '{url}'") logger.error(f"All web scrapers failed for '{url}'")
return subqueries, url, extracted_info return subqueries, url, extracted_info

View File

@@ -468,10 +468,6 @@ def is_internal_url(url: str) -> bool:
if any(hostname.endswith(tld) for tld in internal_tlds): if any(hostname.endswith(tld) for tld in internal_tlds):
return True return True
# Check for non-standard ports
# if parsed_url.port and parsed_url.port not in [80, 443]:
# return True
# Check for URLs without a TLD # Check for URLs without a TLD
if "." not in hostname: if "." not in hostname:
return True return True