/
OS-Worldb968155
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import functools
import json
import re
from abc import ABC, abstractmethod
from collections.abc import Iterable
from contextvars import ContextVar
from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Any, Mapping, Optional, Type, TypeVar, Union
from httpx import Client as httpxClient
from pydantic import BaseModel, ConfigDict, Field, HttpUrl, SecretStr, ValidationInfo, field_serializer, field_validator
if TYPE_CHECKING:
from .oai.client import ModelClient
_KT = TypeVar("_KT")
_VT = TypeVar("_VT")
__all__ = [
"LLMConfig",
"LLMConfigEntry",
"register_llm_config",
]
def _add_default_api_type(d: dict[str, Any]) -> dict[str, Any]:
if "api_type" not in d:
d["api_type"] = "openai"
return d
# Meta class to allow LLMConfig.current and LLMConfig.default to be used as class properties
class MetaLLMConfig(type):
def __init__(cls, *args: Any, **kwargs: Any) -> None:
pass
@property
def current(cls) -> "LLMConfig":
current_llm_config = LLMConfig.get_current_llm_config(llm_config=None)
if current_llm_config is None:
raise ValueError("No current LLMConfig set. Are you inside a context block?")
return current_llm_config # type: ignore[return-value]
@property
def default(cls) -> "LLMConfig":
return cls.current
class LLMConfig(metaclass=MetaLLMConfig):
_current_llm_config: ContextVar["LLMConfig"] = ContextVar("current_llm_config")
def __init__(self, **kwargs: Any) -> None:
outside_properties = list((self._get_base_model_class()).model_json_schema()["properties"].keys())
outside_properties.remove("config_list")
if "config_list" in kwargs and isinstance(kwargs["config_list"], dict):
kwargs["config_list"] = [kwargs["config_list"]]
modified_kwargs = (
kwargs
if "config_list" in kwargs
else {
**{
"config_list": [
{k: v for k, v in kwargs.items() if k not in outside_properties},
]
},
**{k: v for k, v in kwargs.items() if k in outside_properties},
}
)
modified_kwargs["config_list"] = [
_add_default_api_type(v) if isinstance(v, dict) else v for v in modified_kwargs["config_list"]
]
for x in ["max_tokens", "top_p"]:
if x in modified_kwargs:
modified_kwargs["config_list"] = [{**v, x: modified_kwargs[x]} for v in modified_kwargs["config_list"]]
modified_kwargs.pop(x)
self._model = self._get_base_model_class()(**modified_kwargs)
# used by BaseModel to create instance variables
def __enter__(self) -> "LLMConfig":
# Store previous context and set self as current
self._token = LLMConfig._current_llm_config.set(self)
return self
def __exit__(self, exc_type: Type[Exception], exc_val: Exception, exc_tb: Any) -> None:
LLMConfig._current_llm_config.reset(self._token)
@classmethod
def get_current_llm_config(cls, llm_config: "Optional[LLMConfig]" = None) -> "Optional[LLMConfig]":
if llm_config is not None:
return llm_config
try:
return (LLMConfig._current_llm_config.get()).copy()
except LookupError:
return None
def _satisfies_criteria(self, value: Any, criteria_values: Any) -> bool:
if value is None:
return False
if isinstance(value, list):
return bool(set(value) & set(criteria_values)) # Non-empty intersection
else:
return value in criteria_values
@classmethod
def from_json(
cls,
*,
env: Optional[str] = None,
path: Optional[Union[str, Path]] = None,
file_location: Optional[str] = None,
**kwargs: Any,
) -> "LLMConfig":
from .oai.openai_utils import config_list_from_json
if env is None and path is None:
raise ValueError("Either 'env' or 'path' must be provided")
if env is not None and path is not None:
raise ValueError("Only one of 'env' or 'path' can be provided")
config_list = config_list_from_json(
env_or_file=env if env is not None else str(path), file_location=file_location
)
return LLMConfig(config_list=config_list, **kwargs)
def where(self, *, exclude: bool = False, **kwargs: Any) -> "LLMConfig":
from .oai.openai_utils import filter_config
filtered_config_list = filter_config(config_list=self.config_list, filter_dict=kwargs, exclude=exclude)
if len(filtered_config_list) == 0:
raise ValueError(f"No config found that satisfies the filter criteria: {kwargs}")
kwargs = self.model_dump()
kwargs["config_list"] = filtered_config_list
return LLMConfig(**kwargs)
# @functools.wraps(BaseModel.model_dump)
def model_dump(self, *args: Any, exclude_none: bool = True, **kwargs: Any) -> dict[str, Any]:
d = self._model.model_dump(*args, exclude_none=exclude_none, **kwargs)
return {k: v for k, v in d.items() if not (isinstance(v, list) and len(v) == 0)}
# @functools.wraps(BaseModel.model_dump_json)
def model_dump_json(self, *args: Any, exclude_none: bool = True, **kwargs: Any) -> str:
# return self._model.model_dump_json(*args, exclude_none=exclude_none, **kwargs)
d = self.model_dump(*args, exclude_none=exclude_none, **kwargs)
return json.dumps(d)
# @functools.wraps(BaseModel.model_validate)
def model_validate(self, *args: Any, **kwargs: Any) -> Any:
return self._model.model_validate(*args, **kwargs)
@functools.wraps(BaseModel.model_validate_json)
def model_validate_json(self, *args: Any, **kwargs: Any) -> Any:
return self._model.model_validate_json(*args, **kwargs)
@functools.wraps(BaseModel.model_validate_strings)
def model_validate_strings(self, *args: Any, **kwargs: Any) -> Any:
return self._model.model_validate_strings(*args, **kwargs)
def __eq__(self, value: Any) -> bool:
return hasattr(value, "_model") and self._model == value._model
def _getattr(self, o: object, name: str) -> Any:
val = getattr(o, name)
return val
def get(self, key: str, default: Optional[Any] = None) -> Any:
val = getattr(self._model, key, default)
return val
def __getitem__(self, key: str) -> Any:
try:
return self._getattr(self._model, key)
except AttributeError:
raise KeyError(f"Key '{key}' not found in {self.__class__.__name__}")
def __setitem__(self, key: str, value: Any) -> None:
try:
setattr(self._model, key, value)
except ValueError:
raise ValueError(f"'{self.__class__.__name__}' object has no field '{key}'")
def __getattr__(self, name: Any) -> Any:
try:
return self._getattr(self._model, name)
except AttributeError:
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
def __setattr__(self, name: str, value: Any) -> None:
if name == "_model":
object.__setattr__(self, name, value)
else:
setattr(self._model, name, value)
def __contains__(self, key: str) -> bool:
return hasattr(self._model, key)
def __repr__(self) -> str:
d = self.model_dump()
r = [f"{k}={repr(v)}" for k, v in d.items()]
s = f"LLMConfig({', '.join(r)})"
# Replace any keys ending with 'key' or 'token' values with stars for security
s = re.sub(
r"(['\"])(\w*(key|token))\1:\s*(['\"])([^'\"]*)(?:\4)", r"\1\2\1: \4**********\4", s, flags=re.IGNORECASE
)
return s
def __copy__(self) -> "LLMConfig":
return LLMConfig(**self.model_dump())
def __deepcopy__(self, memo: Optional[dict[int, Any]] = None) -> "LLMConfig":
return self.__copy__()
def copy(self) -> "LLMConfig":
return self.__copy__()
def deepcopy(self, memo: Optional[dict[int, Any]] = None) -> "LLMConfig":
return self.__deepcopy__(memo)
def __str__(self) -> str:
return repr(self)
def items(self) -> Iterable[tuple[str, Any]]:
d = self.model_dump()
return d.items()
def keys(self) -> Iterable[str]:
d = self.model_dump()
return d.keys()
def values(self) -> Iterable[Any]:
d = self.model_dump()
return d.values()
_base_model_classes: dict[tuple[Type["LLMConfigEntry"], ...], Type[BaseModel]] = {}
@classmethod
def _get_base_model_class(cls) -> Type["BaseModel"]:
def _get_cls(llm_config_classes: tuple[Type[LLMConfigEntry], ...]) -> Type[BaseModel]:
if llm_config_classes in LLMConfig._base_model_classes:
return LLMConfig._base_model_classes[llm_config_classes]
class _LLMConfig(BaseModel):
temperature: Optional[float] = None
check_every_ms: Optional[int] = None
max_new_tokens: Optional[int] = None
seed: Optional[int] = None
allow_format_str_template: Optional[bool] = None
response_format: Optional[Union[str, dict[str, Any], BaseModel, Type[BaseModel]]] = None
timeout: Optional[int] = None
cache_seed: Optional[int] = None
tools: list[Any] = Field(default_factory=list)
functions: list[Any] = Field(default_factory=list)
parallel_tool_calls: Optional[bool] = None
config_list: Annotated[ # type: ignore[valid-type]
list[Annotated[Union[llm_config_classes], Field(discriminator="api_type")]],
Field(default_factory=list, min_length=1),
]
# Following field is configuration for pydantic to disallow extra fields
model_config = ConfigDict(extra="forbid")
LLMConfig._base_model_classes[llm_config_classes] = _LLMConfig
return _LLMConfig
return _get_cls(tuple(_llm_config_classes))
class LLMConfigEntry(BaseModel, ABC):
api_type: str
model: str = Field(..., min_length=1)
api_key: Optional[SecretStr] = None
api_version: Optional[str] = None
max_tokens: Optional[int] = None
base_url: Optional[HttpUrl] = None
voice: Optional[str] = None
model_client_cls: Optional[str] = None
http_client: Optional[httpxClient] = None
response_format: Optional[Union[str, dict[str, Any], BaseModel, Type[BaseModel]]] = None
default_headers: Optional[Mapping[str, Any]] = None
tags: list[str] = Field(default_factory=list)
# Following field is configuration for pydantic to disallow extra fields
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
@abstractmethod
def create_client(self) -> "ModelClient": ...
@field_validator("base_url", mode="before")
@classmethod
def check_base_url(cls, v: Any, info: ValidationInfo) -> Any:
if not str(v).startswith("https://") and not str(v).startswith("http://"):
v = f"http://{str(v)}"
return v
@field_serializer("base_url")
def serialize_base_url(self, v: Any) -> Any:
return str(v)
@field_serializer("api_key", when_used="unless-none")
def serialize_api_key(self, v: SecretStr) -> Any:
return v.get_secret_value()
def model_dump(self, *args: Any, exclude_none: bool = True, **kwargs: Any) -> dict[str, Any]:
return BaseModel.model_dump(self, exclude_none=exclude_none, *args, **kwargs)
def model_dump_json(self, *args: Any, exclude_none: bool = True, **kwargs: Any) -> str:
return BaseModel.model_dump_json(self, exclude_none=exclude_none, *args, **kwargs)
def get(self, key: str, default: Optional[Any] = None) -> Any:
val = getattr(self, key, default)
if isinstance(val, SecretStr):
return val.get_secret_value()
return val
def __getitem__(self, key: str) -> Any:
try:
val = getattr(self, key)
if isinstance(val, SecretStr):
return val.get_secret_value()
return val
except AttributeError:
raise KeyError(f"Key '{key}' not found in {self.__class__.__name__}")
def __setitem__(self, key: str, value: Any) -> None:
setattr(self, key, value)
def __contains__(self, key: str) -> bool:
return hasattr(self, key)
def items(self) -> Iterable[tuple[str, Any]]:
d = self.model_dump()
return d.items()
def keys(self) -> Iterable[str]:
d = self.model_dump()
return d.keys()
def values(self) -> Iterable[Any]:
d = self.model_dump()
return d.values()
def __repr__(self) -> str:
# Override to eliminate none values from the repr
d = self.model_dump()
r = [f"{k}={repr(v)}" for k, v in d.items()]
s = f"{self.__class__.__name__}({', '.join(r)})"
# Replace any keys ending with '_key' or '_token' values with stars for security
# This regex will match any key ending with '_key' or '_token' and its value, and replace the value with stars
# It also captures the type of quote used (single or double) and reuses it in the replacement
s = re.sub(r'(\w+_(key|token)\s*=\s*)([\'"]).*?\3', r"\1\3**********\3", s, flags=re.IGNORECASE)
return s
def __str__(self) -> str:
return repr(self)
_llm_config_classes: list[Type[LLMConfigEntry]] = []
def register_llm_config(cls: Type[LLMConfigEntry]) -> Type[LLMConfigEntry]:
if isinstance(cls, type) and issubclass(cls, LLMConfigEntry):
_llm_config_classes.append(cls)
else:
raise TypeError(f"Expected a subclass of LLMConfigEntry, got {cls}")
return cls