# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors # # SPDX-License-Identifier: Apache-2.0 # # Portions derived from https://github.com/microsoft/autogen are under the MIT License. # SPDX-License-Identifier: MIT # Install Azure Cosmos DB SDK if not already import pickle from typing import Any, Optional, TypedDict, Union from ..import_utils import optional_import_block, require_optional_import from .abstract_cache_base import AbstractCache with optional_import_block(): from azure.cosmos import CosmosClient, PartitionKey from azure.cosmos.exceptions import CosmosResourceNotFoundError @require_optional_import("azure", "cosmosdb") class CosmosDBConfig(TypedDict, total=False): connection_string: str database_id: str container_id: str cache_seed: Optional[Union[str, int]] client: Optional["CosmosClient"] @require_optional_import("azure", "cosmosdb") class CosmosDBCache(AbstractCache): """Synchronous implementation of AbstractCache using Azure Cosmos DB NoSQL API. This class provides a concrete implementation of the AbstractCache interface using Azure Cosmos DB for caching data, with synchronous operations. Attributes: seed (Union[str, int]): A seed or namespace used as a partition key. client (CosmosClient): The Cosmos DB client used for caching. container: The container instance used for caching. """ def __init__(self, seed: Union[str, int], cosmosdb_config: CosmosDBConfig): """Initialize the CosmosDBCache instance. Args: seed: A seed or namespace for the cache, used as a partition key. cosmosdb_config: The configuration for the Cosmos DB cache. """ self.seed = str(seed) self.client = cosmosdb_config.get("client") or CosmosClient.from_connection_string( cosmosdb_config["connection_string"] ) database_id = cosmosdb_config.get("database_id", "autogen_cache") self.database = self.client.get_database_client(database_id) container_id = cosmosdb_config.get("container_id") self.container = self.database.create_container_if_not_exists( id=container_id, partition_key=PartitionKey(path="/partitionKey") ) @classmethod def create_cache(cls, seed: Union[str, int], cosmosdb_config: CosmosDBConfig): """Factory method to create a CosmosDBCache instance based on the provided configuration. This method decides whether to use an existing CosmosClient or create a new one. """ if "client" in cosmosdb_config and isinstance(cosmosdb_config["client"], CosmosClient): return cls.from_existing_client(seed, **cosmosdb_config) else: return cls.from_config(seed, cosmosdb_config) @classmethod def from_config(cls, seed: Union[str, int], cosmosdb_config: CosmosDBConfig): return cls(str(seed), cosmosdb_config) @classmethod def from_connection_string(cls, seed: Union[str, int], connection_string: str, database_id: str, container_id: str): config = {"connection_string": connection_string, "database_id": database_id, "container_id": container_id} return cls(str(seed), config) @classmethod def from_existing_client(cls, seed: Union[str, int], client: "CosmosClient", database_id: str, container_id: str): config = {"client": client, "database_id": database_id, "container_id": container_id} return cls(str(seed), config) def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]: """Retrieve an item from the Cosmos DB cache. Args: key (str): The key identifying the item in the cache. default (optional): The default value to return if the key is not found. Returns: The deserialized value associated with the key if found, else the default value. """ try: response = self.container.read_item(item=key, partition_key=str(self.seed)) return pickle.loads(response["data"]) except CosmosResourceNotFoundError: return default except Exception as e: # Log the exception or rethrow after logging if needed # Consider logging or handling the error appropriately here raise e def set(self, key: str, value: Any) -> None: """Set an item in the Cosmos DB cache. Args: key (str): The key under which the item is to be stored. value: The value to be stored in the cache. Notes: The value is serialized using pickle before being stored. """ try: serialized_value = pickle.dumps(value) item = {"id": key, "partitionKey": str(self.seed), "data": serialized_value} self.container.upsert_item(item) except Exception as e: # Log or handle exception raise e def close(self) -> None: """Close the Cosmos DB client. Perform any necessary cleanup, such as closing network connections. """ # CosmosClient doesn"t require explicit close in the current SDK # If you created the client inside this class, you should close it if necessary pass def __enter__(self): """Context management entry. Returns: self: The instance itself. """ return self def __exit__(self, exc_type: Optional[type], exc_value: Optional[Exception], traceback: Optional[Any]) -> None: """Context management exit. Perform cleanup actions such as closing the Cosmos DB client. """ self.close()