# Copyright (c) 2023 - 2025, AG2ai, Inc., AG2ai open-source projects maintainers and core contributors # # SPDX-License-Identifier: Apache-2.0 # # Portions derived from https://github.com/microsoft/autogen are under the MIT License. # SPDX-License-Identifier: MIT from __future__ import annotations import datetime import json import sys import uuid from dataclasses import dataclass from types import TracebackType from typing import Any, Optional, cast from ...doc_utils import export_module if sys.version_info >= (3, 11): from typing import Self else: from typing_extensions import Self import requests from requests.adapters import HTTPAdapter, Retry from ...import_utils import optional_import_block, require_optional_import from .base import JupyterConnectionInfo with optional_import_block(): import websocket from websocket import WebSocket @export_module("autogen.coding.jupyter") class JupyterClient: def __init__(self, connection_info: JupyterConnectionInfo): """(Experimental) A client for communicating with a Jupyter gateway server. Args: connection_info (JupyterConnectionInfo): Connection information """ self._connection_info = connection_info self._session = requests.Session() retries = Retry(total=5, backoff_factor=0.1) self._session.mount("http://", HTTPAdapter(max_retries=retries)) def _get_headers(self) -> dict[str, str]: if self._connection_info.token is None: return {} return {"Authorization": f"token {self._connection_info.token}"} def _get_api_base_url(self) -> str: protocol = "https" if self._connection_info.use_https else "http" port = f":{self._connection_info.port}" if self._connection_info.port else "" return f"{protocol}://{self._connection_info.host}{port}" def _get_ws_base_url(self) -> str: port = f":{self._connection_info.port}" if self._connection_info.port else "" return f"ws://{self._connection_info.host}{port}" def list_kernel_specs(self) -> dict[str, dict[str, str]]: response = self._session.get(f"{self._get_api_base_url()}/api/kernelspecs", headers=self._get_headers()) return cast(dict[str, dict[str, str]], response.json()) def list_kernels(self) -> list[dict[str, str]]: response = self._session.get(f"{self._get_api_base_url()}/api/kernels", headers=self._get_headers()) return cast(list[dict[str, str]], response.json()) def start_kernel(self, kernel_spec_name: str) -> str: """Start a new kernel. Args: kernel_spec_name (str): Name of the kernel spec to start Returns: str: ID of the started kernel """ response = self._session.post( f"{self._get_api_base_url()}/api/kernels", headers=self._get_headers(), json={"name": kernel_spec_name}, ) return cast(str, response.json()["id"]) def delete_kernel(self, kernel_id: str) -> None: response = self._session.delete( f"{self._get_api_base_url()}/api/kernels/{kernel_id}", headers=self._get_headers() ) response.raise_for_status() def restart_kernel(self, kernel_id: str) -> None: response = self._session.post( f"{self._get_api_base_url()}/api/kernels/{kernel_id}/restart", headers=self._get_headers() ) response.raise_for_status() @require_optional_import("websocket", "jupyter-executor") def get_kernel_client(self, kernel_id: str) -> JupyterKernelClient: ws_url = f"{self._get_ws_base_url()}/api/kernels/{kernel_id}/channels" ws = websocket.create_connection(ws_url, header=self._get_headers()) return JupyterKernelClient(ws) @require_optional_import("websocket", "jupyter-executor") class JupyterKernelClient: """(Experimental) A client for communicating with a Jupyter kernel.""" @dataclass class ExecutionResult: @dataclass class DataItem: mime_type: str data: str is_ok: bool output: str data_items: list[DataItem] def __init__(self, websocket: WebSocket): # type: ignore[no-any-unimported] self._session_id: str = uuid.uuid4().hex self._websocket: WebSocket = websocket # type: ignore[no-any-unimported] def __enter__(self) -> Self: return self def __exit__( self, exc_type: Optional[type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] ) -> None: self.stop() def stop(self) -> None: self._websocket.close() def _send_message(self, *, content: dict[str, Any], channel: str, message_type: str) -> str: timestamp = datetime.datetime.now().isoformat() message_id = uuid.uuid4().hex message = { "header": { "username": "autogen", "version": "5.0", "session": self._session_id, "msg_id": message_id, "msg_type": message_type, "date": timestamp, }, "parent_header": {}, "channel": channel, "content": content, "metadata": {}, "buffers": {}, } self._websocket.send_text(json.dumps(message)) return message_id def _receive_message(self, timeout_seconds: Optional[float]) -> Optional[dict[str, Any]]: self._websocket.settimeout(timeout_seconds) try: data = self._websocket.recv() if isinstance(data, bytes): data = data.decode("utf-8") return cast(dict[str, Any], json.loads(data)) except websocket.WebSocketTimeoutException: return None def wait_for_ready(self, timeout_seconds: Optional[float] = None) -> bool: message_id = self._send_message(content={}, channel="shell", message_type="kernel_info_request") while True: message = self._receive_message(timeout_seconds) # This means we timed out with no new messages. if message is None: return False if ( message.get("parent_header", {}).get("msg_id") == message_id and message["msg_type"] == "kernel_info_reply" ): return True def execute(self, code: str, timeout_seconds: Optional[float] = None) -> ExecutionResult: message_id = self._send_message( content={ "code": code, "silent": False, "store_history": True, "user_expressions": {}, "allow_stdin": False, "stop_on_error": True, }, channel="shell", message_type="execute_request", ) text_output = [] data_output = [] while True: message = self._receive_message(timeout_seconds) if message is None: return JupyterKernelClient.ExecutionResult( is_ok=False, output="ERROR: Timeout waiting for output from code block.", data_items=[] ) # Ignore messages that are not for this execution. if message.get("parent_header", {}).get("msg_id") != message_id: continue msg_type = message["msg_type"] content = message["content"] if msg_type in ["execute_result", "display_data"]: for data_type, data in content["data"].items(): if data_type == "text/plain": text_output.append(data) elif data_type.startswith("image/") or data_type == "text/html": data_output.append(self.ExecutionResult.DataItem(mime_type=data_type, data=data)) else: text_output.append(json.dumps(data)) elif msg_type == "stream": text_output.append(content["text"]) elif msg_type == "error": # Output is an error. return JupyterKernelClient.ExecutionResult( is_ok=False, output=f"ERROR: {content['ename']}: {content['evalue']}\n{content['traceback']}", data_items=[], ) if msg_type == "status" and content["execution_state"] == "idle": break return JupyterKernelClient.ExecutionResult( is_ok=True, output="\n".join([str(output) for output in text_output]), data_items=data_output )