Source code for sherpa_ai.output_parsers.entity_validation
from enum import Enum
from typing import Tuple
from langchain_core.language_models import BaseLanguageModel
from sherpa_ai.events import EventType
from sherpa_ai.memory import Belief
from sherpa_ai.output_parsers.base import BaseOutputProcessor
from sherpa_ai.output_parsers.validation_result import ValidationResult
from sherpa_ai.utils import (
extract_entities,
text_similarity,
text_similarity_by_llm,
text_similarity_by_metrics,
)
class TextSimilarityMethod(Enum):
BASIC = 0
METRICS = 1
LLM = 2
[docs]
class EntityValidation(BaseOutputProcessor):
"""
Process and validate the presence of entities in the generated text.
This class inherits from the BaseOutputProcessor and provides a method to process
the generated text and validate the presence of entities based on a specified source.
Methods:
- process_output(text: str, belief: Belief) -> ValidationResult:
Process the generated text and validate the presence of entities.
- get_failure_message() -> str:
Returns a failure message to be displayed when the validation fails.
"""
[docs]
def process_output(
self, text: str, belief: Belief, llm: BaseLanguageModel = None, **kwargs
) -> ValidationResult:
"""
Verifies that entities within `text` exist in the `belief` source text.
Args:
text: The text to be processed
belief: The belief object of the agent that generated the output
iteration_count (int, optional): The iteration count for validation processing.
1. means basic text similarity.
2 means text similarity by metrics.
3 means text similarity by llm.
Returns:
ValidationResult: The result of the validation. If any entity in the
text to be processed doesn't exist in the source text,
validation is invalid and contains a feedback string.
Otherwise, validation is valid.
"""
source = belief.get_histories_excluding_types(
exclude_types=[EventType.feedback, EventType.result, EventType.action],
)
entity_exist_in_source, error_message = self.check_entities_match(
text, source, self.similarity_picker(self.count), llm
)
if entity_exist_in_source:
return ValidationResult(
is_valid=True,
result=text,
feedback="",
)
else:
self.count += 1
return ValidationResult(
is_valid=False,
result=text,
feedback=error_message,
)
[docs]
def similarity_picker(self, value: int):
"""
Picks a text similarity state based on the provided iteration count value.
Args:
value (int): The iteration count value used to determine the text similarity state.
- 0: Use BASIC text similarity.
- 1: Use text similarity BY_METRICS.
- Default: Use text similarity BY_LLM.
Returns:
TextSimilarityState: The selected text similarity state.
"""
switch_dict = {0: TextSimilarityMethod.BASIC, 1: TextSimilarityMethod.METRICS}
return switch_dict.get(value, TextSimilarityMethod.LLM)
[docs]
def get_failure_message(self) -> str:
return "Some enitities from the source might not be mentioned."
[docs]
def check_entities_match(
self,
result: str,
source: str,
stage: TextSimilarityMethod,
llm: BaseLanguageModel,
):
"""
Check if entities extracted from a question are present in an answer.
Args:
- result (str): Answer text.
- source (str): Question text.
- stage (int): Stage of the check (0, 1, or 2).
Returns:
dict: Result of the check containing
"""
stage = stage.value
source_entity = extract_entities(source)
check_entity = extract_entities(result)
if stage == 0:
return text_similarity(
check_entity=check_entity, source_entity=source_entity
)
elif stage == 1:
return text_similarity_by_metrics(
check_entity=check_entity, source_entity=source_entity
)
elif stage > 1 and llm is not None:
return text_similarity_by_llm(
llm=llm,
source_entity=source_entity,
result=result,
source=source,
)
return text_similarity_by_metrics(
check_entity=check_entity, source_entity=source_entity
)