Source code for sherpa_ai.models.sherpa_base_model
import typing
from typing import Any, List, Optional
from langchain_openai import ChatOpenAI
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatResult
from pydantic import BaseModel
from sherpa_ai.database.user_usage_tracker import UserUsageTracker
[docs]
class SherpaOpenAI(ChatOpenAI):
user_id: typing.Optional[str] = None
def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
):
pass
@property
def _llm_type(self):
return super()._llm_type
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
response = super()._generate(prompts, stop, run_manager, **kwargs)
total_token = response.llm_output["token_usage"]["total_tokens"]
if self.user_id:
user_db = UserUsageTracker()
user_db.add_data(user_id=self.user_id, token=total_token)
user_db.close_connection()
return response