Source code for sherpa_ai.models.chat_model_with_logging
import json
import typing
from typing import Any, Coroutine, List, Optional
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatResult
from loguru import logger
[docs]
class ChatModelWithLogging(BaseChatModel):
llm: BaseChatModel
logger: type(logger)
@property
def _llm_type(self):
return self.llm._llm_type
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
# get the name of the language model. For models like OpenAI, this is the model
# name (e.g., gpt-3.5-turbo). for other LLMs, this is the type of the LLM
llm_name = (
self.llm.model_name if hasattr(self.llm, "model_name") else self._llm_type
)
input_text = []
for message in messages:
# make sure all the messages stay on the same line
input_text.append(
{"text": message.content.replace("\n", "\\n"), "agent": message.type}
)
result = self.llm._generate(messages, stop, run_manager, **kwargs)
# only one generation for a LLM call
generation = result.generations[0]
log = {
"input": input_text,
# make sure all the messages stay on the same line
"output": generation.message.content.replace("\n", "\\n"),
"llm_name": llm_name,
}
self.logger.info(json.dumps(log))
return result
def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Coroutine[Any, Any, ChatResult]:
self.llm.agenerate(messages, stop, run_manager, **kwargs)