# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors # # SPDX-License-Identifier: Apache-2.0 # # Portions derived from https://github.com/https://github.com/Lancetnik/FastDepends are under the MIT License. # SPDX-License-Identifier: MIT import asyncio import functools import inspect from contextlib import AsyncExitStack, ExitStack, asynccontextmanager, contextmanager from typing import ( TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, Awaitable, Callable, ContextManager, Dict, ForwardRef, List, Tuple, TypeVar, Union, cast, ) import anyio from typing_extensions import ( Annotated, ParamSpec, get_args, get_origin, ) from ._compat import evaluate_forwardref if TYPE_CHECKING: from types import FrameType P = ParamSpec("P") T = TypeVar("T") async def run_async( func: Union[ Callable[P, T], Callable[P, Awaitable[T]], ], *args: P.args, **kwargs: P.kwargs, ) -> T: if is_coroutine_callable(func): return await cast(Callable[P, Awaitable[T]], func)(*args, **kwargs) else: return await run_in_threadpool(cast(Callable[P, T], func), *args, **kwargs) async def run_in_threadpool(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: if kwargs: func = functools.partial(func, **kwargs) return await anyio.to_thread.run_sync(func, *args) async def solve_generator_async( *sub_args: Any, call: Callable[..., Any], stack: AsyncExitStack, **sub_values: Any ) -> Any: if is_gen_callable(call): cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) elif is_async_gen_callable(call): # pragma: no branch cm = asynccontextmanager(call)(*sub_args, **sub_values) return await stack.enter_async_context(cm) def solve_generator_sync(*sub_args: Any, call: Callable[..., Any], stack: ExitStack, **sub_values: Any) -> Any: cm = contextmanager(call)(*sub_args, **sub_values) return stack.enter_context(cm) def get_typed_signature(call: Callable[..., Any]) -> Tuple[inspect.Signature, Any]: signature = inspect.signature(call) locals = collect_outer_stack_locals() # We unwrap call to get the original unwrapped function call = inspect.unwrap(call) globalns = getattr(call, "__globals__", {}) typed_params = [ inspect.Parameter( name=param.name, kind=param.kind, default=param.default, annotation=get_typed_annotation( param.annotation, globalns, locals, ), ) for param in signature.parameters.values() ] return inspect.Signature(typed_params), get_typed_annotation( signature.return_annotation, globalns, locals, ) def collect_outer_stack_locals() -> Dict[str, Any]: frame = inspect.currentframe() frames: List[FrameType] = [] while frame is not None: if "fast_depends" not in frame.f_code.co_filename: frames.append(frame) frame = frame.f_back locals = {} for f in frames[::-1]: locals.update(f.f_locals) return locals def get_typed_annotation( annotation: Any, globalns: Dict[str, Any], locals: Dict[str, Any], ) -> Any: if isinstance(annotation, str): annotation = ForwardRef(annotation) if isinstance(annotation, ForwardRef): annotation = evaluate_forwardref(annotation, globalns, locals) if get_origin(annotation) is Annotated and (args := get_args(annotation)): solved_args = [get_typed_annotation(x, globalns, locals) for x in args] annotation.__origin__, annotation.__metadata__ = solved_args[0], tuple(solved_args[1:]) return annotation @asynccontextmanager async def contextmanager_in_threadpool( cm: ContextManager[T], ) -> AsyncGenerator[T, None]: exit_limiter = anyio.CapacityLimiter(1) try: yield await run_in_threadpool(cm.__enter__) except Exception as e: ok = bool(await anyio.to_thread.run_sync(cm.__exit__, type(e), e, None, limiter=exit_limiter)) if not ok: # pragma: no branch raise e else: await anyio.to_thread.run_sync(cm.__exit__, None, None, None, limiter=exit_limiter) def is_gen_callable(call: Callable[..., Any]) -> bool: if inspect.isgeneratorfunction(call): return True dunder_call = getattr(call, "__call__", None) # noqa: B004 return inspect.isgeneratorfunction(dunder_call) def is_async_gen_callable(call: Callable[..., Any]) -> bool: if inspect.isasyncgenfunction(call): return True dunder_call = getattr(call, "__call__", None) # noqa: B004 return inspect.isasyncgenfunction(dunder_call) def is_coroutine_callable(call: Callable[..., Any]) -> bool: if inspect.isclass(call): return False if asyncio.iscoroutinefunction(call): return True dunder_call = getattr(call, "__call__", None) # noqa: B004 return asyncio.iscoroutinefunction(dunder_call) async def async_map(func: Callable[..., T], async_iterable: AsyncIterable[Any]) -> AsyncIterable[T]: async for i in async_iterable: yield func(i)