Add Ability to Summarize Documents (#800)

* Uses entire file text and summarizer model to generate document summary.
* Uses the contents of the user's query to create a tailored summary.
* Integrates with File Filters #788 for a better UX.
This commit is contained in:
Raghav Tirumale
2024-06-18 10:01:07 -04:00
committed by GitHub
parent 677d49d438
commit d4e5c95711
21 changed files with 791 additions and 85 deletions

View File

@@ -23,13 +23,14 @@ def test_extract_markdown_with_no_headings(tmp_path):
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
# Assert
assert len(entries) == 1
assert len(entries) == 2
assert len(entries[1]) == 1
# Ensure raw entry with no headings do not get heading prefix prepended
assert not entries[0].raw.startswith("#")
assert not entries[1][0].raw.startswith("#")
# Ensure compiled entry has filename prepended as top level heading
assert entries[0].compiled.startswith(expected_heading)
assert entries[1][0].compiled.startswith(expected_heading)
# Ensure compiled entry also includes the file name
assert str(tmp_path) in entries[0].compiled
assert str(tmp_path) in entries[1][0].compiled
def test_extract_single_markdown_entry(tmp_path):
@@ -48,7 +49,8 @@ def test_extract_single_markdown_entry(tmp_path):
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
# Assert
assert len(entries) == 1
assert len(entries) == 2
assert len(entries[1]) == 1
def test_extract_multiple_markdown_entries(tmp_path):
@@ -72,8 +74,9 @@ def test_extract_multiple_markdown_entries(tmp_path):
# Assert
assert len(entries) == 2
assert len(entries[1]) == 2
# Ensure entry compiled strings include the markdown files they originate from
assert all([tmp_path.stem in entry.compiled for entry in entries])
assert all([tmp_path.stem in entry.compiled for entry in entries[1]])
def test_extract_entries_with_different_level_headings(tmp_path):
@@ -94,8 +97,9 @@ def test_extract_entries_with_different_level_headings(tmp_path):
# Assert
assert len(entries) == 2
assert entries[0].raw == "# Heading 1\n## Sub-Heading 1.1", "Ensure entry includes heading ancestory"
assert entries[1].raw == "# Heading 2\n"
assert len(entries[1]) == 2
assert entries[1][0].raw == "# Heading 1\n## Sub-Heading 1.1", "Ensure entry includes heading ancestory"
assert entries[1][1].raw == "# Heading 2\n"
def test_extract_entries_with_non_incremental_heading_levels(tmp_path):
@@ -116,10 +120,11 @@ def test_extract_entries_with_non_incremental_heading_levels(tmp_path):
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
# Assert
assert len(entries) == 3
assert entries[0].raw == "# Heading 1\n#### Sub-Heading 1.1", "Ensure entry includes heading ancestory"
assert entries[1].raw == "# Heading 1\n## Sub-Heading 1.2", "Ensure entry includes heading ancestory"
assert entries[2].raw == "# Heading 2\n"
assert len(entries) == 2
assert len(entries[1]) == 3
assert entries[1][0].raw == "# Heading 1\n#### Sub-Heading 1.1", "Ensure entry includes heading ancestory"
assert entries[1][1].raw == "# Heading 1\n## Sub-Heading 1.2", "Ensure entry includes heading ancestory"
assert entries[1][2].raw == "# Heading 2\n"
def test_extract_entries_with_text_before_headings(tmp_path):
@@ -141,10 +146,13 @@ body line 2
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=3)
# Assert
assert len(entries) == 3
assert entries[0].raw == "\nText before headings"
assert entries[1].raw == "# Heading 1\nbody line 1"
assert entries[2].raw == "# Heading 1\n## Heading 2\nbody line 2\n", "Ensure raw entry includes heading ancestory"
assert len(entries) == 2
assert len(entries[1]) == 3
assert entries[1][0].raw == "\nText before headings"
assert entries[1][1].raw == "# Heading 1\nbody line 1"
assert (
entries[1][2].raw == "# Heading 1\n## Heading 2\nbody line 2\n"
), "Ensure raw entry includes heading ancestory"
def test_parse_markdown_file_into_single_entry_if_small(tmp_path):
@@ -165,8 +173,9 @@ body line 1.1
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=12)
# Assert
assert len(entries) == 1
assert entries[0].raw == entry
assert len(entries) == 2
assert len(entries[1]) == 1
assert entries[1][0].raw == entry
def test_parse_markdown_entry_with_children_as_single_entry_if_small(tmp_path):
@@ -191,13 +200,14 @@ longer body line 2.1
entries = MarkdownToEntries.extract_markdown_entries(markdown_files=data, max_tokens=12)
# Assert
assert len(entries) == 3
assert len(entries) == 2
assert len(entries[1]) == 3
assert (
entries[0].raw == "# Heading 1\nbody line 1\n## Subheading 1.1\nbody line 1.1"
entries[1][0].raw == "# Heading 1\nbody line 1\n## Subheading 1.1\nbody line 1.1"
), "First entry includes children headings"
assert entries[1].raw == "# Heading 2\nbody line 2", "Second entry does not include children headings"
assert entries[1][1].raw == "# Heading 2\nbody line 2", "Second entry does not include children headings"
assert (
entries[2].raw == "# Heading 2\n## Subheading 2.1\nlonger body line 2.1\n"
entries[1][2].raw == "# Heading 2\n## Subheading 2.1\nlonger body line 2.1\n"
), "Third entry is second entries child heading"