mirror of
https://github.com/home-assistant/core.git
synced 2026-05-17 22:10:57 +01:00
171 lines
6.0 KiB
Python
171 lines
6.0 KiB
Python
"""The Model Context Protocol Server implementation.
|
|
|
|
The Model Context Protocol python sdk defines a Server API that provides the
|
|
MCP message handling logic and error handling. The server implementation provided
|
|
here is independent of the lower level transport protocol.
|
|
|
|
See https://modelcontextprotocol.io/docs/concepts/architecture#implementation-example
|
|
"""
|
|
|
|
from collections.abc import Callable, Sequence
|
|
import json
|
|
import logging
|
|
from typing import Any, cast
|
|
|
|
from mcp import types
|
|
from mcp.server import Server
|
|
from mcp.server.lowlevel.helper_types import ReadResourceContents
|
|
from pydantic import AnyUrl
|
|
import voluptuous as vol
|
|
from voluptuous_openapi import convert
|
|
|
|
from homeassistant.core import HomeAssistant
|
|
from homeassistant.exceptions import HomeAssistantError
|
|
from homeassistant.helpers import llm
|
|
|
|
from .const import STATELESS_LLM_API
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
SNAPSHOT_RESOURCE_URI = "homeassistant://assist/context-snapshot"
|
|
SNAPSHOT_RESOURCE_URL = AnyUrl(SNAPSHOT_RESOURCE_URI)
|
|
SNAPSHOT_RESOURCE_MIME_TYPE = "text/plain"
|
|
LIVE_CONTEXT_TOOL_NAME = "GetLiveContext"
|
|
|
|
|
|
def _has_live_context_tool(llm_api: llm.APIInstance) -> bool:
|
|
"""Return if the selected API exposes the live context tool."""
|
|
return any(tool.name == LIVE_CONTEXT_TOOL_NAME for tool in llm_api.tools)
|
|
|
|
|
|
def _format_tool(
|
|
tool: llm.Tool, custom_serializer: Callable[[Any], Any] | None
|
|
) -> types.Tool:
|
|
"""Format tool specification."""
|
|
input_schema = convert(tool.parameters, custom_serializer=custom_serializer)
|
|
return types.Tool(
|
|
name=tool.name,
|
|
description=tool.description or "",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": input_schema["properties"],
|
|
},
|
|
)
|
|
|
|
|
|
async def create_server(
|
|
hass: HomeAssistant, llm_api_id: str | list[str], llm_context: llm.LLMContext
|
|
) -> Server:
|
|
"""Create a new Model Context Protocol Server.
|
|
|
|
A Model Context Protocol Server object is associated with a single session.
|
|
The MCP SDK handles the details of the protocol.
|
|
"""
|
|
if llm_api_id == STATELESS_LLM_API:
|
|
llm_api_id = llm.LLM_API_ASSIST
|
|
|
|
server = Server[Any]("home-assistant")
|
|
|
|
async def get_api_instance() -> llm.APIInstance:
|
|
"""Get the LLM API selected."""
|
|
# Backwards compatibility with old MCP Server config
|
|
return await llm.async_get_api(hass, llm_api_id, llm_context)
|
|
|
|
@server.list_prompts() # type: ignore[no-untyped-call,untyped-decorator]
|
|
async def handle_list_prompts() -> list[types.Prompt]:
|
|
llm_api = await get_api_instance()
|
|
return [
|
|
types.Prompt(
|
|
name=llm_api.api.name,
|
|
description=f"Default prompt for Home Assistant {llm_api.api.name} API",
|
|
)
|
|
]
|
|
|
|
@server.get_prompt() # type: ignore[no-untyped-call,untyped-decorator]
|
|
async def handle_get_prompt(
|
|
name: str, arguments: dict[str, str] | None
|
|
) -> types.GetPromptResult:
|
|
llm_api = await get_api_instance()
|
|
if name != llm_api.api.name:
|
|
raise ValueError(f"Unknown prompt: {name}")
|
|
|
|
return types.GetPromptResult(
|
|
description=f"Default prompt for Home Assistant {llm_api.api.name} API",
|
|
messages=[
|
|
types.PromptMessage(
|
|
role="assistant",
|
|
content=types.TextContent(
|
|
type="text",
|
|
text=llm_api.api_prompt,
|
|
),
|
|
)
|
|
],
|
|
)
|
|
|
|
@server.list_resources() # type: ignore[no-untyped-call,untyped-decorator]
|
|
async def handle_list_resources() -> list[types.Resource]:
|
|
llm_api = await get_api_instance()
|
|
if not _has_live_context_tool(llm_api):
|
|
return []
|
|
|
|
return [
|
|
types.Resource(
|
|
uri=SNAPSHOT_RESOURCE_URL,
|
|
name="assist_context_snapshot",
|
|
title="Assist context snapshot",
|
|
description=(
|
|
"A snapshot of the current Assist context, matching the"
|
|
" existing GetLiveContext tool output."
|
|
),
|
|
mimeType=SNAPSHOT_RESOURCE_MIME_TYPE,
|
|
)
|
|
]
|
|
|
|
@server.read_resource() # type: ignore[no-untyped-call,untyped-decorator]
|
|
async def handle_read_resource(uri: AnyUrl) -> Sequence[ReadResourceContents]:
|
|
if str(uri) != SNAPSHOT_RESOURCE_URI:
|
|
raise ValueError(f"Unknown resource: {uri}")
|
|
|
|
llm_api = await get_api_instance()
|
|
if not _has_live_context_tool(llm_api):
|
|
raise ValueError(f"Unknown resource: {uri}")
|
|
|
|
tool_response = await llm_api.async_call_tool(
|
|
llm.ToolInput(tool_name=LIVE_CONTEXT_TOOL_NAME, tool_args={})
|
|
)
|
|
if not tool_response.get("success"):
|
|
raise HomeAssistantError(cast(str, tool_response["error"]))
|
|
|
|
return [
|
|
ReadResourceContents(
|
|
content=cast(str, tool_response["result"]),
|
|
mime_type=SNAPSHOT_RESOURCE_MIME_TYPE,
|
|
)
|
|
]
|
|
|
|
@server.list_tools() # type: ignore[no-untyped-call,untyped-decorator]
|
|
async def list_tools() -> list[types.Tool]:
|
|
"""List available time tools."""
|
|
llm_api = await get_api_instance()
|
|
return [_format_tool(tool, llm_api.custom_serializer) for tool in llm_api.tools]
|
|
|
|
@server.call_tool() # type: ignore[untyped-decorator]
|
|
async def call_tool(name: str, arguments: dict) -> Sequence[types.TextContent]:
|
|
"""Handle calling tools."""
|
|
llm_api = await get_api_instance()
|
|
tool_input = llm.ToolInput(tool_name=name, tool_args=arguments)
|
|
_LOGGER.debug("Tool call: %s(%s)", tool_input.tool_name, tool_input.tool_args)
|
|
|
|
try:
|
|
tool_response = await llm_api.async_call_tool(tool_input)
|
|
except (HomeAssistantError, vol.Invalid) as e:
|
|
raise HomeAssistantError(f"Error calling tool: {e}") from e
|
|
return [
|
|
types.TextContent(
|
|
type="text",
|
|
text=json.dumps(tool_response, ensure_ascii=False),
|
|
)
|
|
]
|
|
|
|
return server
|