/
OS-Worldb968155
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Callable, Literal, Optional, Union
from uuid import UUID
from pydantic import BaseModel
from ..events import deprecated_by
from ..events.client_events import StreamEvent, UsageSummaryEvent
from .base_message import BaseMessage, wrap_message
__all__ = ["UsageSummaryMessage"]
class ModelUsageSummary(BaseModel):
"""Model usage summary."""
model: str
"""Model name."""
completion_tokens: int
"""Number of tokens used for completion."""
cost: float
"""Cost of the completion."""
prompt_tokens: int
"""Number of tokens used for prompt."""
total_tokens: int
"""Total number of tokens used."""
class ActualUsageSummary(BaseModel):
"""Actual usage summary."""
usages: Optional[list[ModelUsageSummary]] = None
"""List of model usage summaries."""
total_cost: Optional[float] = None
"""Total cost."""
class TotalUsageSummary(BaseModel):
"""Total usage summary."""
usages: Optional[list[ModelUsageSummary]] = None
"""List of model usage summaries."""
total_cost: Optional[float] = None
"""Total cost."""
Mode = Literal["both", "total", "actual"]
def _change_usage_summary_format(
actual_usage_summary: Optional[dict[str, Any]] = None, total_usage_summary: Optional[dict[str, Any]] = None
) -> dict[str, dict[str, Any]]:
summary: dict[str, Any] = {}
for usage_type, usage_summary in {"actual": actual_usage_summary, "total": total_usage_summary}.items():
if usage_summary is None:
summary[usage_type] = {"usages": None, "total_cost": None}
continue
usage_summary_altered_format: dict[str, list[dict[str, Any]]] = {"usages": []}
for k, v in usage_summary.items():
if isinstance(k, str) and isinstance(v, dict):
current_usage = {key: value for key, value in v.items()}
current_usage["model"] = k
usage_summary_altered_format["usages"].append(current_usage)
else:
usage_summary_altered_format[k] = v
summary[usage_type] = usage_summary_altered_format
return summary
@deprecated_by(UsageSummaryEvent)
@wrap_message
class UsageSummaryMessage(BaseMessage):
"""Usage summary message."""
actual: ActualUsageSummary
"""Actual usage summary."""
total: TotalUsageSummary
"""Total usage summary."""
mode: Mode
"""Mode to display the usage summary."""
def __init__(
self,
*,
uuid: Optional[UUID] = None,
actual_usage_summary: Optional[dict[str, Any]] = None,
total_usage_summary: Optional[dict[str, Any]] = None,
mode: Mode = "both",
):
# print(f"{actual_usage_summary=}")
# print(f"{total_usage_summary=}")
summary_dict = _change_usage_summary_format(actual_usage_summary, total_usage_summary)
super().__init__(uuid=uuid, **summary_dict, mode=mode)
def _print_usage(
self,
usage_summary: Union[ActualUsageSummary, TotalUsageSummary],
usage_type: str = "total",
f: Optional[Callable[..., Any]] = None,
) -> None:
f = f or print
word_from_type = "including" if usage_type == "total" else "excluding"
if usage_summary.usages is None or len(usage_summary.usages) == 0:
f("No actual cost incurred (all completions are using cache).", flush=True)
return
f(f"Usage summary {word_from_type} cached usage: ", flush=True)
f(f"Total cost: {round(usage_summary.total_cost, 5)}", flush=True) # type: ignore [arg-type]
for usage in usage_summary.usages:
f(
f"* Model '{usage.model}': cost: {round(usage.cost, 5)}, prompt_tokens: {usage.prompt_tokens}, completion_tokens: {usage.completion_tokens}, total_tokens: {usage.total_tokens}",
flush=True,
)
def print(self, f: Optional[Callable[..., Any]] = None) -> None:
f = f or print
if self.total.usages is None:
f('No usage summary. Please call "create" first.', flush=True)
return
f("-" * 100, flush=True)
if self.mode == "both":
self._print_usage(self.actual, "actual", f)
f()
if self.total.model_dump_json() != self.actual.model_dump_json():
self._print_usage(self.total, "total", f)
else:
f(
"All completions are non-cached: the total cost with cached completions is the same as actual cost.",
flush=True,
)
elif self.mode == "total":
self._print_usage(self.total, "total", f)
elif self.mode == "actual":
self._print_usage(self.actual, "actual", f)
else:
raise ValueError(f'Invalid mode: {self.mode}, choose from "actual", "total", ["actual", "total"]')
f("-" * 100, flush=True)
@deprecated_by(StreamEvent)
@wrap_message
class StreamMessage(BaseMessage):
"""Stream message."""
content: str
"""Content of the message."""
def __init__(self, *, uuid: Optional[UUID] = None, content: str) -> None:
super().__init__(uuid=uuid, content=content)
def print(self, f: Optional[Callable[..., Any]] = None) -> None:
f = f or print
# Set the terminal text color to green
f("\033[32m", end="")
f(self.content, end="", flush=True)
# Reset the terminal text color
f("\033[0m\n")