Source code for sherpa_ai.memory.belief

from __future__ import annotations

from typing import TYPE_CHECKING, Callable, List, Optional

import pydash
import transitions as ts
from loguru import logger

from sherpa_ai.actions.base import BaseAction, BaseRetrievalAction
from sherpa_ai.events import Event, EventType

if TYPE_CHECKING:
    from sherpa_ai.memory.state_machine import SherpaStateMachine


[docs] class Belief: """ The belief of the agent. it contains 1. events: the events observed by the agent, synchronized with the shared memory 2. internal_events: the internal events generated by the agent through its reasoning process (actions) """ # noqa E501 def __init__(self): self.events: List[Event] = [] self.internal_events: List[Event] = [] self.current_task: Event = None self.state_machine: SherpaStateMachine = None self.actions = [] self.dict: dict = {} self.max_tokens = 4000
[docs] def update(self, observation: Event): if observation in self.events: return self.events.append(observation)
[docs] def update_internal( self, event_type: EventType, agent: str, content: str, ): event = Event(event_type=event_type, agent=agent, content=content) self.internal_events.append(event)
[docs] def get_by_type(self, event_type): return [ event for event in self.internal_events if event.event_type == event_type ]
[docs] def set_current_task(self, task: Event): self.current_task = task
[docs] def get_context(self, token_counter: Callable[[str], int]): """ Get the context of the agent Args: token_counter: Token counter max_tokens: Maximum number of tokens Returns: str: Context of the agent """ context = "" for event in reversed(self.events): if event.event_type in [ EventType.task, EventType.result, EventType.user_input, ]: message = f"{event.agent}: {event.content}({event.event_type})" context = message + "\n" + context if token_counter(context) > self.max_tokens: break return context
[docs] def get_internal_history(self, token_counter: Callable[[str], int]): """ Get the internal history of the agent Args: token_counter: Token counter Returns: str: Internal history of the agent with event content separated by newlines. History is truncated if the number of tokens exceeds `max_tokens`. """ results = [] current_tokens = 0 for event in reversed(self.internal_events): results.append(event.content) current_tokens += token_counter(event.content) if current_tokens > self.max_tokens: break context = "\n".join(reversed(results)) return context
[docs] def clear_short_term_memory(self): self.dict.clear() self.internal_events.clear()
[docs] def get_histories_excluding_types( self, exclude_types: list[EventType], token_counter: Optional[Callable[[str], int]] = None, max_tokens=4000, ): """ Get the internal history of the agent without events of excluded_type Args: token_counter: Token counter max_tokens: Maximum number of tokens exclude_types: List of events to be excluded Returns: str: Internal history of the agent with event content separated by newlines. History is truncated if the number of tokens exceeds `max_tokens`. """ if token_counter is None: # if no token counter is provided, use the default word counter def token_counter(x): return len(x.split()) results = [] feedback = [] current_tokens = 0 for event in reversed(self.internal_events): if event.event_type not in exclude_types: if event.event_type == EventType.feedback: feedback.append(event.content) else: results.append(event.content) current_tokens += token_counter(event.content) if current_tokens > max_tokens: break context = "\n".join(set(reversed(results))) + "\n".join(set(feedback)) return context
[docs] def set_actions(self, actions: List[BaseAction]): if self.state_machine is not None: logger.warning( "State machine exists, please add actions as transitions directly to the state machine" # noqa E501 ) return self.actions = actions # TODO: This is a quick an dirty way to set the current task # in actions, need to find a better way for action in actions: if isinstance(action, BaseRetrievalAction): action.current_task = self.current_task.content
@property def action_description(self): return "\n".join([str(action) for action in self.get_actions()])
[docs] def get_state(self) -> str: if self.state_machine is None: return None return self.state_machine.get_current_state().name
[docs] def get_state_obj(self) -> ts.State: if self.state_machine is None: return None return self.state_machine.get_current_state()
[docs] def get_actions(self) -> List[BaseAction]: if self.state_machine is None: return self.actions return self.state_machine.get_actions()
[docs] def get_action(self, action_name) -> BaseAction: if self.state_machine is not None: self.actions = self.state_machine.get_actions() result = None for action in self.actions: if action.name == action_name: result = action break return result
[docs] async def async_get_actions(self) -> List[BaseAction]: if self.state_machine is None: return self.actions return await self.state_machine.async_get_actions()
[docs] async def async_get_action(self, action_name) -> BaseAction: if self.state_machine is not None: self.actions = await self.state_machine.async_get_actions() result = None for action in self.actions: if action.name == action_name: result = action break return result
[docs] def get_dict(self): return self.dict
[docs] def get(self, key, default=None): """ Get value from the dict, the key can be a dot separated string if the value is nested """ # noqa E501 return pydash.get(self.dict, key, default)
[docs] def get_all_keys(self): def get_all_keys(d, parent_key=""): keys = [] for k, v in d.items(): full_key = parent_key + "." + k if parent_key else k keys.append(full_key) if isinstance(v, dict): keys.extend(get_all_keys(v, full_key)) return keys return get_all_keys(self.dict)
[docs] def has(self, key): """ Check if the key exists in the dict """ return pydash.has(self.dict, key)
[docs] def set(self, key, value): """ Set value in the dict, the key can be a dot separated string if the value is nested """ # noqa E501 pydash.set_(self.dict, key, value)
@property def __dict__(self): return { "events": [event.__dict__ for event in self.events], "internal_events": [event.__dict__ for event in self.internal_events], "current_task": self.current_task.__dict__ if self.current_task else None, "dict": self.dict, }
[docs] @classmethod def from_dict(cls, data): belief = cls() belief.events = [Event.from_dict(event) for event in data["events"]] belief.internal_events = [ Event.from_dict(event) for event in data["internal_events"] ] belief.current_task = ( Event.from_dict(data["current_task"]) if data["current_task"] else None ) belief.dict = data["dict"] return belief