# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors # # SPDX-License-Identifier: Apache-2.0 import functools import inspect import sys from abc import ABC from functools import wraps from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, TypeVar, Union, get_type_hints from ..agentchat import Agent from ..doc_utils import export_module from ..fast_depends import Depends as FastDepends from ..fast_depends import inject from ..fast_depends.dependencies import model if TYPE_CHECKING: from ..agentchat.conversable_agent import ConversableAgent __all__ = [ "BaseContext", "ChatContext", "Depends", "Field", "get_context_params", "inject_params", "on", "remove_params", ] @export_module("autogen.tools") class BaseContext(ABC): """Base class for context classes. This is the base class for defining various context types that may be used throughout the application. It serves as a parent for specific context classes. """ pass @export_module("autogen.tools") class ChatContext(BaseContext): """ChatContext class that extends BaseContext. This class is used to represent a chat context that holds a list of messages. It inherits from `BaseContext` and adds the `messages` attribute. """ def __init__(self, agent: "ConversableAgent") -> None: """Initializes the ChatContext with an agent. Args: agent: The agent to use for retrieving chat messages. """ self._agent = agent @property def chat_messages(self) -> dict[Agent, list[dict[Any, Any]]]: """The messages in the chat. Returns: A dictionary of agents and their messages. """ return self._agent.chat_messages @property def last_message(self) -> Optional[dict[str, Any]]: """The last message in the chat. Returns: The last message in the chat. """ return self._agent.last_message() T = TypeVar("T") def on(x: T) -> Callable[[], T]: def inner(ag2_x: T = x) -> T: return ag2_x return inner @export_module("autogen.tools") def Depends(x: Any) -> Any: # noqa: N802 """Creates a dependency for injection based on the provided context or type. Args: x: The context or dependency to be injected. Returns: A FastDepends object that will resolve the dependency for injection. """ if isinstance(x, BaseContext): return FastDepends(lambda: x) return FastDepends(x) def get_context_params(func: Callable[..., Any], subclass: Union[type[BaseContext], type[ChatContext]]) -> list[str]: """Gets the names of the context parameters in a function signature. Args: func: The function to inspect for context parameters. subclass: The subclass to search for. Returns: A list of parameter names that are instances of the specified subclass. """ sig = inspect.signature(func) return [p.name for p in sig.parameters.values() if _is_context_param(p, subclass=subclass)] def _is_context_param( param: inspect.Parameter, subclass: Union[type[BaseContext], type[ChatContext]] = BaseContext ) -> bool: # param.annotation.__args__[0] is used to handle Annotated[MyContext, Depends(MyContext(b=2))] param_annotation = param.annotation.__args__[0] if hasattr(param.annotation, "__args__") else param.annotation try: return isinstance(param_annotation, type) and issubclass(param_annotation, subclass) except TypeError: return False def _is_depends_param(param: inspect.Parameter) -> bool: return isinstance(param.default, model.Depends) or ( hasattr(param.annotation, "__metadata__") and type(param.annotation.__metadata__) == tuple and isinstance(param.annotation.__metadata__[0], model.Depends) ) def remove_params(func: Callable[..., Any], sig: inspect.Signature, params: Iterable[str]) -> None: new_signature = sig.replace(parameters=[p for p in sig.parameters.values() if p.name not in params]) func.__signature__ = new_signature # type: ignore[attr-defined] def _remove_injected_params_from_signature(func: Callable[..., Any]) -> Callable[..., Any]: # This is a workaround for Python 3.9+ where staticmethod.__func__ is accessible if sys.version_info >= (3, 9) and isinstance(func, staticmethod) and hasattr(func, "__func__"): func = _fix_staticmethod(func) sig = inspect.signature(func) params_to_remove = [p.name for p in sig.parameters.values() if _is_context_param(p) or _is_depends_param(p)] remove_params(func, sig, params_to_remove) return func class Field: """Represents a description field for use in type annotations. This class is used to store a description for an annotated field, often used for documenting or validating fields in a context or data model. """ def __init__(self, description: str) -> None: """Initializes the Field with a description. Args: description: The description text for the field. """ self._description = description @property def description(self) -> str: return self._description def _string_metadata_to_description_field(func: Callable[..., Any]) -> Callable[..., Any]: type_hints = get_type_hints(func, include_extras=True) for _, annotation in type_hints.items(): # Check if the annotation itself has metadata (using __metadata__) if hasattr(annotation, "__metadata__"): metadata = annotation.__metadata__ if metadata and isinstance(metadata[0], str): # Replace string metadata with Field annotation.__metadata__ = (Field(description=metadata[0]),) # For Python < 3.11, annotations like `Optional` are stored as `Union`, so metadata # would be in the first element of __args__ (e.g., `__args__[0]` for `int` in `Optional[int]`) elif hasattr(annotation, "__args__") and hasattr(annotation.__args__[0], "__metadata__"): metadata = annotation.__args__[0].__metadata__ if metadata and isinstance(metadata[0], str): # Replace string metadata with Field annotation.__args__[0].__metadata__ = (Field(description=metadata[0]),) return func def _fix_staticmethod(f: Callable[..., Any]) -> Callable[..., Any]: # This is a workaround for Python 3.9+ where staticmethod.__func__ is accessible if sys.version_info >= (3, 9) and isinstance(f, staticmethod) and hasattr(f, "__func__"): @wraps(f.__func__) def wrapper(*args: Any, **kwargs: Any) -> Any: return f.__func__(*args, **kwargs) # type: ignore[attr-defined] wrapper.__name__ = f.__func__.__name__ f = wrapper return f def _set_return_annotation_to_any(f: Callable[..., Any]) -> Callable[..., Any]: if inspect.iscoroutinefunction(f): @functools.wraps(f) async def _a_wrapped_func(*args: Any, **kwargs: Any) -> Any: return await f(*args, **kwargs) wrapped_func = _a_wrapped_func else: @functools.wraps(f) def _wrapped_func(*args: Any, **kwargs: Any) -> Any: return f(*args, **kwargs) wrapped_func = _wrapped_func sig = inspect.signature(f) # Change the return annotation directly on the signature of the wrapper wrapped_func.__signature__ = sig.replace(return_annotation=Any) # type: ignore[attr-defined] return wrapped_func def inject_params(f: Callable[..., Any]) -> Callable[..., Any]: """Injects parameters into a function, removing injected dependencies from its signature. This function is used to modify a function by injecting dependencies and removing injected parameters from the function's signature. Args: f: The function to modify with dependency injection. Returns: The modified function with injected dependencies and updated signature. """ # This is a workaround for Python 3.9+ where staticmethod.__func__ is accessible if sys.version_info >= (3, 9) and isinstance(f, staticmethod) and hasattr(f, "__func__"): f = _fix_staticmethod(f) f = _string_metadata_to_description_field(f) f = _set_return_annotation_to_any(f) f = inject(f) f = _remove_injected_params_from_signature(f) return f