/
OS-World15d9ddb
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
# Install Azure Cosmos DB SDK if not already
import pickle
from typing import Any, Optional, TypedDict, Union
from ..import_utils import optional_import_block, require_optional_import
from .abstract_cache_base import AbstractCache
with optional_import_block():
from azure.cosmos import CosmosClient, PartitionKey
from azure.cosmos.exceptions import CosmosResourceNotFoundError
@require_optional_import("azure", "cosmosdb")
class CosmosDBConfig(TypedDict, total=False):
connection_string: str
database_id: str
container_id: str
cache_seed: Optional[Union[str, int]]
client: Optional["CosmosClient"]
@require_optional_import("azure", "cosmosdb")
class CosmosDBCache(AbstractCache):
"""Synchronous implementation of AbstractCache using Azure Cosmos DB NoSQL API.
This class provides a concrete implementation of the AbstractCache
interface using Azure Cosmos DB for caching data, with synchronous operations.
Attributes:
seed (Union[str, int]): A seed or namespace used as a partition key.
client (CosmosClient): The Cosmos DB client used for caching.
container: The container instance used for caching.
"""
def __init__(self, seed: Union[str, int], cosmosdb_config: CosmosDBConfig):
"""Initialize the CosmosDBCache instance.
Args:
seed: A seed or namespace for the cache, used as a partition key.
cosmosdb_config: The configuration for the Cosmos DB cache.
"""
self.seed = str(seed)
self.client = cosmosdb_config.get("client") or CosmosClient.from_connection_string(
cosmosdb_config["connection_string"]
)
database_id = cosmosdb_config.get("database_id", "autogen_cache")
self.database = self.client.get_database_client(database_id)
container_id = cosmosdb_config.get("container_id")
self.container = self.database.create_container_if_not_exists(
id=container_id, partition_key=PartitionKey(path="/partitionKey")
)
@classmethod
def create_cache(cls, seed: Union[str, int], cosmosdb_config: CosmosDBConfig):
"""Factory method to create a CosmosDBCache instance based on the provided configuration.
This method decides whether to use an existing CosmosClient or create a new one.
"""
if "client" in cosmosdb_config and isinstance(cosmosdb_config["client"], CosmosClient):
return cls.from_existing_client(seed, **cosmosdb_config)
else:
return cls.from_config(seed, cosmosdb_config)
@classmethod
def from_config(cls, seed: Union[str, int], cosmosdb_config: CosmosDBConfig):
return cls(str(seed), cosmosdb_config)
@classmethod
def from_connection_string(cls, seed: Union[str, int], connection_string: str, database_id: str, container_id: str):
config = {"connection_string": connection_string, "database_id": database_id, "container_id": container_id}
return cls(str(seed), config)
@classmethod
def from_existing_client(cls, seed: Union[str, int], client: "CosmosClient", database_id: str, container_id: str):
config = {"client": client, "database_id": database_id, "container_id": container_id}
return cls(str(seed), config)
def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]:
"""Retrieve an item from the Cosmos DB cache.
Args:
key (str): The key identifying the item in the cache.
default (optional): The default value to return if the key is not found.
Returns:
The deserialized value associated with the key if found, else the default value.
"""
try:
response = self.container.read_item(item=key, partition_key=str(self.seed))
return pickle.loads(response["data"])
except CosmosResourceNotFoundError:
return default
except Exception as e:
# Log the exception or rethrow after logging if needed
# Consider logging or handling the error appropriately here
raise e
def set(self, key: str, value: Any) -> None:
"""Set an item in the Cosmos DB cache.
Args:
key (str): The key under which the item is to be stored.
value: The value to be stored in the cache.
Notes:
The value is serialized using pickle before being stored.
"""
try:
serialized_value = pickle.dumps(value)
item = {"id": key, "partitionKey": str(self.seed), "data": serialized_value}
self.container.upsert_item(item)
except Exception as e:
# Log or handle exception
raise e
def close(self) -> None:
"""Close the Cosmos DB client.
Perform any necessary cleanup, such as closing network connections.
"""
# CosmosClient doesn"t require explicit close in the current SDK
# If you created the client inside this class, you should close it if necessary
pass
def __enter__(self):
"""Context management entry.
Returns:
self: The instance itself.
"""
return self
def __exit__(self, exc_type: Optional[type], exc_value: Optional[Exception], traceback: Optional[Any]) -> None:
"""Context management exit.
Perform cleanup actions such as closing the Cosmos DB client.
"""
self.close()