Source code for sherpa_ai.policies.chat_sm_policy

from __future__ import annotations

import asyncio
import json
from typing import TYPE_CHECKING, ClassVar, Optional

from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.prompts.chat import ChatPromptTemplate
from loguru import logger
from pydantic import BaseModel, ConfigDict

from sherpa_ai.actions.base import BaseAction
from sherpa_ai.policies.base import BasePolicy, PolicyOutput
from sherpa_ai.policies.exceptions import SherpaPolicyException
from sherpa_ai.policies.utils import (
    construct_conversation_from_belief,
    is_selection_trivial,
    transform_json_output,
)
from sherpa_ai.prompts.prompt_template_loader import PromptTemplate

if TYPE_CHECKING:
    from sherpa_ai.memory.belief import Belief


[docs] class ChatStateMachinePolicy(BasePolicy): """ The policy to select an action from the belief based on the ReACT framework. """ model_config = ConfigDict( arbitrary_types_allowed=True, ) # Initialize PromptTemplate as a class variable _prompt_template: ClassVar[PromptTemplate] = PromptTemplate( "./sherpa_ai/prompts/prompts.json" ) # Initialize base templates with placeholder variables SYSTEM_PROMPT: ClassVar[str] = _prompt_template.format_prompt( prompt_parent_id="system_prompt", prompt_id="SYSTEM_PROMPT", version="1.0", variables={ "context": "{context}", "task_description": "{task_description}", "response_format": "{response_format}", }, ) ACTION_SELECTION_PROMPT: ClassVar[str] = _prompt_template.format_prompt( prompt_parent_id="action_selection_prompt", prompt_id="ACTION_SELECTION_PROMPT", version="1.0", variables={ "state": "{state}", "state_description": "{state_description}", "possible_actions": "{possible_actions}", }, ) # Initialize base chat template chat_template: ChatPromptTemplate = ChatPromptTemplate( [ ("system", str(SYSTEM_PROMPT)), ("placeholder", "{conversation}"), ("human", str(ACTION_SELECTION_PROMPT)), ] ) print("%" * 70) print(chat_template, flush=True) print("%" * 70) llm: Optional[BaseChatModel] = None response_format: dict = { "action": { "name": "action name you choose", "args": {"arg name": "value"}, }, } max_conversation_tokens: int = 8000
[docs] def get_prompt_data(self, belief: Belief, actions: list[BaseAction]) -> dict: """Create the prompt based on information from the belief""" task_description = belief.current_task.content possible_actions = "\n".join([str(action) for action in actions]) current_state = belief.get_state_obj().name state_description = belief.get_state_obj().description conversations = construct_conversation_from_belief( belief, self.llm.get_num_tokens, self.max_conversation_tokens ) formatted_state_description = "" if state_description: variables = {"state": current_state, "state_description": state_description} formatted_state_description = self._prompt_template.format_prompt( prompt_parent_id="state_description_prompt", prompt_id="STATE_DESCRIPTION_PROMPT", version="1.0", variables=variables, ) return { "task_description": task_description, "possible_actions": possible_actions, "state": current_state, "state_description": formatted_state_description, "response_format": json.dumps(self.response_format, indent=4), "conversation": conversations, }
[docs] def select_action(self, belief: Belief) -> Optional[PolicyOutput]: """Synchronous wrapper for async_select_action""" return asyncio.run(self.async_select_action(belief))
[docs] async def async_select_action(self, belief: Belief) -> Optional[PolicyOutput]: """Select an action based on the current belief state""" actions = await belief.async_get_actions() if is_selection_trivial(actions): return PolicyOutput(action=actions[0], args={}) prompt_data = self.get_prompt_data(belief, actions) logger.debug(f"Prompt: {self.chat_template.invoke(prompt_data)}") chain = self.chat_template | self.llm result = chain.invoke(prompt_data).content logger.debug(f"Result: {result}") name, args = transform_json_output(result) action = await belief.async_get_action(name) if action is None: raise SherpaPolicyException( f"Action {name} not found in the list of possible actions" ) return PolicyOutput(action=action, args=args)