import nltk
from loguru import logger
from nltk.tokenize import sent_tokenize, word_tokenize
from sherpa_ai.actions.base import ActionResource, BaseRetrievalAction
from sherpa_ai.memory import Belief
from sherpa_ai.output_parsers.base import BaseOutputProcessor
from sherpa_ai.output_parsers.validation_result import ValidationResult
# download the punkt tokenizer. This is necessary for the sent_tokenize in NLTK.
# The download will only happen once and the result will be cached.
nltk.download("punkt_tab")
[docs]
class CitationValidation(BaseOutputProcessor):
"""
A class for adding citations to generated text based on a list of resources.
This class inherits from the abstract class BaseOutputParser and provides
methods to add citations to each sentence in the generated text based on
reference texts and links provided in the 'resources' list.
Attributes:
sequence_threshold (float): Threshold for common longest subsequence / text. Default is 0.7.
jaccard_threshold (float): Jaccard similarity threshold. Default is 0.7.
token_overlap (float): Token overlap threshold. Default is 0.7.
Typical usage example:
```python
citation_parser = CitationValidation(seq_thresh=0.7, jaccard_thresh=0.7, token_overlap=0.7)
result = citation_parser.parse_output(generated_text, list_of_resources)
```
"""
def __init__(
self, sequence_threshold=0.7, jaccard_threshold=0.7, token_overlap=0.7
):
self.sequence_threshold = sequence_threshold
self.jaccard_threshold = jaccard_threshold
self.token_overlap = token_overlap
[docs]
def calculate_token_overlap(self, sentence1, sentence2) -> tuple:
"""
Calculates the percentage of token overlap between two sentences.
Tokenizes the input sentences and calculates the percentage of token overlap
by finding the intersection of the token sets and dividing it by the length
of each sentence's token set.
Args:
sentence1 (str): The first sentence for token overlap calculation.
sentence2 (str): The second sentence for token overlap calculation.
Returns:
tuple: A tuple containing two float values representing the percentage
of token overlap for sentence1 and sentence2, respectively.
"""
# Tokenize the sentences
tokens1 = word_tokenize(sentence1)
tokens2 = word_tokenize(sentence2)
# Calculate the set intersection to find the overlapping tokens
overlapping_tokens = set(tokens1) & set(tokens2)
# Calculate the percentage of token overlap
if len(tokens1) == 0:
overlap_percentage = 0
else:
overlap_percentage = len(overlapping_tokens) / (len(tokens1))
if len(tokens2) == 0:
overlap_percentage_2 = 0
else:
overlap_percentage_2 = len(overlapping_tokens) / (len(tokens2))
return overlap_percentage, overlap_percentage_2
[docs]
def jaccard_index(sself, sentence1, sentence2) -> float:
"""
Calculates the Jaccard index between two sentences.
The Jaccard index is a measure of similarity between two sets, defined as the
size of the intersection divided by the size of the union of the sets.
Args:
sentence1 (str): The first sentence for Jaccard index calculation.
sentence2 (str): The second sentence for Jaccard index calculation.
Returns:
float: The Jaccard index representing the similarity between the two sentences.
"""
# Convert the sentences to sets of words
set1 = set(word_tokenize(sentence1))
set2 = set(word_tokenize(sentence2))
# Calculate the Jaccard index
intersection = len(set1.intersection(set2))
union = len(set1.union(set2))
jaccard_index = intersection / union if union != 0 else 0.0
return jaccard_index
[docs]
def longest_common_subsequence(self, text1: str, text2: str) -> int:
"""
Calculates the length of the longest common subsequence between two texts.
A subsequence of a string is a new string generated from the original
string with some characters (can be none) deleted without changing
the relative order of the remaining characters.
Args:
- text1 (str): The first text for calculating the longest common subsequence.
- text2 (str): The second text for calculating the longest common subsequence.
Returns:
- int: The length of the longest common subsequence between the two texts.
"""
dp = [[0 for i in range(len(text1) + 1)] for i in range(len(text2) + 1)]
for i in range(1, len(text2) + 1):
for j in range(1, len(text1) + 1):
diagnoal = dp[i - 1][j - 1]
if text1[j - 1] == text2[i - 1]:
diagnoal += 1
dp[i][j] = max(diagnoal, dp[i - 1][j], dp[i][j - 1])
return dp[-1][-1]
[docs]
def flatten_nested_list(self, nested_list: list[list[str]]) -> list[str]:
"""
Flattens a nested list of strings into a single list of strings.
Args:
nested_list (list[list[str]]): The nested list of strings to be flattened.
Returns:
list[str]: A flat list containing all non-empty strings from the nested list.
"""
sentences = []
for sublist in nested_list:
for item in sublist:
if len(item) > 0:
sentences.append(item)
return sentences
[docs]
def split_paragraph_into_sentences(self, paragraph: str) -> list[str]:
"""
Uses NLTK's sent_tokenize to split the given paragraph into a list of sentences.
Args:
paragraph (str): The input paragraph to be tokenized into sentences.
Returns:
list[str]: A list of sentences extracted from the input paragraph.
"""
sentences = sent_tokenize(paragraph)
return sentences
[docs]
def resources_from_belief(self, belief: Belief) -> list[ActionResource]:
"""
Returns a list of all resources within belief.actions.
"""
resources = []
for action in belief.actions:
if isinstance(action, BaseRetrievalAction):
resources.extend(action.resources)
return resources
[docs]
def process_output(self, text: str, belief: Belief, **kwargs) -> ValidationResult:
"""
Add citations to sentences in the generated text using resources based on fact checking model.
Args:
text (str): The text which needs citations/references added
belief (Belief): Belief of the agent that generated `text`
Returns:
ValidationResult: The result of citation processing.
`is_valid` is True when citation processing succeeds or no citation resources are provided,
False otherwise.
`result` contains the formatted text with citations.
`feedback` providing additional optional information.
Typical usage example:
```python
resources = ActionResource(source="http://example.com/source1", content="Some reference text.")]
citation_parser = CitationValidation()
result = citation_parser.parse_output("Text needing citations.", resources)
```
"""
resources = self.resources_from_belief(belief)
if len(resources) == 0:
# no resources used, return the original text
return ValidationResult(
is_valid=True,
result=text,
feedback="",
)
return self.add_citations(text, resources)
[docs]
def add_citation_to_sentence(self, sentence: str, resources: list[ActionResource]):
"""
Uses a list of resources to add citations to a sentence
Returns:
citation_ids: a list of citation identifiers
citation_links: a list of citation links (URLs)
"""
citation_ids = []
citation_links = []
if len(sentence) <= 5:
return citation_ids, citation_links
for index, resource in enumerate(resources):
cited = False
resource_link = resource.source
resource_text = resource.content
resource_sentences = resource_text.split(".")
# TODO: verify that splitting each sentence on newlines improves citation results
nested_sentence_lines = [s.split("\n") for s in resource_sentences]
resource_lines = self.flatten_nested_list(nested_sentence_lines)
for resource_line in resource_lines:
if not cited and not (resource_link in citation_links):
seq = self.longest_common_subsequence(sentence, resource_line)
if (
(seq / len(sentence)) > self.sequence_threshold
or sentence in resource_line
or self.jaccard_index(sentence, resource_line)
> self.jaccard_threshold
):
citation_links.append(resource_link)
citation_ids.append(index + 1)
cited = True
return citation_ids, citation_links
[docs]
def add_citations(self, text: str, resources: list[dict]) -> ValidationResult:
paragraph = text.split("\n")
paragraph = [p for p in paragraph if len(p.strip()) > 0]
paragraphs = [self.split_paragraph_into_sentences(s) for s in paragraph]
new_paragraph = []
for paragraph in paragraphs:
new_sentences = []
# for each sentence in each paragraph
for _, sentence in enumerate(paragraph):
sentence = sentence.strip()
if len(sentence) == 0:
continue
ids, links = self.add_citation_to_sentence(sentence, resources)
formatted_sentence = self.format_sentence_with_citations(
sentence, ids, links
)
new_sentences.append(formatted_sentence)
new_paragraph.append(" ".join(new_sentences) + "\n")
return ValidationResult(
is_valid=True,
result="".join(new_paragraph),
feedback="",
)
[docs]
def get_failure_message(self) -> str:
return "Unable to add citations to the generated text. Please pay attention to the cited sources."