Use async/await in tests for query method of text and image search

The text, image search query method has become async. So async/await
is required to get results correctly in tests etc
This commit is contained in:
Debanjum Singh Solanky
2023-06-28 20:11:26 -07:00
parent f516d127c8
commit 56ce97ef9e
3 changed files with 16 additions and 9 deletions

View File

@@ -3,6 +3,9 @@ import logging
from pathlib import Path
from PIL import Image
# External Packages
import pytest
# Internal Packages
from khoj.utils.state import model
from khoj.utils.constants import web_directory
@@ -48,7 +51,8 @@ def test_image_metadata(content_config: ContentConfig):
# ----------------------------------------------------------------------------------------------------
def test_image_search(content_config: ContentConfig, search_config: SearchConfig):
@pytest.mark.anyio
async def test_image_search(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
output_directory = resolve_absolute_path(web_directory)
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
@@ -60,7 +64,7 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig
# Act
for query, expected_image_name in query_expected_image_pairs:
hits = image_search.query(query, count=1, model=model.image_search)
hits = await image_search.query(query, count=1, model=model.image_search)
results = image_search.collate_results(
hits,
@@ -83,7 +87,8 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig
# ----------------------------------------------------------------------------------------------------
def test_image_search_query_truncated(content_config: ContentConfig, search_config: SearchConfig, caplog):
@pytest.mark.anyio
async def test_image_search_query_truncated(content_config: ContentConfig, search_config: SearchConfig, caplog):
# Arrange
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
max_words_supported = 10
@@ -93,7 +98,7 @@ def test_image_search_query_truncated(content_config: ContentConfig, search_conf
# Act
try:
with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
image_search.query(query, count=1, model=model.image_search)
await image_search.query(query, count=1, model=model.image_search)
# Assert
except RuntimeError as e:
if "The size of tensor a (102) must match the size of tensor b (77)" in str(e):
@@ -102,7 +107,8 @@ def test_image_search_query_truncated(content_config: ContentConfig, search_conf
# ----------------------------------------------------------------------------------------------------
def test_image_search_by_filepath(content_config: ContentConfig, search_config: SearchConfig, caplog):
@pytest.mark.anyio
async def test_image_search_by_filepath(content_config: ContentConfig, search_config: SearchConfig, caplog):
# Arrange
output_directory = resolve_absolute_path(web_directory)
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
@@ -113,7 +119,7 @@ def test_image_search_by_filepath(content_config: ContentConfig, search_config:
# Act
with caplog.at_level(logging.INFO, logger="khoj.search_type.image_search"):
hits = image_search.query(query, count=1, model=model.image_search)
hits = await image_search.query(query, count=1, model=model.image_search)
results = image_search.collate_results(
hits,