from __future__ import annotations
import traceback
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Union
from loguru import logger
from pydantic import BaseModel, ConfigDict
from sherpa_ai.actions.base import BaseAction
from sherpa_ai.actions.exceptions import (SherpaActionExecutionException,
SherpaMissingInformationException)
from sherpa_ai.config.task_result import TaskResult
from sherpa_ai.events import EventType
from sherpa_ai.memory import Belief, SharedMemory
from sherpa_ai.output_parsers.base import BaseOutputProcessor
from sherpa_ai.policies.base import BasePolicy, PolicyOutput
from sherpa_ai.policies.exceptions import SherpaPolicyException
[docs]
class BaseAgent(ABC, BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
name: str
description: str
shared_memory: SharedMemory = None
belief: Belief = None
policy: BasePolicy = None
num_runs: int = 1
actions: List[BaseAction] = []
validation_steps: int = 1
validations: List[BaseOutputProcessor] = []
feedback_agent_name: str = "critic"
global_regen_max: int = 12
do_synthesize_output: bool = False
llm: Any = None
[docs]
@abstractmethod
def create_actions(self) -> List[BaseAction]:
pass
[docs]
@abstractmethod
def synthesize_output(self) -> str:
pass
[docs]
def send_event(self, event: str, args: dict):
"""
Send an event to the state machine in the belief
Args:
event (str): The event name
args (dict): The arguments for the event
"""
if self.belief.state_machine is None:
logger.error("State machine is not defined in the belief")
return
getattr(self.belief.state_machine, event)(**args)
[docs]
async def async_send_event(self, event: str, args: dict):
"""
Send an event to the state machine in the belief
Args:
event (str): The event name
args (dict): The arguments for the event
"""
if self.belief.state_machine is None:
logger.error("State machine is not defined in the belief")
return
func = getattr(self.belief.state_machine, event)
if func is None:
logger.error(f"Event {event} is not defined in the state machine")
return
await func(**args)
[docs]
def agent_preparation(self):
logger.debug(f"```⏳{self.name} is thinking...```")
if self.shared_memory is not None:
self.shared_memory.observe(self.belief)
if len(self.belief.get_actions()) == 0:
actions = self.actions if len(self.actions) > 0 else self.create_actions()
self.belief.set_actions(actions)
[docs]
def select_action(self) -> Optional[PolicyOutput]:
try:
result = self.policy.select_action(self.belief)
return result
except SherpaPolicyException as e:
self.belief.update_internal(
EventType.action_output,
self.feedback_agent_name,
f"Error in selecting action: {e}",
)
logger.exception(e)
return None
except Exception as e:
logger.exception(e)
return e
[docs]
async def async_select_action(self) -> Optional[PolicyOutput]:
try:
result = await self.policy.async_select_action(self.belief)
return result
except SherpaPolicyException as e:
self.belief.update_internal(
EventType.action_output,
self.feedback_agent_name,
f"Error in selecting action: {e}",
)
logger.exception(e)
return None
except Exception as e:
logger.exception(e)
return e
[docs]
def agent_finished(self, result: str) -> str:
if len(self.validations) > 0:
result = self.validate_output()
elif self.do_synthesize_output:
result = self.synthesize_output()
logger.debug(f"```🤖{self.name} wrote: {result}```")
if self.shared_memory is not None:
self.shared_memory.add(EventType.result, self.name, result)
return result
[docs]
def run(self) -> TaskResult:
self.agent_preparation()
for _ in range(self.num_runs):
if len(self.belief.get_actions()) == 0:
break
result = self.select_action()
if result is None:
# this means no action is selected
continue
elif isinstance(result, Exception):
tb_exception = traceback.TracebackException.from_exception(result)
stack_trace = "".join(tb_exception.format())
task_result = TaskResult(content=stack_trace, status="failed")
return task_result
logger.debug(f"Action selected: {result}")
logger.debug(
f"🤖{self.name} is executing```" "``` {result.action.name}...```"
)
action_output = self.act(result.action, result.args)
if action_output is None:
continue
elif isinstance(action_output, SherpaMissingInformationException):
question = action_output.message
task_result = TaskResult(content=question, status="waiting")
return task_result
elif isinstance(action_output, Exception):
tb_exception = traceback.TracebackException.from_exception(
action_output
)
stack_trace = "".join(tb_exception.format())
task_result = TaskResult(content=stack_trace, status="failed")
return task_result
action_output = self.belief.get(result.action.name, action_output)
logger.debug(f"```Action output: {action_output}```")
action_output = self.agent_finished(action_output)
task_result = TaskResult(content=action_output, status="success")
return task_result
[docs]
async def async_run(self) -> TaskResult:
logger.debug(f"```⏳{self.name} is thinking...```")
if self.shared_memory is not None:
self.shared_memory.observe(self.belief)
actions = await self.belief.async_get_actions()
if len(actions) == 0:
actions = self.actions if len(self.actions) > 0 else self.create_actions()
self.belief.set_actions(actions)
for _ in range(self.num_runs):
actions = await self.belief.async_get_actions()
if len(actions) == 0:
break
result = await self.async_select_action()
if result is None:
# this means no action is selected
continue
elif isinstance(result, Exception):
tb_exception = traceback.TracebackException.from_exception(result)
stack_trace = "".join(tb_exception.format())
task_result = TaskResult(content=stack_trace, status="failed")
return task_result
logger.debug(f"Action selected: {result}")
logger.debug(
f"🤖{self.name} is executing```" "``` {result.action.name}...```"
)
action_output = await self.async_act(result.action, result.args)
if action_output is None:
continue
elif isinstance(action_output, SherpaMissingInformationException):
question = action_output.message
task_result = TaskResult(content=question, status="waiting")
return task_result
elif isinstance(action_output, Exception):
tb_exception = traceback.TracebackException.from_exception(
action_output
)
stack_trace = "".join(tb_exception.format())
task_result = TaskResult(content=stack_trace, status="failed")
return task_result
action_output = self.belief.get(result.action.name, action_output)
logger.debug(f"```Action output: {action_output}```")
action_output = self.agent_finished(action_output)
task_result = TaskResult(content=action_output, status="success")
return task_result
# The validation_iterator function is responsible for iterating through each
# instantiated validation in the 'self.validations' list.
# It performs the necessary validation steps for each validation, updating the
# belief system and synthesizing output if needed.
# It keeps track of the global regeneration count, whether all validations have
# passed, and if any validation has been escaped.
# The function returns the updated global regeneration count, the status of all
# validations, whether any validation has been escaped, and the synthesized output.
[docs]
def validation_iterator(
self,
validations,
global_regen_count,
all_pass,
validation_is_scaped,
result,
):
for i in range(len(validations)):
validation = validations[i]
logger.info(f"validation_running: {validation.__class__.__name__}")
logger.info(f"validation_count: {validation.count}")
# this checks if the validator has already exceeded the validation steps
# limit.
if validation.count < self.validation_steps:
self.belief.update_internal(EventType.result, self.name, result)
validation_result = validation.process_output(
text=result, belief=self.belief, llm=self.llm
)
logger.info(f"validation_result: {validation_result}")
if not validation_result.is_valid:
self.belief.update_internal(
EventType.feedback,
self.feedback_agent_name,
validation_result.feedback,
)
result = self.synthesize_output()
global_regen_count += 1
break
# if all validations passed then set all_pass to True
elif i == len(validations) - 1:
result = validation_result.result
all_pass = True
else:
result = validation_result.result
# if validation is the last one and surpassed the validation steps limit
# then finish the loop with all_pass and mention there is a scaped
# validation.
elif i == len(validations) - 1:
validation_is_scaped = True
all_pass = True
else:
# if the validation has already reached the validation steps limit then
# continue to the next validation.
validation_is_scaped = True
return global_regen_count, all_pass, validation_is_scaped, result
[docs]
def validate_output(self):
"""
Validate the synthesized output through a series of validation steps.
This method iterates through each validation in the 'validations' list, and for
each validation, it performs 'validation_steps' attempts to synthesize output
using 'synthesize_output' method. If the output doesn't pass validation,
feedback is incorporated into the belief system.
If a validation fails after all attempts, the error messages from the last
failed validation are appended to the final result.
Returns:
str: The synthesized output after validation.
"""
result = ""
# create array of instance of validation so that we can keep track of how many
# times regeneration happened.
all_pass = False
validation_is_scaped = False
iteration_count = 0
result = self.synthesize_output()
global_regen_count = 0
# reset the state of all the validation before starting the validation process.
for validation in self.validations:
validation.reset_state()
validations = self.validations
# this loop will run until max regeneration reached or all validations have
# failed
while self.global_regen_max > global_regen_count and not all_pass:
logger.info(f"validations_size: {len(validations)}")
iteration_count += 1
logger.info(f"main_iteration: {iteration_count}")
logger.info(f"regen_count: {global_regen_count}")
(
global_regen_count,
all_pass,
validation_is_scaped,
result,
) = self.validation_iterator(
all_pass=all_pass,
global_regen_count=global_regen_count,
validation_is_scaped=validation_is_scaped,
validations=validations,
result=result,
)
# if all didn't pass or validation reached max regeneration run the validation
# one more time but no regeneration.
if validation_is_scaped or self.global_regen_max >= global_regen_count:
failed_validations = []
for validation in validations:
validation_result = validation.process_output(
text=result, belief=self.belief, llm=self.llm
)
if not validation_result.is_valid:
failed_validations.append(validation)
else:
result = validation_result.result
result += "\n".join(
failed_validation.get_failure_message()
for failed_validation in failed_validations
)
else:
# check if validation is not passed after all the attempts if so return the
# error message.
result += "\n".join(
(
inst_val.get_failure_message()
if inst_val.count == self.validation_steps
else ""
)
for inst_val in validations
)
self.belief.update_internal(EventType.result, self.name, result)
return result
[docs]
def observe(self):
return self.shared_memory.observe(self.belief)
[docs]
def act(self, action: BaseAction, inputs: dict) -> Union[Optional[str], Exception]:
try:
action_output = action(**inputs)
return action_output
except SherpaActionExecutionException as e:
self.belief.update_internal(
EventType.action_output,
self.feedback_agent_name,
f"Error in executing action: {action.name}. Error: {e}",
)
logger.exception(e)
return None
except Exception as e:
logger.exception(e)
return e
[docs]
async def async_act(self, action: BaseAction, inputs: dict) -> Optional[str]:
try:
action_output = await action(**inputs)
return action_output
except SherpaActionExecutionException as e:
self.belief.update_internal(
EventType.action_output,
self.feedback_agent_name,
f"Error in executing action: {action.name}. Error: {e}",
)
logger.exception(e)
return None
except Exception as e:
logger.exception(e)
return e