import time
from loguru import logger
from sqlalchemy import Boolean, Column, Integer, String, create_engine
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import declarative_base, sessionmaker
import sherpa_ai.config as cfg
from sherpa_ai.verbose_loggers.base import BaseVerboseLogger
from sherpa_ai.verbose_loggers.verbose_loggers import DummyVerboseLogger
Base = declarative_base()
[docs]
class UsageTracker(Base):
"""SQLAlchemy base model for tracking LLM token usage on per-user basis"""
__tablename__ = "usage_tracker"
id = Column(Integer, primary_key=True, autoincrement=True)
user_id = Column(String)
token = Column(Integer)
timestamp = Column(Integer)
reset_timestamp = Column(Boolean)
reminded_timestamp = Column(Boolean)
[docs]
class Whitelist(Base):
"""Represents a trusted list of users whose usage is not tracked"""
__tablename__ = "whitelist"
id = Column(Integer, primary_key=True, autoincrement=True)
user_id = Column(String, unique=True)
[docs]
class UserUsageTracker:
"""Enables an app to track LLM token usage on per-user basis"""
def __init__(
self,
db_name=cfg.DB_NAME,
db_url=cfg.DB_URL,
s3_file_key="token_counter.db",
bucket_name="sherpa-sqlight",
verbose_logger: BaseVerboseLogger = DummyVerboseLogger(),
engine=None,
session=None,
):
try:
import boto3
UserUsageTracker.boto3 = boto3
except ImportError:
raise ImportError(
"Could not import boto3 python package."
"This is needed in order to use the UserUsageTracker."
"Please install boto3 with 'pip install boto3'."
)
"""
Initialize the UserUsageTracker.
Args:
db_name (str): Name of the database.
max_daily_token (int): Maximum daily token limit.
"""
self.db_name = db_name
self.db_url = db_url
self.engine = engine or create_engine(self.db_url)
if session:
self.session = session
else:
Session = sessionmaker(bind=self.engine)
self.session = Session()
self.create_table()
self.max_daily_token = cfg.DAILY_TOKEN_LIMIT
self.verbose_logger = verbose_logger
self.is_reminded = False
self.usage_percentage_allowed = 75
self.limit_time_size_in_hours = float(cfg.LIMIT_TIME_SIZE_IN_HOURS or 24)
self.bucket_name = bucket_name
self.s3_file_key = s3_file_key
self.local_file_path = f"./{self.db_name}"
[docs]
@classmethod
def download_from_s3(
cls,
db_name=cfg.DB_NAME,
db_url=cfg.DB_URL,
s3_file_key="token_counter.db",
bucket_name="sherpa-sqlight",
verbose_logger: BaseVerboseLogger = DummyVerboseLogger(),
):
"""
Download usage tracking database from Amazon S3 to local file.
Args:
bucket_name (str): Name of the S3 bucket.
s3_file_key (str): Key of the file in the S3 bucket.
local_file_path (str): Local path where the file will be downloaded.
"""
local_file_path = f"./{db_name}"
# file_path = Path(self.local_file_path)
# if not file_path.exists():
s3 = cls.boto3.client("s3")
try:
s3.download_file(bucket_name, s3_file_key, local_file_path)
except Exception as e:
logger.error(f"Error download from s3: {str(e)}")
return cls(
db_name=db_name,
db_url=db_url,
s3_file_key=s3_file_key,
bucket_name=bucket_name,
verbose_logger=verbose_logger,
)
[docs]
def upload_to_s3(self):
"""
Upload usage tracking database file from local file to Amazon S3.
Args:
local_file_path (str): Local path of the file to be uploaded.
bucket_name (str): Name of the S3 bucket.
s3_file_key (str): Key of the file in the S3 bucket.
"""
s3 = UserUsageTracker.boto3.client("s3")
try:
s3.upload_file(self.local_file_path, self.bucket_name, self.s3_file_key)
except Exception as e:
logger.error(f"Error uploading file to S3: {str(e)}")
[docs]
def create_table(self):
"""Create the necessary tables in the database."""
Base.metadata.create_all(self.engine)
[docs]
def add_to_whitelist(self, user_id):
"""
Add a user to the whitelist table.
Args:
user_id (str): ID of the user to be added to the whitelist.
"""
user = Whitelist(user_id=user_id)
try:
self.session.add(user)
self.session.commit()
except IntegrityError:
logger.warning(f"Ignoring user ID {user_id}, already whitelisted")
self.session.rollback()
if not cfg.FLASK_DEBUG:
self.upload_to_s3()
[docs]
def get_all_whitelisted_ids(self):
"""Get a list of all user IDs in the whitelist."""
whitelisted_ids = [user.user_id for user in self.session.query(Whitelist).all()]
return whitelisted_ids
[docs]
def get_whitelist_by_user_id(self, user_id):
"""
Get whitelist information for a specific user.
Args:
user_id (str): ID of the user.
Returns:
list: List of dictionaries containing whitelist information.
"""
data = self.session.query(Whitelist).filter_by(user_id=user_id).all()
return [{"id": item.id, "user_id": item.user_id} for item in data]
[docs]
def is_in_whitelist(self, user_id):
"""
Check if a user is in the whitelist.
Args:
user_id (str): ID of the user.
Returns:
bool: True if the user is in the whitelist, False otherwise.
"""
return bool(self.get_whitelist_by_user_id(user_id))
[docs]
def add_and_check_data(
self, user_id, token, reset_timestamp=False, reminded_timestamp=False
):
"""
Add usage data for a user and check for reminders.
Args:
user_id (str): ID of the user.
token (int): Number of tokens used.
reset_timestamp (bool): Whether to reset the timestamp.
reminded_timestamp (bool): Set reminded_timestamp.
"""
self.add_data(
user_id=user_id,
token=token,
reset_timestamp=reset_timestamp,
reminded_timestamp=reminded_timestamp,
)
self.remind_user_of_daily_token_limit(user_id=user_id)
[docs]
def add_data(self, user_id, token, reset_timestamp=False, reminded_timestamp=False):
"""
Add usage data for a user.
Args:
user_id (str): ID of the user.
token (int): Number of tokens used.
reset_timestamp (bool): Whether to reset the timestamp.
reminded_timestamp (bool): Set reminded_timestamp.
"""
data = UsageTracker(
user_id=user_id,
token=token,
timestamp=int(time.time()),
reset_timestamp=reset_timestamp,
reminded_timestamp=reminded_timestamp,
)
self.session.add(data)
self.session.commit()
self.upload_to_s3()
[docs]
def percentage_used(self, user_id):
"""
Calculate the percentage of daily token quota used by a user.
Args:
user_id (str): ID of the user.
Returns:
float: Percentage of daily tokens used since last reset.
"""
total_token_since_last_reset = self.get_sum_of_tokens_since_last_reset(
user_id=user_id
)
return (total_token_since_last_reset * 100) / self.max_daily_token
[docs]
def remind_user_of_daily_token_limit(self, user_id):
"""
Remind the user when their token usage exceeds a certain percentage.
Args:
user_id (str): ID of the user.
"""
user_is_whitelisted = self.is_in_whitelist(user_id)
self.is_reminded = self.check_if_reminded(user_id=user_id)
if not user_is_whitelisted and not self.is_reminded:
if (
self.percentage_used(user_id=user_id) > self.usage_percentage_allowed
and not self.is_reminded
):
self.add_data(user_id=user_id, token=0, reminded_timestamp=True)
self.verbose_logger.log(
f"Hi friend, you have used up {self.usage_percentage_allowed}% of your daily token limit. once you go over the limit there will be a 24 hour cool down period after which you can continue using Sherpa! be awesome!"
)
[docs]
def get_data_since_last_reset(self, user_id):
"""
Get usage since the user's usage data was last reset.
Args:
user_id (str): ID of the user.
Returns:
list: List of dictionaries containing usage data.
"""
last_reset_info = self.get_last_reset_info(user_id)
# if there is no reset point before all the users interaction will be taken as a data since last reset
if last_reset_info is None or last_reset_info["id"] is None:
data = self.session.query(UsageTracker).filter_by(user_id=user_id).all()
return [
{
"id": item.id,
"user_id": item.user_id,
"token": item.token,
"timestamp": item.timestamp,
"reset_timestamp": item.reset_timestamp,
"reminded_timestamp": item.reminded_timestamp,
}
for item in data
]
# return every thing from the last reset point.
# since id is incremental everything greater than the earliest reset point
data = (
self.session.query(UsageTracker)
.filter(
UsageTracker.user_id == user_id,
UsageTracker.id >= last_reset_info["id"],
)
.all()
)
return [
{
"id": item.id,
"user_id": item.user_id,
"token": item.token,
"timestamp": item.timestamp,
"reset_timestamp": item.reset_timestamp,
"reminded_timestamp": item.reminded_timestamp,
}
for item in data
]
[docs]
def check_if_reminded(self, user_id):
data_list = self.get_data_since_last_reset(user_id)
is_reminded_true = any(
item.get("reminded_timestamp", False) for item in data_list
)
return is_reminded_true
[docs]
def get_sum_of_tokens_since_last_reset(self, user_id):
"""
Calculate the sum of tokens used since the last reset for a user.
Args:
user_id (str): ID of the user.
Returns:
int: Sum of tokens used since the last reset.
"""
data_since_last_reset = self.get_data_since_last_reset(user_id)
if len(data_since_last_reset) == 1 and "user_id" in data_since_last_reset[0]:
return data_since_last_reset[0]["token"]
token_sum = sum(item["token"] for item in data_since_last_reset)
return token_sum
[docs]
def reset_usage(self, user_id, token_amount):
"""
Reset the usage data for a user to zero.
Args:
user_id (str): ID of the user.
token_amount (int): Number of tokens to reset.
"""
self.add_and_check_data(
user_id=user_id, token=token_amount, reset_timestamp=True
)
[docs]
def get_last_reset_info(self, user_id):
"""
Get information about the most recent usage data reset for a user.
Args:
user_id (str): ID of the user.
Returns:
dict or None: Dictionary containing last reset information or None if not found.
"""
data = (
self.session.query(UsageTracker.id, UsageTracker.timestamp)
.filter(UsageTracker.user_id == user_id, UsageTracker.reset_timestamp == 1)
.order_by(UsageTracker.timestamp.desc())
.first()
)
if data:
last_reset_id, last_reset_timestamp = data
return {"id": last_reset_id, "timestamp": last_reset_timestamp}
else:
return None
[docs]
def seconds_to_hms(self, seconds):
"""
Convert seconds to hours, minutes, and seconds.
Args:
seconds (int): Number of seconds.
Returns:
str: Formatted string in the format "hours : minutes : seconds".
"""
remaining_seconds = int(float(self.limit_time_size_in_hours) * 3600 - seconds)
hours = remaining_seconds // 3600
minutes = (remaining_seconds % 3600) // 60
seconds = remaining_seconds % 60
return f"{hours} hours : {minutes} min : {seconds} sec"
[docs]
def check_usage(self, user_id, token_amount):
"""
Check user usage and determine whether user is allowed to consume more tokens.
Args:
user_id (str): ID of the user.
token_amount (int): Number of tokens to check.
Returns:
dict: Result containing information about tokens remaining,
whether more tokens can be consumed (can_execute),
any associated message, and the time left.
"""
user_is_whitelisted = self.is_in_whitelist(user_id)
if user_is_whitelisted:
return {
"token-left": self.max_daily_token,
"can_execute": True,
"message": "",
"time_left": "",
}
else:
last_reset_info = self.get_last_reset_info(user_id=user_id)
#
time_since_last_reset = None
if last_reset_info is not None and last_reset_info["timestamp"] is not None:
time_since_last_reset = int(time.time()) - last_reset_info["timestamp"]
if int(token_amount) > self.max_daily_token:
return {
"token-left": 0,
"can_execute": False,
"message": "your request exceeds token limit. try using smaller context.",
"time_left": "",
}
if time_since_last_reset is None or (
time_since_last_reset != 0
and time_since_last_reset > 3600 * float(self.limit_time_size_in_hours)
):
self.reset_usage(user_id=user_id, token_amount=token_amount)
return {
"token-left": self.max_daily_token,
"can_execute": True,
"message": "",
"time_left": "",
}
else:
total_token_since_last_reset = self.get_sum_of_tokens_since_last_reset(
user_id=user_id
)
remaining_tokens = self.max_daily_token - total_token_since_last_reset
if remaining_tokens <= 0:
return {
"token-left": remaining_tokens,
"can_execute": False,
"message": "daily usage limit exceeded. you can try after 24 hours",
"time_left": self.seconds_to_hms(time_since_last_reset),
}
else:
self.add_and_check_data(user_id=user_id, token=token_amount)
return {
"token-left": remaining_tokens,
"current_token": token_amount,
"can_execute": True,
"message": "",
"time_left": self.seconds_to_hms(time_since_last_reset),
}
[docs]
def get_all_data(self):
data = self.session.query(UsageTracker).all()
return [item for item in data]
[docs]
def close_connection(self):
"""Close the database connection."""
self.session.close()