"""The Model Context Protocol Server implementation. The Model Context Protocol python sdk defines a Server API that provides the MCP message handling logic and error handling. The server implementation provided here is independent of the lower level transport protocol. See https://modelcontextprotocol.io/docs/concepts/architecture#implementation-example """ from collections.abc import Callable, Sequence import json import logging from typing import Any, cast from mcp import types from mcp.server import Server from mcp.server.lowlevel.helper_types import ReadResourceContents from pydantic import AnyUrl import voluptuous as vol from voluptuous_openapi import convert from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import llm from .const import STATELESS_LLM_API _LOGGER = logging.getLogger(__name__) SNAPSHOT_RESOURCE_URI = "homeassistant://assist/context-snapshot" SNAPSHOT_RESOURCE_URL = AnyUrl(SNAPSHOT_RESOURCE_URI) SNAPSHOT_RESOURCE_MIME_TYPE = "text/plain" LIVE_CONTEXT_TOOL_NAME = "GetLiveContext" def _has_live_context_tool(llm_api: llm.APIInstance) -> bool: """Return if the selected API exposes the live context tool.""" return any(tool.name == LIVE_CONTEXT_TOOL_NAME for tool in llm_api.tools) def _format_tool( tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None ) -> types.Tool: """Format tool specification.""" input_schema = convert(tool.parameters, custom_serializer=custom_serializer) return types.Tool( name=tool.name, description=tool.description or "", inputSchema={ "type": "object", "properties": input_schema["properties"], }, ) async def create_server( hass: HomeAssistant, llm_api_id: str | list[str], llm_context: llm.LLMContext ) -> Server: """Create a new Model Context Protocol Server. A Model Context Protocol Server object is associated with a single session. The MCP SDK handles the details of the protocol. """ if llm_api_id == STATELESS_LLM_API: llm_api_id = llm.LLM_API_ASSIST server = Server[Any]("home-assistant") async def get_api_instance() -> llm.APIInstance: """Get the LLM API selected.""" # Backwards compatibility with old MCP Server config return await llm.async_get_api(hass, llm_api_id, llm_context) @server.list_prompts() # type: ignore[no-untyped-call,untyped-decorator] async def handle_list_prompts() -> list[types.Prompt]: llm_api = await get_api_instance() return [ types.Prompt( name=llm_api.api.name, description=f"Default prompt for Home Assistant {llm_api.api.name} API", ) ] @server.get_prompt() # type: ignore[no-untyped-call,untyped-decorator] async def handle_get_prompt( name: str, arguments: dict[str, str] | None ) -> types.GetPromptResult: llm_api = await get_api_instance() if name != llm_api.api.name: raise ValueError(f"Unknown prompt: {name}") return types.GetPromptResult( description=f"Default prompt for Home Assistant {llm_api.api.name} API", messages=[ types.PromptMessage( role="assistant", content=types.TextContent( type="text", text=llm_api.api_prompt, ), ) ], ) @server.list_resources() # type: ignore[no-untyped-call,untyped-decorator] async def handle_list_resources() -> list[types.Resource]: llm_api = await get_api_instance() if not _has_live_context_tool(llm_api): return [] return [ types.Resource( uri=SNAPSHOT_RESOURCE_URL, name="assist_context_snapshot", title="Assist context snapshot", description=( "A snapshot of the current Assist context, matching the" " existing GetLiveContext tool output." ), mimeType=SNAPSHOT_RESOURCE_MIME_TYPE, ) ] @server.read_resource() # type: ignore[no-untyped-call,untyped-decorator] async def handle_read_resource(uri: AnyUrl) -> Sequence[ReadResourceContents]: if str(uri) != SNAPSHOT_RESOURCE_URI: raise ValueError(f"Unknown resource: {uri}") llm_api = await get_api_instance() if not _has_live_context_tool(llm_api): raise ValueError(f"Unknown resource: {uri}") tool_response = await llm_api.async_call_tool( llm.ToolInput(tool_name=LIVE_CONTEXT_TOOL_NAME, tool_args={}) ) if not tool_response.get("success"): raise HomeAssistantError(cast(str, tool_response["error"])) return [ ReadResourceContents( content=cast(str, tool_response["result"]), mime_type=SNAPSHOT_RESOURCE_MIME_TYPE, ) ] @server.list_tools() # type: ignore[no-untyped-call,untyped-decorator] async def list_tools() -> list[types.Tool]: """List available time tools.""" llm_api = await get_api_instance() return [_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools] @server.call_tool() # type: ignore[untyped-decorator] async def call_tool(name: str, arguments: dict) -> Sequence[types.TextContent]: """Handle calling tools.""" llm_api = await get_api_instance() tool_input = llm.ToolInput(tool_name=name, tool_args=arguments) _LOGGER.debug("Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args) try: tool_response = await llm_api.async_call_tool(tool_input) except (HomeAssistantError, vol.Invalid) as e: raise HomeAssistantError(f"Error calling tool: {e}") from e return [ types.TextContent( type="text", text=json.dumps(tool_response, ensure_ascii=False), ) ] return server