Source code for sherpa_ai.models.sherpa_base_chat_model

import typing
from typing import Any, List, Optional

from langchain_openai import ChatOpenAI 
from langchain_core.callbacks import ( 
    AsyncCallbackManagerForLLMRun,
    CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseChatModel 
from langchain_core.messages import BaseMessage 
from langchain_core.outputs import ChatResult 

from sherpa_ai.database.user_usage_tracker import UserUsageTracker
from sherpa_ai.verbose_loggers.base import BaseVerboseLogger


[docs] class SherpaBaseChatModel(BaseChatModel): user_id: typing.Optional[str] = None verbose_logger: BaseVerboseLogger = None def _agenerate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ): pass @property def _llm_type(self): return super()._llm_type def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: response = super()._generate(messages, stop, run_manager, **kwargs) token_before = super().get_num_tokens_from_messages(messages) token_after = 0 for result_message in response.generations: token_after += super().get_num_tokens(result_message.text) total_token = token_before + token_after if self.user_id: user_db = UserUsageTracker(verbose_logger=self.verbose_logger) user_db.add_data(user_id=self.user_id, token=total_token) user_db.close_connection() return response
[docs] class SherpaChatOpenAI(ChatOpenAI): user_id: typing.Optional[str] = None verbose_logger: BaseVerboseLogger = None def _agenerate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ): pass @property def _llm_type(self): return super()._llm_type def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: response = super()._generate(messages, stop, run_manager, **kwargs) token_before = super().get_num_tokens_from_messages(messages) token_after = 0 for result_message in response.generations: token_after += super().get_num_tokens(result_message.text) total_token = token_before + token_after if self.user_id: user_db = UserUsageTracker(verbose_logger=self.verbose_logger) user_db.add_data(user_id=self.user_id, token=total_token) user_db.close_connection() return response