/
OS-Worldb968155
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
import functools
import inspect
import sys
from abc import ABC
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, TypeVar, Union, get_type_hints
from ..agentchat import Agent
from ..doc_utils import export_module
from ..fast_depends import Depends as FastDepends
from ..fast_depends import inject
from ..fast_depends.dependencies import model
if TYPE_CHECKING:
from ..agentchat.conversable_agent import ConversableAgent
__all__ = [
"BaseContext",
"ChatContext",
"Depends",
"Field",
"get_context_params",
"inject_params",
"on",
"remove_params",
]
@export_module("autogen.tools")
class BaseContext(ABC):
"""Base class for context classes.
This is the base class for defining various context types that may be used
throughout the application. It serves as a parent for specific context classes.
"""
pass
@export_module("autogen.tools")
class ChatContext(BaseContext):
"""ChatContext class that extends BaseContext.
This class is used to represent a chat context that holds a list of messages.
It inherits from `BaseContext` and adds the `messages` attribute.
"""
def __init__(self, agent: "ConversableAgent") -> None:
"""Initializes the ChatContext with an agent.
Args:
agent: The agent to use for retrieving chat messages.
"""
self._agent = agent
@property
def chat_messages(self) -> dict[Agent, list[dict[Any, Any]]]:
"""The messages in the chat.
Returns:
A dictionary of agents and their messages.
"""
return self._agent.chat_messages
@property
def last_message(self) -> Optional[dict[str, Any]]:
"""The last message in the chat.
Returns:
The last message in the chat.
"""
return self._agent.last_message()
T = TypeVar("T")
def on(x: T) -> Callable[[], T]:
def inner(ag2_x: T = x) -> T:
return ag2_x
return inner
@export_module("autogen.tools")
def Depends(x: Any) -> Any: # noqa: N802
"""Creates a dependency for injection based on the provided context or type.
Args:
x: The context or dependency to be injected.
Returns:
A FastDepends object that will resolve the dependency for injection.
"""
if isinstance(x, BaseContext):
return FastDepends(lambda: x)
return FastDepends(x)
def get_context_params(func: Callable[..., Any], subclass: Union[type[BaseContext], type[ChatContext]]) -> list[str]:
"""Gets the names of the context parameters in a function signature.
Args:
func: The function to inspect for context parameters.
subclass: The subclass to search for.
Returns:
A list of parameter names that are instances of the specified subclass.
"""
sig = inspect.signature(func)
return [p.name for p in sig.parameters.values() if _is_context_param(p, subclass=subclass)]
def _is_context_param(
param: inspect.Parameter, subclass: Union[type[BaseContext], type[ChatContext]] = BaseContext
) -> bool:
# param.annotation.__args__[0] is used to handle Annotated[MyContext, Depends(MyContext(b=2))]
param_annotation = param.annotation.__args__[0] if hasattr(param.annotation, "__args__") else param.annotation
try:
return isinstance(param_annotation, type) and issubclass(param_annotation, subclass)
except TypeError:
return False
def _is_depends_param(param: inspect.Parameter) -> bool:
return isinstance(param.default, model.Depends) or (
hasattr(param.annotation, "__metadata__")
and type(param.annotation.__metadata__) == tuple
and isinstance(param.annotation.__metadata__[0], model.Depends)
)
def remove_params(func: Callable[..., Any], sig: inspect.Signature, params: Iterable[str]) -> None:
new_signature = sig.replace(parameters=[p for p in sig.parameters.values() if p.name not in params])
func.__signature__ = new_signature # type: ignore[attr-defined]
def _remove_injected_params_from_signature(func: Callable[..., Any]) -> Callable[..., Any]:
# This is a workaround for Python 3.9+ where staticmethod.__func__ is accessible
if sys.version_info >= (3, 9) and isinstance(func, staticmethod) and hasattr(func, "__func__"):
func = _fix_staticmethod(func)
sig = inspect.signature(func)
params_to_remove = [p.name for p in sig.parameters.values() if _is_context_param(p) or _is_depends_param(p)]
remove_params(func, sig, params_to_remove)
return func
class Field:
"""Represents a description field for use in type annotations.
This class is used to store a description for an annotated field, often used for
documenting or validating fields in a context or data model.
"""
def __init__(self, description: str) -> None:
"""Initializes the Field with a description.
Args:
description: The description text for the field.
"""
self._description = description
@property
def description(self) -> str:
return self._description
def _string_metadata_to_description_field(func: Callable[..., Any]) -> Callable[..., Any]:
type_hints = get_type_hints(func, include_extras=True)
for _, annotation in type_hints.items():
# Check if the annotation itself has metadata (using __metadata__)
if hasattr(annotation, "__metadata__"):
metadata = annotation.__metadata__
if metadata and isinstance(metadata[0], str):
# Replace string metadata with Field
annotation.__metadata__ = (Field(description=metadata[0]),)
# For Python < 3.11, annotations like `Optional` are stored as `Union`, so metadata
# would be in the first element of __args__ (e.g., `__args__[0]` for `int` in `Optional[int]`)
elif hasattr(annotation, "__args__") and hasattr(annotation.__args__[0], "__metadata__"):
metadata = annotation.__args__[0].__metadata__
if metadata and isinstance(metadata[0], str):
# Replace string metadata with Field
annotation.__args__[0].__metadata__ = (Field(description=metadata[0]),)
return func
def _fix_staticmethod(f: Callable[..., Any]) -> Callable[..., Any]:
# This is a workaround for Python 3.9+ where staticmethod.__func__ is accessible
if sys.version_info >= (3, 9) and isinstance(f, staticmethod) and hasattr(f, "__func__"):
@wraps(f.__func__)
def wrapper(*args: Any, **kwargs: Any) -> Any:
return f.__func__(*args, **kwargs) # type: ignore[attr-defined]
wrapper.__name__ = f.__func__.__name__
f = wrapper
return f
def _set_return_annotation_to_any(f: Callable[..., Any]) -> Callable[..., Any]:
if inspect.iscoroutinefunction(f):
@functools.wraps(f)
async def _a_wrapped_func(*args: Any, **kwargs: Any) -> Any:
return await f(*args, **kwargs)
wrapped_func = _a_wrapped_func
else:
@functools.wraps(f)
def _wrapped_func(*args: Any, **kwargs: Any) -> Any:
return f(*args, **kwargs)
wrapped_func = _wrapped_func
sig = inspect.signature(f)
# Change the return annotation directly on the signature of the wrapper
wrapped_func.__signature__ = sig.replace(return_annotation=Any) # type: ignore[attr-defined]
return wrapped_func
def inject_params(f: Callable[..., Any]) -> Callable[..., Any]:
"""Injects parameters into a function, removing injected dependencies from its signature.
This function is used to modify a function by injecting dependencies and removing
injected parameters from the function's signature.
Args:
f: The function to modify with dependency injection.
Returns:
The modified function with injected dependencies and updated signature.
"""
# This is a workaround for Python 3.9+ where staticmethod.__func__ is accessible
if sys.version_info >= (3, 9) and isinstance(f, staticmethod) and hasattr(f, "__func__"):
f = _fix_staticmethod(f)
f = _string_metadata_to_description_field(f)
f = _set_return_annotation_to_any(f)
f = inject(f)
f = _remove_injected_params_from_signature(f)
return f