# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors # # SPDX-License-Identifier: Apache-2.0 # # Portions derived from https://github.com/microsoft/autogen are under the MIT License. # SPDX-License-Identifier: MIT import functools import inspect import json from logging import getLogger from typing import Annotated, Any, Callable, ForwardRef, Optional, TypeVar, Union from packaging.version import parse from pydantic import BaseModel, Field, TypeAdapter from pydantic import __version__ as pydantic_version from pydantic.json_schema import JsonSchemaValue from typing_extensions import Literal, get_args, get_origin from ..doc_utils import export_module from .dependency_injection import Field as AG2Field if parse(pydantic_version) < parse("2.10.2"): from pydantic._internal._typing_extra import eval_type_lenient as try_eval_type else: from pydantic._internal._typing_extra import try_eval_type __all__ = ["get_function_schema", "load_basemodels_if_needed", "serialize_to_str"] logger = getLogger(__name__) T = TypeVar("T") def get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any: """Get the type annotation of a parameter. Args: annotation: The annotation of the parameter globalns: The global namespace of the function Returns: The type annotation of the parameter """ if isinstance(annotation, AG2Field): annotation = annotation.description if isinstance(annotation, str): annotation = ForwardRef(annotation) annotation, _ = try_eval_type(annotation, globalns, globalns) return annotation def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: """Get the signature of a function with type annotations. Args: call: The function to get the signature for Returns: The signature of the function with type annotations """ signature = inspect.signature(call) globalns = getattr(call, "__globals__", {}) typed_params = [ inspect.Parameter( name=param.name, kind=param.kind, default=param.default, annotation=get_typed_annotation(param.annotation, globalns), ) for param in signature.parameters.values() ] typed_signature = inspect.Signature(typed_params) return typed_signature def get_typed_return_annotation(call: Callable[..., Any]) -> Any: """Get the return annotation of a function. Args: call: The function to get the return annotation for Returns: The return annotation of the function """ signature = inspect.signature(call) annotation = signature.return_annotation if annotation is inspect.Signature.empty: return None globalns = getattr(call, "__globals__", {}) return get_typed_annotation(annotation, globalns) def get_param_annotations(typed_signature: inspect.Signature) -> dict[str, Union[Annotated[type[Any], str], type[Any]]]: """Get the type annotations of the parameters of a function Args: typed_signature: The signature of the function with type annotations Returns: A dictionary of the type annotations of the parameters of the function """ return { k: v.annotation for k, v in typed_signature.parameters.items() if v.annotation is not inspect.Signature.empty } class Parameters(BaseModel): """Parameters of a function as defined by the OpenAI API""" type: Literal["object"] = "object" properties: dict[str, JsonSchemaValue] required: list[str] class Function(BaseModel): """A function as defined by the OpenAI API""" description: Annotated[str, Field(description="Description of the function")] name: Annotated[str, Field(description="Name of the function")] parameters: Annotated[Parameters, Field(description="Parameters of the function")] class ToolFunction(BaseModel): """A function under tool as defined by the OpenAI API.""" type: Literal["function"] = "function" function: Annotated[Function, Field(description="Function under tool")] def get_parameter_json_schema(k: str, v: Any, default_values: dict[str, Any]) -> JsonSchemaValue: """Get a JSON schema for a parameter as defined by the OpenAI API Args: k: The name of the parameter v: The type of the parameter default_values: The default values of the parameters of the function Returns: A Pydanitc model for the parameter """ def type2description(k: str, v: Union[Annotated[type[Any], str], type[Any]]) -> str: if not hasattr(v, "__metadata__"): return k # handles Annotated retval = v.__metadata__[0] if isinstance(retval, AG2Field): return retval.description # type: ignore[return-value] else: raise ValueError(f"Invalid {retval} for parameter {k}, should be a DescriptionField, got {type(retval)}") schema = TypeAdapter(v).json_schema() if k in default_values: dv = default_values[k] schema["default"] = dv schema["description"] = type2description(k, v) return schema def get_required_params(typed_signature: inspect.Signature) -> list[str]: """Get the required parameters of a function Args: typed_signature: The signature of the function as returned by inspect.signature Returns: A list of the required parameters of the function """ return [k for k, v in typed_signature.parameters.items() if v.default == inspect.Signature.empty] def get_default_values(typed_signature: inspect.Signature) -> dict[str, Any]: """Get default values of parameters of a function Args: typed_signature: The signature of the function as returned by inspect.signature Returns: A dictionary of the default values of the parameters of the function """ return {k: v.default for k, v in typed_signature.parameters.items() if v.default != inspect.Signature.empty} def get_parameters( required: list[str], param_annotations: dict[str, Union[Annotated[type[Any], str], type[Any]]], default_values: dict[str, Any], ) -> Parameters: """Get the parameters of a function as defined by the OpenAI API Args: required: The required parameters of the function param_annotations: The type annotations of the parameters of the function default_values: The default values of the parameters of the function Returns: A Pydantic model for the parameters of the function """ return Parameters( properties={ k: get_parameter_json_schema(k, v, default_values) for k, v in param_annotations.items() if v is not inspect.Signature.empty }, required=required, ) def get_missing_annotations(typed_signature: inspect.Signature, required: list[str]) -> tuple[set[str], set[str]]: """Get the missing annotations of a function Ignores the parameters with default values as they are not required to be annotated, but logs a warning. Args: typed_signature: The signature of the function with type annotations required: The required parameters of the function Returns: A set of the missing annotations of the function """ all_missing = {k for k, v in typed_signature.parameters.items() if v.annotation is inspect.Signature.empty} missing = all_missing.intersection(set(required)) unannotated_with_default = all_missing.difference(missing) return missing, unannotated_with_default @export_module("autogen.tools") def get_function_schema(f: Callable[..., Any], *, name: Optional[str] = None, description: str) -> dict[str, Any]: """Get a JSON schema for a function as defined by the OpenAI API Args: f: The function to get the JSON schema for name: The name of the function description: The description of the function Returns: A JSON schema for the function Raises: TypeError: If the function is not annotated Examples: ```python def f(a: Annotated[str, "Parameter a"], b: int = 2, c: Annotated[float, "Parameter c"] = 0.1) -> None: pass get_function_schema(f, description="function f") # {'type': 'function', # 'function': {'description': 'function f', # 'name': 'f', # 'parameters': {'type': 'object', # 'properties': {'a': {'type': 'str', 'description': 'Parameter a'}, # 'b': {'type': 'int', 'description': 'b'}, # 'c': {'type': 'float', 'description': 'Parameter c'}}, # 'required': ['a']}}} ``` """ typed_signature = get_typed_signature(f) required = get_required_params(typed_signature) default_values = get_default_values(typed_signature) param_annotations = get_param_annotations(typed_signature) return_annotation = get_typed_return_annotation(f) missing, unannotated_with_default = get_missing_annotations(typed_signature, required) if return_annotation is None: logger.warning( f"The return type of the function '{f.__name__}' is not annotated. Although annotating it is " + "optional, the function should return either a string, a subclass of 'pydantic.BaseModel'." ) if unannotated_with_default != set(): unannotated_with_default_s = [f"'{k}'" for k in sorted(unannotated_with_default)] logger.warning( f"The following parameters of the function '{f.__name__}' with default values are not annotated: " + f"{', '.join(unannotated_with_default_s)}." ) if missing != set(): missing_s = [f"'{k}'" for k in sorted(missing)] raise TypeError( f"All parameters of the function '{f.__name__}' without default values must be annotated. " + f"The annotations are missing for the following parameters: {', '.join(missing_s)}" ) fname = name if name else f.__name__ parameters = get_parameters(required, param_annotations, default_values=default_values) function = ToolFunction( function=Function( description=description, name=fname, parameters=parameters, ) ) return function.model_dump() def get_load_param_if_needed_function(t: Any) -> Optional[Callable[[dict[str, Any], type[BaseModel]], BaseModel]]: """Get a function to load a parameter if it is a Pydantic model Args: t: The type annotation of the parameter Returns: A function to load the parameter if it is a Pydantic model, otherwise None """ origin = get_origin(t) if origin is Annotated: args = get_args(t) if args: return get_load_param_if_needed_function(args[0]) else: # Invalid Annotated usage return None # Handle generic types (list[str], dict[str,Any], Union[...], etc.) or where t is not a type at all # This means it's not a BaseModel subclass if origin is not None or not isinstance(t, type): return None def load_base_model(v: dict[str, Any], model_type: type[BaseModel]) -> BaseModel: return model_type(**v) # Check if it's a class and a subclass of BaseModel if issubclass(t, BaseModel): return load_base_model else: return None @export_module("autogen.tools") def load_basemodels_if_needed(func: Callable[..., Any]) -> Callable[..., Any]: """A decorator to load the parameters of a function if they are Pydantic models Args: func: The function with annotated parameters Returns: A function that loads the parameters before calling the original function """ # get the type annotations of the parameters typed_signature = get_typed_signature(func) param_annotations = get_param_annotations(typed_signature) # get functions for loading BaseModels when needed based on the type annotations kwargs_mapping_with_nones = {k: get_load_param_if_needed_function(t) for k, t in param_annotations.items()} # remove the None values kwargs_mapping = {k: f for k, f in kwargs_mapping_with_nones.items() if f is not None} # a function that loads the parameters before calling the original function @functools.wraps(func) def _load_parameters_if_needed(*args: Any, **kwargs: Any) -> Any: # load the BaseModels if needed for k, f in kwargs_mapping.items(): kwargs[k] = f(kwargs[k], param_annotations[k]) # call the original function return func(*args, **kwargs) @functools.wraps(func) async def _a_load_parameters_if_needed(*args: Any, **kwargs: Any) -> Any: # load the BaseModels if needed for k, f in kwargs_mapping.items(): kwargs[k] = f(kwargs[k], param_annotations[k]) # call the original function return await func(*args, **kwargs) if inspect.iscoroutinefunction(func): return _a_load_parameters_if_needed else: return _load_parameters_if_needed class _SerializableResult(BaseModel): result: Any @export_module("autogen.tools") def serialize_to_str(x: Any) -> str: if isinstance(x, str): return x if isinstance(x, BaseModel): return x.model_dump_json() retval_model = _SerializableResult(result=x) try: return str(retval_model.model_dump()["result"]) except Exception: pass # try json.dumps() and then just return str(x) if that fails too try: return json.dumps(x, ensure_ascii=False) except Exception: return str(x)