mirrored 13 minutes ago
0
Linxin SongCoACT initialize (#292) b968155
# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors
#
# SPDX-License-Identifier: Apache-2.0
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
"""Create a OpenAI-compatible client for Gemini features.

Example:
    ```python
    llm_config = {
        "config_list": [
            {
                "api_type": "google",
                "model": "gemini-pro",
                "api_key": os.environ.get("GOOGLE_GEMINI_API_KEY"),
                "safety_settings": [
                    {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"},
                    {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"},
                    {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_ONLY_HIGH"},
                    {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_ONLY_HIGH"},
                ],
                "top_p": 0.5,
                "max_tokens": 2048,
                "temperature": 1.0,
                "top_k": 5,
            }
        ]
    }

    agent = autogen.AssistantAgent("my_agent", llm_config=llm_config)
    ```

Resources:
- https://ai.google.dev/docs
- https://cloud.google.com/vertex-ai/generative-ai/docs/migrate/migrate-from-azure-to-gemini
- https://blog.google/technology/ai/google-gemini-pro-imagen-duet-ai-update/
- https://ai.google.dev/api/python/google/generativeai/ChatSession
"""

from __future__ import annotations

import asyncio
import base64
import copy
import json
import logging
import os
import random
import re
import time
import warnings
from io import BytesIO
from typing import Any, Literal, Optional, Type, Union

import requests
from packaging import version
from pydantic import BaseModel, Field

from ..import_utils import optional_import_block, require_optional_import
from ..json_utils import resolve_json_references
from ..llm_config import LLMConfigEntry, register_llm_config
from .client_utils import FormatterProtocol
from .gemini_types import ToolConfig
from .oai_models import ChatCompletion, ChatCompletionMessage, ChatCompletionMessageToolCall, Choice, CompletionUsage

with optional_import_block():
    import google.genai as genai
    import vertexai
    from PIL import Image
    from google.auth.credentials import Credentials
    from google.genai.types import (
        Content,
        FinishReason,
        FunctionCall,
        FunctionDeclaration,
        FunctionResponse,
        GenerateContentConfig,
        GenerateContentResponse,
        GoogleSearch,
        Part,
        Schema,
        Tool,
        Type,
    )
    from jsonschema import ValidationError
    from vertexai.generative_models import Content as VertexAIContent
    from vertexai.generative_models import FunctionDeclaration as vaiFunctionDeclaration
    from vertexai.generative_models import GenerationConfig, GenerativeModel
    from vertexai.generative_models import (
        GenerationResponse as VertexAIGenerationResponse,
    )
    from vertexai.generative_models import HarmBlockThreshold as VertexAIHarmBlockThreshold
    from vertexai.generative_models import HarmCategory as VertexAIHarmCategory
    from vertexai.generative_models import Part as VertexAIPart
    from vertexai.generative_models import SafetySetting as VertexAISafetySetting
    from vertexai.generative_models import (
        Tool as vaiTool,
    )

logger = logging.getLogger(__name__)


@register_llm_config
class GeminiLLMConfigEntry(LLMConfigEntry):
    api_type: Literal["google"] = "google"
    project_id: Optional[str] = None
    location: Optional[str] = None
    # google_application_credentials points to the path of the JSON Keyfile
    google_application_credentials: Optional[str] = None
    # credentials is a google.auth.credentials.Credentials object
    credentials: Optional[Union[Any, str]] = None
    stream: bool = False
    safety_settings: Optional[Union[list[dict[str, Any]], dict[str, Any]]] = None
    price: Optional[list[float]] = Field(default=None, min_length=2, max_length=2)
    tool_config: Optional[ToolConfig] = None

    def create_client(self):
        raise NotImplementedError("GeminiLLMConfigEntry.create_client() is not implemented.")


@require_optional_import(["google", "vertexai", "PIL", "jsonschema"], "gemini")
class GeminiClient:
    """Client for Google's Gemini API."""

    # Mapping, where Key is a term used by Autogen, and Value is a term used by Gemini
    PARAMS_MAPPING = {
        "max_tokens": "max_output_tokens",
        # "n": "candidate_count", # Gemini supports only `n=1`
        "stop_sequences": "stop_sequences",
        "temperature": "temperature",
        "top_p": "top_p",
        "top_k": "top_k",
        "max_output_tokens": "max_output_tokens",
    }

    def _initialize_vertexai(self, **params):
        if "google_application_credentials" in params:
            # Path to JSON Keyfile
            os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = params["google_application_credentials"]
        vertexai_init_args = {}
        if "project_id" in params:
            vertexai_init_args["project"] = params["project_id"]
        if "location" in params:
            vertexai_init_args["location"] = params["location"]
        if "credentials" in params:
            assert isinstance(params["credentials"], Credentials), (
                "Object type google.auth.credentials.Credentials is expected!"
            )
            vertexai_init_args["credentials"] = params["credentials"]
        if vertexai_init_args:
            vertexai.init(**vertexai_init_args)

    def __init__(self, **kwargs):
        """Uses either either api_key for authentication from the LLM config
        (specifying the GOOGLE_GEMINI_API_KEY environment variable also works),
        or follows the Google authentication mechanism for VertexAI in Google Cloud if no api_key is specified,
        where project_id and location can also be passed as parameters. Previously created credentials object can be provided,
        or a Service account key file can also be used. If neither a service account key file, nor the api_key are passed,
        then the default credentials will be used, which could be a personal account if the user is already authenticated in,
        like in Google Cloud Shell.

        Args:
            **kwargs: The keyword arguments to initialize the Gemini client.
        """
        self.api_key = kwargs.get("api_key")
        if not self.api_key:
            self.api_key = os.getenv("GOOGLE_GEMINI_API_KEY")
            if self.api_key is None:
                self.use_vertexai = True
                self._initialize_vertexai(**kwargs)
            else:
                self.use_vertexai = False
        else:
            self.use_vertexai = False
        if not self.use_vertexai:
            assert ("project_id" not in kwargs) and ("location" not in kwargs), (
                "Google Cloud project and compute location cannot be set when using an API Key!"
            )

        self.api_version = kwargs.get("api_version")

        # Store the response format, if provided (for structured outputs)
        self._response_format: Optional[type[BaseModel]] = None

    def message_retrieval(self, response) -> list:
        """Retrieve and return a list of strings or a list of Choice.Message from the response.

        NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
        since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
        """
        return [choice.message for choice in response.choices]

    def cost(self, response) -> float:
        return response.cost

    @staticmethod
    def get_usage(response) -> dict:
        """Return usage summary of the response using RESPONSE_USAGE_KEYS."""
        # ...  # pragma: no cover
        return {
            "prompt_tokens": response.usage.prompt_tokens,
            "completion_tokens": response.usage.completion_tokens,
            "total_tokens": response.usage.total_tokens,
            "cost": response.cost,
            "model": response.model,
        }

    def create(self, params: dict) -> ChatCompletion:
        # When running in async context via run_in_executor from ConversableAgent.a_generate_oai_reply,
        # this method runs in a new thread that doesn't have an event loop by default. The Google Genai
        # client requires an event loop even for synchronous operations, so we need to ensure one exists.
        try:
            asyncio.get_running_loop()
        except RuntimeError:
            # No event loop exists in this thread (which happens when called from an executor)
            # Create a new event loop for this thread to satisfy Genai client requirements
            loop = asyncio.new_event_loop()
            asyncio.set_event_loop(loop)

        if self.use_vertexai:
            self._initialize_vertexai(**params)
        else:
            assert ("project_id" not in params) and ("location" not in params), (
                "Google Cloud project and compute location cannot be set when using an API Key!"
            )
        model_name = params.get("model", "gemini-pro")

        if model_name == "gemini-pro-vision":
            raise ValueError(
                "Gemini 1.0 Pro vision ('gemini-pro-vision') has been deprecated, please consider switching to a different model, for example 'gemini-1.5-flash'."
            )
        elif not model_name:
            raise ValueError(
                "Please provide a model name for the Gemini Client. "
                "You can configure it in the OAI Config List file. "
                "See this [LLM configuration tutorial](https://docs.ag2.ai/latest/docs/user-guide/basic-concepts/llm-configuration/) for more details."
            )

        params.get("api_type", "google")  # not used
        http_options = {"api_version": self.api_version} if self.api_version else None
        messages = params.get("messages", [])
        stream = params.get("stream", False)
        n_response = params.get("n", 1)
        system_instruction = self._extract_system_instruction(messages)
        response_validation = params.get("response_validation", True)
        tools = self._tools_to_gemini_tools(params["tools"]) if "tools" in params else None
        tool_config = params.get("tool_config")

        generation_config = {
            gemini_term: params[autogen_term]
            for autogen_term, gemini_term in self.PARAMS_MAPPING.items()
            if autogen_term in params
        }
        if self.use_vertexai:
            safety_settings = GeminiClient._to_vertexai_safety_settings(params.get("safety_settings", []))
        else:
            safety_settings = params.get("safety_settings", [])

        if stream:
            warnings.warn(
                "Streaming is not supported for Gemini yet, and it will have no effect. Please set stream=False.",
                UserWarning,
            )
            stream = False

        if n_response > 1:
            warnings.warn("Gemini only supports `n=1` for now. We only generate one response.", UserWarning)

        autogen_tool_calls = []

        # Maps the function call ids to function names so we can inject it into FunctionResponse messages
        self.tool_call_function_map: dict[str, str] = {}

        # If response_format exists, we want structured outputs
        # Based on
        # https://ai.google.dev/gemini-api/docs/structured-output?lang=python#supply-schema-in-config
        if params.get("response_format"):
            self._response_format = params.get("response_format")
            generation_config["response_mime_type"] = "application/json"

            response_format_schema_raw = params.get("response_format")

            if isinstance(response_format_schema_raw, dict):
                response_schema = resolve_json_references(response_format_schema_raw)
            else:
                response_schema = resolve_json_references(params.get("response_format").model_json_schema())
            if "$defs" in response_schema:
                response_schema.pop("$defs")
            generation_config["response_schema"] = response_schema

        # A. create and call the chat model.
        gemini_messages = self._oai_messages_to_gemini_messages(messages)
        if self.use_vertexai:
            model = GenerativeModel(
                model_name,
                generation_config=GenerationConfig(**generation_config),
                safety_settings=safety_settings,
                system_instruction=system_instruction,
                tool_config=tool_config,
                tools=tools,
            )

            chat = model.start_chat(history=gemini_messages[:-1], response_validation=response_validation)
            response = chat.send_message(gemini_messages[-1].parts, stream=stream, safety_settings=safety_settings)
        else:
            client = genai.Client(api_key=self.api_key, http_options=http_options)
            generate_content_config = GenerateContentConfig(
                safety_settings=safety_settings,
                system_instruction=system_instruction,
                tools=tools,
                tool_config=tool_config,
                **generation_config,
            )
            chat = client.chats.create(model=model_name, config=generate_content_config, history=gemini_messages[:-1])
            response = chat.send_message(message=gemini_messages[-1].parts)

        # Extract text and tools from response
        ans = ""
        random_id = random.randint(0, 10000)
        prev_function_calls = []
        error_finish_reason = None

        if isinstance(response, GenerateContentResponse):
            if len(response.candidates) != 1:
                raise ValueError(
                    f"Unexpected number of candidates in the response. Expected 1, got {len(response.candidates)}"
                )

            # Look at https://cloud.google.com/vertex-ai/generative-ai/docs/reference/python/latest/vertexai.generative_models.FinishReason
            if response.candidates[0].finish_reason and response.candidates[0].finish_reason == FinishReason.RECITATION:
                recitation_part = Part(text="Unsuccessful Finish Reason: RECITATION")
                parts = [recitation_part]
                error_finish_reason = "content_filter"  # As per available finish_reason in Choice
            else:
                parts = response.candidates[0].content.parts
        elif isinstance(response, VertexAIGenerationResponse):  # or hasattr(response, "candidates"):
            # google.generativeai also raises an error len(candidates) != 1:
            if len(response.candidates) != 1:
                raise ValueError(
                    f"Unexpected number of candidates in the response. Expected 1, got {len(response.candidates)}"
                )
            parts = response.candidates[0].content.parts
        else:
            raise ValueError(f"Unexpected response type: {type(response)}")

        for part in parts:
            # Function calls
            if fn_call := part.function_call:
                # If we have a repeated function call, ignore it
                if fn_call not in prev_function_calls:
                    autogen_tool_calls.append(
                        ChatCompletionMessageToolCall(
                            id=str(random_id),
                            function={
                                "name": fn_call.name,
                                "arguments": (
                                    json.dumps({key: val for key, val in fn_call.args.items()})
                                    if fn_call.args is not None
                                    else ""
                                ),
                            },
                            type="function",
                        )
                    )

                    prev_function_calls.append(fn_call)
                    random_id += 1

            # Plain text content
            elif text := part.text:
                ans += text

        # If we have function calls, ignore the text
        # as it can be Gemini guessing the function response
        if len(autogen_tool_calls) != 0:
            ans = ""
        else:
            autogen_tool_calls = None

        if self._response_format and ans:
            try:
                parsed_response = self._convert_json_response(ans)
                ans = _format_json_response(parsed_response, ans)
            except ValueError as e:
                ans = str(e)

        # 3. convert output
        message = ChatCompletionMessage(
            role="assistant", content=ans, function_call=None, tool_calls=autogen_tool_calls
        )
        choices = [
            Choice(
                finish_reason="tool_calls"
                if autogen_tool_calls is not None
                else error_finish_reason
                if error_finish_reason
                else "stop",
                index=0,
                message=message,
            )
        ]

        prompt_tokens = response.usage_metadata.prompt_token_count
        completion_tokens = (
            response.usage_metadata.candidates_token_count if response.usage_metadata.candidates_token_count else 0
        )

        response_oai = ChatCompletion(
            id=str(random.randint(0, 1000)),
            model=model_name,
            created=int(time.time()),
            object="chat.completion",
            choices=choices,
            usage=CompletionUsage(
                prompt_tokens=prompt_tokens,
                completion_tokens=completion_tokens,
                total_tokens=prompt_tokens + completion_tokens,
            ),
            cost=calculate_gemini_cost(self.use_vertexai, prompt_tokens, completion_tokens, model_name),
        )

        return response_oai

    def _extract_system_instruction(self, messages: list[dict]) -> str | None:
        """Extract system instruction if provided."""
        if messages is None or len(messages) == 0 or messages[0].get("role") != "system":
            return None

        message = messages.pop(0)
        content = message["content"]

        # Multi-model uses a list of dictionaries as content with text for the system message
        # Otherwise normal agents will have strings as content
        content = content[0].get("text", "").strip() if isinstance(content, list) else content.strip()

        content = content if len(content) > 0 else None
        return content

    def _oai_content_to_gemini_content(self, message: dict[str, Any]) -> tuple[list[Any], str]:
        """Convert AG2 content to Gemini parts, catering for text and tool calls"""
        rst = []

        if "role" in message and message["role"] == "tool":
            # Tool call recommendation

            function_name = self.tool_call_function_map[message["tool_call_id"]]

            if self.use_vertexai:
                rst.append(
                    VertexAIPart.from_function_response(
                        name=function_name, response={"result": self._to_json_or_str(message["content"])}
                    )
                )
            else:
                rst.append(
                    Part(
                        function_response=FunctionResponse(
                            name=function_name, response={"result": self._to_json_or_str(message["content"])}
                        )
                    )
                )

            return rst, "tool"
        elif "tool_calls" in message and len(message["tool_calls"]) != 0:
            for tool_call in message["tool_calls"]:
                function_id = tool_call["id"]
                function_name = tool_call["function"]["name"]
                self.tool_call_function_map[function_id] = function_name

                if self.use_vertexai:
                    rst.append(
                        VertexAIPart.from_dict({
                            "functionCall": {
                                "name": function_name,
                                "args": json.loads(tool_call["function"]["arguments"]),
                            }
                        })
                    )
                else:
                    rst.append(
                        Part(
                            function_call=FunctionCall(
                                name=function_name,
                                args=json.loads(tool_call["function"]["arguments"]),
                            )
                        )
                    )

            return rst, "tool_call"

        elif isinstance(message["content"], str):
            content = message["content"]
            if content == "":
                content = "empty"  # Empty content is not allowed.
            if self.use_vertexai:
                rst.append(VertexAIPart.from_text(content))
            else:
                rst.append(Part(text=content))

            return rst, "text"

        # For images the message contains a list of text items
        if isinstance(message["content"], list):
            has_image = False
            for msg in message["content"]:
                if isinstance(msg, dict):
                    assert "type" in msg, f"Missing 'type' field in message: {msg}"
                    if msg["type"] == "text":
                        if self.use_vertexai:
                            rst.append(VertexAIPart.from_text(text=msg["text"]))
                        else:
                            rst.append(Part(text=msg["text"]))
                    elif msg["type"] == "image_url":
                        if self.use_vertexai:
                            img_url = msg["image_url"]["url"]
                            img_part = VertexAIPart.from_uri(img_url, mime_type="image/png")
                            rst.append(img_part)
                        else:
                            b64_img = get_image_data(msg["image_url"]["url"])
                            rst.append(Part(inline_data={"mime_type": "image/png", "data": b64_img}))

                        has_image = True
                    else:
                        raise ValueError(f"Unsupported message type: {msg['type']}")
                else:
                    raise ValueError(f"Unsupported message type: {type(msg)}")
            return rst, "image" if has_image else "text"
        else:
            raise Exception("Unable to convert content to Gemini format.")

    def _concat_parts(self, parts: list[Part]) -> list:
        """Concatenate parts with the same type.
        If two adjacent parts both have the "text" attribute, then it will be joined into one part.
        """
        if not parts:
            return []

        concatenated_parts = []
        previous_part = parts[0]

        for current_part in parts[1:]:
            if previous_part.text != "":
                if self.use_vertexai:
                    previous_part = VertexAIPart.from_text(previous_part.text + current_part.text)
                else:
                    previous_part.text += current_part.text
            else:
                concatenated_parts.append(previous_part)
                previous_part = current_part

        if previous_part.text == "":
            if self.use_vertexai:
                previous_part = VertexAIPart.from_text("empty")
            else:
                previous_part.text = "empty"  # Empty content is not allowed.
        concatenated_parts.append(previous_part)

        return concatenated_parts

    def _oai_messages_to_gemini_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
        """Convert messages from OAI format to Gemini format.
        Make sure the "user" role and "model" role are interleaved.
        Also, make sure the last item is from the "user" role.
        """
        rst = []
        for message in messages:
            parts, part_type = self._oai_content_to_gemini_content(message)
            role = "user" if message["role"] in ["user", "system"] else "model"

            if part_type == "text":
                rst.append(
                    VertexAIContent(parts=parts, role=role)
                    if self.use_vertexai
                    else rst.append(Content(parts=parts, role=role))
                )
            elif part_type == "tool":
                # Function responses should be assigned "model" role to keep them separate from function calls
                role = "function" if version.parse(genai.__version__) < version.parse("1.4.0") else "model"
                rst.append(
                    VertexAIContent(parts=parts, role=role)
                    if self.use_vertexai
                    else rst.append(Content(parts=parts, role=role))
                )
            elif part_type == "tool_call":
                # Function calls should be assigned "user" role
                role = "function" if version.parse(genai.__version__) < version.parse("1.4.0") else "user"
                rst.append(
                    VertexAIContent(parts=parts, role=role)
                    if self.use_vertexai
                    else rst.append(Content(parts=parts, role=role))
                )
            elif part_type == "image":
                # Image has multiple parts, some can be text and some can be image based
                text_parts = []
                image_parts = []
                for part in parts:
                    if isinstance(part, Part):
                        # Text or non-Vertex AI image part
                        text_parts.append(part)
                    elif isinstance(part, VertexAIPart):
                        # Image
                        image_parts.append(part)
                    else:
                        raise Exception("Unable to process image part")

                if len(text_parts) > 0:
                    rst.append(
                        VertexAIContent(parts=text_parts, role=role)
                        if self.use_vertexai
                        else rst.append(Content(parts=text_parts, role=role))
                    )

                if len(image_parts) > 0:
                    rst.append(
                        VertexAIContent(parts=image_parts, role=role)
                        if self.use_vertexai
                        else rst.append(Content(parts=image_parts, role=role))
                    )

            if len(rst) != 0 and rst[-1] is None:
                rst.pop()

        # The Gemini is restrict on order of roles, such that
        # 1. The first message must be from the user role.
        # 2. The last message must be from the user role.
        # 3. The messages should be interleaved between user and model.
        # We add a dummy message "start chat" if the first role is not the user.
        # We add a dummy message "continue" if the last role is not the user.
        if rst[0].role != "user":
            text_part, _ = self._oai_content_to_gemini_content({"content": "start chat"})
            rst.insert(
                0,
                VertexAIContent(parts=text_part, role="user")
                if self.use_vertexai
                else Content(parts=text_part, role="user"),
            )

        if rst[-1].role != "user":
            text_part, _ = self._oai_content_to_gemini_content({"content": "continue"})
            rst.append(
                VertexAIContent(parts=text_part, role="user")
                if self.use_vertexai
                else Content(parts=text_part, role="user")
            )

        return rst

    def _convert_json_response(self, response: str) -> Any:
        """Extract and validate JSON response from the output for structured outputs.

        Args:
            response (str): The response from the API.

        Returns:
            Any: The parsed JSON response.
        """
        if not self._response_format:
            return response

        try:
            # Parse JSON and validate against the Pydantic model if Pydantic model was provided
            json_data = json.loads(response)
            if isinstance(self._response_format, dict):
                return json_data
            else:
                return self._response_format.model_validate(json_data)
        except Exception as e:
            raise ValueError(f"Failed to parse response as valid JSON matching the schema for Structured Output: {e!s}")

    @staticmethod
    def _convert_type_null_to_nullable(schema: Any) -> Any:
        """
        Recursively converts all occurrences of {"type": "null"} to {"nullable": True} in a schema.
        """
        if isinstance(schema, dict):
            # If schema matches {"type": "null"}, replace it
            if schema == {"type": "null"}:
                return {"nullable": True}
            # Otherwise, recursively process dictionary
            return {key: GeminiClient._convert_type_null_to_nullable(value) for key, value in schema.items()}
        elif isinstance(schema, list):
            # Recursively process list elements
            return [GeminiClient._convert_type_null_to_nullable(item) for item in schema]
        return schema

    @staticmethod
    def _check_if_prebuilt_google_search_tool_exists(tools: list[dict[str, Any]]) -> bool:
        """Check if the Google Search tool is present in the tools list."""
        exists = False
        for tool in tools:
            if tool["function"]["name"] == "prebuilt_google_search":
                exists = True
                break

        if exists and len(tools) > 1:
            raise ValueError(
                "Google Search tool can be used only by itself. Please remove other tools from the tools list."
            )

        return exists

    @staticmethod
    def _unwrap_references(function_parameters: dict[str, Any]) -> dict[str, Any]:
        if "properties" not in function_parameters:
            return function_parameters

        function_parameters_copy = copy.deepcopy(function_parameters)

        for property_name, property_value in function_parameters["properties"].items():
            if "$defs" in property_value:
                function_parameters_copy["properties"][property_name] = resolve_json_references(property_value)
                function_parameters_copy["properties"][property_name].pop("$defs")

        return function_parameters_copy

    def _tools_to_gemini_tools(self, tools: list[dict[str, Any]]) -> list[Tool]:
        """Create Gemini tools (as typically requires Callables)"""
        if self._check_if_prebuilt_google_search_tool_exists(tools) and not self.use_vertexai:
            return [Tool(google_search=GoogleSearch())]

        functions = []
        for tool in tools:
            if self.use_vertexai:
                tool["function"]["parameters"] = GeminiClient._convert_type_null_to_nullable(
                    tool["function"]["parameters"]
                )
                function_parameters = GeminiClient._unwrap_references(tool["function"]["parameters"])
                function = vaiFunctionDeclaration(
                    name=tool["function"]["name"],
                    description=tool["function"]["description"],
                    parameters=function_parameters,
                )
            else:
                function = GeminiClient._create_gemini_function_declaration(tool)
            functions.append(function)

        if self.use_vertexai:
            return [vaiTool(function_declarations=functions)]
        else:
            return [Tool(function_declarations=functions)]

    @staticmethod
    def _create_gemini_function_declaration(tool: dict) -> FunctionDeclaration:
        function_declaration = FunctionDeclaration()
        function_declaration.name = tool["function"]["name"]
        function_declaration.description = tool["function"]["description"]
        if len(tool["function"]["parameters"]["properties"]) != 0:
            function_declaration.parameters = GeminiClient._create_gemini_function_parameters(
                copy.deepcopy(tool["function"]["parameters"])
            )

        return function_declaration

    @staticmethod
    def _create_gemini_function_declaration_schema(json_data) -> Schema:
        """Recursively creates Schema objects for FunctionDeclaration."""
        param_schema = Schema()
        param_type = json_data["type"]

        """
        TYPE_UNSPECIFIED = 0
        STRING = 1
        INTEGER = 2
        NUMBER = 3
        OBJECT = 4
        ARRAY = 5
        BOOLEAN = 6
        """

        if param_type == "integer":
            param_schema.type = Type.INTEGER
        elif param_type == "number":
            param_schema.type = Type.NUMBER
        elif param_type == "string":
            param_schema.type = Type.STRING
        elif param_type == "boolean":
            param_schema.type = Type.BOOLEAN
        elif param_type == "array":
            param_schema.type = Type.ARRAY
            if "items" in json_data:
                param_schema.items = GeminiClient._create_gemini_function_declaration_schema(json_data["items"])
            else:
                print("Warning: Array schema missing 'items' definition.")
        elif param_type == "object":
            param_schema.type = Type.OBJECT
            param_schema.properties = {}
            if "properties" in json_data:
                for prop_name, prop_data in json_data["properties"].items():
                    param_schema.properties[prop_name] = GeminiClient._create_gemini_function_declaration_schema(
                        prop_data
                    )
                else:
                    print("Warning: Object schema missing 'properties' definition.")

        elif param_type in ("null", "any"):
            param_schema.type = Type.STRING  # Treating these as strings for simplicity
        else:
            print(f"Warning: Unsupported parameter type '{param_type}'.")

        if "description" in json_data:
            param_schema.description = json_data["description"]

        return param_schema

    @staticmethod
    def _create_gemini_function_parameters(function_parameter: dict[str, any]) -> dict[str, any]:
        """Convert function parameters to Gemini format, recursive"""
        function_parameter = GeminiClient._unwrap_references(function_parameter)

        if "type" in function_parameter:
            function_parameter["type"] = function_parameter["type"].upper()
            # If the schema was created from pydantic BaseModel, it will "title" attribute which needs to be removed
            function_parameter.pop("title", None)

        # Parameter properties and items
        if "properties" in function_parameter:
            for key in function_parameter["properties"]:
                function_parameter["properties"][key] = GeminiClient._create_gemini_function_parameters(
                    function_parameter["properties"][key]
                )

        if "items" in function_parameter:
            function_parameter["items"] = GeminiClient._create_gemini_function_parameters(function_parameter["items"])

        # Remove any attributes not needed
        for attr in ["default"]:
            if attr in function_parameter:
                del function_parameter[attr]

        return function_parameter

    @staticmethod
    def _to_vertexai_safety_settings(safety_settings):
        """Convert safety settings to VertexAI format if needed,
        like when specifying them in the OAI_CONFIG_LIST
        """
        if isinstance(safety_settings, list) and all([
            isinstance(safety_setting, dict) and not isinstance(safety_setting, VertexAISafetySetting)
            for safety_setting in safety_settings
        ]):
            vertexai_safety_settings = []
            for safety_setting in safety_settings:
                if safety_setting["category"] not in VertexAIHarmCategory.__members__:
                    invalid_category = safety_setting["category"]
                    logger.error(f"Safety setting category {invalid_category} is invalid")
                elif safety_setting["threshold"] not in VertexAIHarmBlockThreshold.__members__:
                    invalid_threshold = safety_setting["threshold"]
                    logger.error(f"Safety threshold {invalid_threshold} is invalid")
                else:
                    vertexai_safety_setting = VertexAISafetySetting(
                        category=safety_setting["category"],
                        threshold=safety_setting["threshold"],
                    )
                    vertexai_safety_settings.append(vertexai_safety_setting)
            return vertexai_safety_settings
        else:
            return safety_settings

    @staticmethod
    def _to_json_or_str(data: str) -> dict | str:
        try:
            json_data = json.loads(data)
            return json_data
        except (json.JSONDecodeError, ValidationError):
            return data


@require_optional_import(["PIL"], "gemini")
def get_image_data(image_file: str, use_b64=True) -> bytes:
    if image_file.startswith("http://") or image_file.startswith("https://"):
        response = requests.get(image_file)
        content = response.content
    elif re.match(r"data:image/(?:png|jpeg);base64,", image_file):
        return re.sub(r"data:image/(?:png|jpeg);base64,", "", image_file)
    else:
        image = Image.open(image_file).convert("RGB")
        buffered = BytesIO()
        image.save(buffered, format="PNG")
        content = buffered.getvalue()

    if use_b64:
        return base64.b64encode(content).decode("utf-8")
    else:
        return content


def _format_json_response(response: Any, original_answer: str) -> str:
    """Formats the JSON response for structured outputs using the format method if it exists."""
    return response.format() if isinstance(response, FormatterProtocol) else original_answer


def calculate_gemini_cost(use_vertexai: bool, input_tokens: int, output_tokens: int, model_name: str) -> float:
    def total_cost_mil(cost_per_mil_input: float, cost_per_mil_output: float):
        # Cost per million
        return cost_per_mil_input * input_tokens / 1e6 + cost_per_mil_output * output_tokens / 1e6

    def total_cost_k(cost_per_k_input: float, cost_per_k_output: float):
        # Cost per thousand
        return cost_per_k_input * input_tokens / 1e3 + cost_per_k_output * output_tokens / 1e3

    model_name = model_name.lower()
    up_to_128k = input_tokens <= 128000
    up_to_200k = input_tokens <= 200000

    if use_vertexai:
        # Vertex AI pricing - based on Text input
        # https://cloud.google.com/vertex-ai/generative-ai/pricing#vertex-ai-pricing

        if (
            "gemini-2.5-pro-preview-03-25" in model_name
            or "gemini-2.5-pro-exp-03-25" in model_name
            or "gemini-2.5-pro-preview-05-06" in model_name
        ):
            if up_to_200k:
                return total_cost_mil(1.25, 10)
            else:
                return total_cost_mil(2.5, 15)

        elif "gemini-2.5-flash-preview-04-17" in model_name:
            return total_cost_mil(0.15, 0.6)  # NON-THINKING OUTPUT PRICE, $3 FOR THINKING!

        elif "gemini-2.0-flash-lite" in model_name:
            return total_cost_mil(0.075, 0.3)

        elif "gemini-2.0-flash" in model_name:
            return total_cost_mil(0.15, 0.6)

        elif "gemini-1.5-flash" in model_name:
            if up_to_128k:
                return total_cost_k(0.00001875, 0.000075)
            else:
                return total_cost_k(0.0000375, 0.00015)

        elif "gemini-1.5-pro" in model_name:
            if up_to_128k:
                return total_cost_k(0.0003125, 0.00125)
            else:
                return total_cost_k(0.000625, 0.0025)

        elif "gemini-1.0-pro" in model_name:
            return total_cost_k(0.000125, 0.00001875)

        else:
            warnings.warn(
                f"Cost calculation is not implemented for model {model_name}. Cost will be calculated zero.",
                UserWarning,
            )
            return 0

    else:
        # Non-Vertex AI pricing

        if (
            "gemini-2.5-pro-preview-03-25" in model_name
            or "gemini-2.5-pro-exp-03-25" in model_name
            or "gemini-2.5-pro-preview-05-06" in model_name
        ):
            # https://ai.google.dev/gemini-api/docs/pricing#gemini-2.5-pro-preview
            if up_to_200k:
                return total_cost_mil(1.25, 10)
            else:
                return total_cost_mil(2.5, 15)

        elif "gemini-2.5-flash-preview-04-17" in model_name:
            # https://ai.google.dev/gemini-api/docs/pricing#gemini-2.5-flash
            return total_cost_mil(0.15, 0.6)

        elif "gemini-2.0-flash-lite" in model_name:
            # https://ai.google.dev/gemini-api/docs/pricing#gemini-2.0-flash-lite
            return total_cost_mil(0.075, 0.3)

        elif "gemini-2.0-flash" in model_name:
            # https://ai.google.dev/gemini-api/docs/pricing#gemini-2.0-flash
            return total_cost_mil(0.1, 0.4)

        elif "gemini-1.5-flash-8b" in model_name:
            # https://ai.google.dev/pricing#1_5flash-8B
            if up_to_128k:
                return total_cost_mil(0.0375, 0.15)
            else:
                return total_cost_mil(0.075, 0.3)

        elif "gemini-1.5-flash" in model_name:
            # https://ai.google.dev/pricing#1_5flash
            if up_to_128k:
                return total_cost_mil(0.075, 0.3)
            else:
                return total_cost_mil(0.15, 0.6)

        elif "gemini-1.5-pro" in model_name:
            # https://ai.google.dev/pricing#1_5pro
            if up_to_128k:
                return total_cost_mil(1.25, 5.0)
            else:
                return total_cost_mil(2.50, 10.0)

        elif "gemini-1.0-pro" in model_name:
            # https://ai.google.dev/pricing#1_5pro
            return total_cost_mil(0.50, 1.5)

        else:
            warnings.warn(
                f"Cost calculation is not implemented for model {model_name}. Cost will be calculated zero.",
                UserWarning,
            )
            return 0