diff --git a/src/mcp_api/config.py b/src/mcp_api/config.py index ed5de37..ab7ec43 100644 --- a/src/mcp_api/config.py +++ b/src/mcp_api/config.py @@ -5,14 +5,19 @@ from __future__ import annotations import json from dataclasses import dataclass, field from pathlib import Path +from typing import Literal @dataclass class ServerConfig: name: str - command: str + transport: Literal["stdio", "streamable_http"] = "stdio" + # stdio + command: str | None = None args: list[str] = field(default_factory=list) env: dict[str, str] | None = None + # streamable_http + url: str | None = None def load_claude_json(path: Path | None = None) -> dict[str, ServerConfig]: @@ -27,10 +32,19 @@ def from_dict(servers: dict) -> dict[str, ServerConfig]: """Parse a dict in the claude.json mcpServers format.""" result: dict[str, ServerConfig] = {} for name, cfg in servers.items(): - result[name] = ServerConfig( - name=name, - command=cfg["command"], - args=cfg.get("args", []), - env=cfg.get("env"), - ) + transport = cfg.get("transport", "stdio") + if transport == "streamable_http": + result[name] = ServerConfig( + name=name, + transport="streamable_http", + url=cfg["url"], + ) + else: + result[name] = ServerConfig( + name=name, + transport="stdio", + command=cfg["command"], + args=cfg.get("args", []), + env=cfg.get("env"), + ) return result diff --git a/src/mcp_api/session.py b/src/mcp_api/session.py index 12364bf..a79b2ec 100644 --- a/src/mcp_api/session.py +++ b/src/mcp_api/session.py @@ -7,6 +7,7 @@ from contextlib import AsyncExitStack from mcp import ClientSession, StdioServerParameters from mcp import types from mcp.client.stdio import stdio_client +from mcp.client.streamable_http import streamable_http_client from .config import ServerConfig @@ -20,12 +21,18 @@ class ServerSession: self._session: ClientSession | None = None async def connect(self) -> None: - params = StdioServerParameters( - command=self._config.command, - args=self._config.args, - env=self._config.env, - ) - read, write = await self._stack.enter_async_context(stdio_client(params)) + if self._config.transport == "streamable_http": + read, write, _ = await self._stack.enter_async_context( + streamable_http_client(self._config.url) + ) + else: + params = StdioServerParameters( + command=self._config.command, + args=self._config.args, + env=self._config.env, + ) + read, write = await self._stack.enter_async_context(stdio_client(params)) + self._session = await self._stack.enter_async_context( ClientSession(read, write) )