mirror of
https://github.com/home-assistant/supervisor.git
synced 2026-02-14 23:19:37 +00:00
Fix MCP API proxy support for streaming and headers (#6461)
* Fix MCP API proxy support for streaming and headers This commit fixes two issues with using the core API core/api/mcp through the API proxy: 1. **Streaming support**: The proxy now detects text/event-stream responses and properly streams them instead of buffering all data. This is required for MCP's Server-Sent Events (SSE) transport. 2. **Header forwarding**: Added MCP-required headers to the forwarded headers: - Accept: Required for content negotiation - Last-Event-ID: Required for resuming broken SSE connections - Mcp-Session-Id: Required for session management across requests The proxy now also preserves MCP-related response headers (Mcp-Session-Id) and sets X-Accel-Buffering to "no" for streaming responses to prevent buffering by intermediate proxies. Tests added to verify: - MCP headers are properly forwarded to Home Assistant - Streaming responses (text/event-stream) are handled correctly - Response headers are preserved * Refactor: reuse stream logic for SSE responses (#3) * Fix ruff format + cover streaming payload error * Fix merge error * Address review comments (headers / streaming proxy) (#4) * Address review: header handling for streaming/non-streaming * Forward MCP-Protocol-Version and Origin headers * Do not forward Origin header through API proxy (#5) --------- Co-authored-by: Stefan Agner <stefan@agner.ch>
This commit is contained in:
@@ -21,7 +21,13 @@ from ..utils.logging import AddonLoggerAdapter
|
||||
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
FORWARD_HEADERS = ("X-Speech-Content",)
|
||||
FORWARD_HEADERS = (
|
||||
"X-Speech-Content",
|
||||
"Accept",
|
||||
"Last-Event-ID",
|
||||
"Mcp-Session-Id",
|
||||
"MCP-Protocol-Version",
|
||||
)
|
||||
HEADER_HA_ACCESS = "X-Ha-Access"
|
||||
|
||||
# Maximum message size for websocket messages from Home Assistant.
|
||||
@@ -35,6 +41,38 @@ MAX_MESSAGE_SIZE_FROM_CORE = 64 * 1024 * 1024
|
||||
class APIProxy(CoreSysAttributes):
|
||||
"""API Proxy for Home Assistant."""
|
||||
|
||||
async def _stream_client_response(
|
||||
self,
|
||||
request: web.Request,
|
||||
client: aiohttp.ClientResponse,
|
||||
*,
|
||||
content_type: str,
|
||||
headers_to_copy: tuple[str, ...] = (),
|
||||
) -> web.StreamResponse:
|
||||
"""Stream an upstream aiohttp response to the caller.
|
||||
|
||||
Used for event streams (e.g. Home Assistant /api/stream) and for SSE endpoints
|
||||
such as MCP (text/event-stream).
|
||||
"""
|
||||
response = web.StreamResponse(status=client.status)
|
||||
response.content_type = content_type
|
||||
|
||||
for header in headers_to_copy:
|
||||
if header in client.headers:
|
||||
response.headers[header] = client.headers[header]
|
||||
|
||||
response.headers["X-Accel-Buffering"] = "no"
|
||||
|
||||
try:
|
||||
await response.prepare(request)
|
||||
async for data in client.content:
|
||||
await response.write(data)
|
||||
except (aiohttp.ClientError, aiohttp.ClientPayloadError):
|
||||
# Client disconnected or upstream closed
|
||||
pass
|
||||
|
||||
return response
|
||||
|
||||
def _check_access(self, request: web.Request):
|
||||
"""Check the Supervisor token."""
|
||||
if AUTHORIZATION in request.headers:
|
||||
@@ -95,16 +133,11 @@ class APIProxy(CoreSysAttributes):
|
||||
|
||||
_LOGGER.info("Home Assistant EventStream start")
|
||||
async with self._api_client(request, "stream", timeout=None) as client:
|
||||
response = web.StreamResponse()
|
||||
response.content_type = request.headers.get(CONTENT_TYPE, "")
|
||||
try:
|
||||
response.headers["X-Accel-Buffering"] = "no"
|
||||
await response.prepare(request)
|
||||
async for data in client.content:
|
||||
await response.write(data)
|
||||
|
||||
except (aiohttp.ClientError, aiohttp.ClientPayloadError):
|
||||
pass
|
||||
response = await self._stream_client_response(
|
||||
request,
|
||||
client,
|
||||
content_type=request.headers.get(CONTENT_TYPE, ""),
|
||||
)
|
||||
|
||||
_LOGGER.info("Home Assistant EventStream close")
|
||||
return response
|
||||
@@ -118,10 +151,31 @@ class APIProxy(CoreSysAttributes):
|
||||
# Normal request
|
||||
path = request.match_info.get("path", "")
|
||||
async with self._api_client(request, path) as client:
|
||||
# Check if this is a streaming response (e.g., MCP SSE endpoints)
|
||||
if client.content_type == "text/event-stream":
|
||||
return await self._stream_client_response(
|
||||
request,
|
||||
client,
|
||||
content_type=client.content_type,
|
||||
headers_to_copy=(
|
||||
"Cache-Control",
|
||||
"Mcp-Session-Id",
|
||||
),
|
||||
)
|
||||
|
||||
# Non-streaming response
|
||||
data = await client.read()
|
||||
return web.Response(
|
||||
response = web.Response(
|
||||
body=data, status=client.status, content_type=client.content_type
|
||||
)
|
||||
# Copy selected headers from the upstream response
|
||||
for header in (
|
||||
"Cache-Control",
|
||||
"Mcp-Session-Id",
|
||||
):
|
||||
if header in client.headers:
|
||||
response.headers[header] = client.headers[header]
|
||||
return response
|
||||
|
||||
async def _websocket_client(self) -> ClientWebSocketResponse:
|
||||
"""Initialize a WebSocket API connection."""
|
||||
|
||||
@@ -9,7 +9,7 @@ import logging
|
||||
from typing import Any, cast
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from aiohttp import ClientWebSocketResponse, WSCloseCode
|
||||
from aiohttp import ClientPayloadError, ClientWebSocketResponse, WSCloseCode
|
||||
from aiohttp.http_websocket import WSMessage, WSMsgType
|
||||
from aiohttp.test_utils import TestClient
|
||||
import pytest
|
||||
@@ -326,3 +326,129 @@ async def test_api_proxy_delete_request(
|
||||
assert response.status == 200
|
||||
assert await response.text() == '{"result": "ok"}'
|
||||
assert response.content_type == "application/json"
|
||||
|
||||
|
||||
async def test_api_proxy_mcp_headers_forwarded(
|
||||
api_client: TestClient,
|
||||
install_addon_example: Addon,
|
||||
):
|
||||
"""Test that MCP headers are forwarded to Home Assistant."""
|
||||
install_addon_example.persist[ATTR_ACCESS_TOKEN] = "abc123"
|
||||
install_addon_example.data["homeassistant_api"] = True
|
||||
|
||||
with patch.object(HomeAssistantAPI, "make_request") as make_request:
|
||||
# Mock the response from make_request
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status = 200
|
||||
mock_response.content_type = "application/json"
|
||||
mock_response.read.return_value = b"mocked response"
|
||||
mock_response.headers = {"Mcp-Session-Id": "test-session-123"}
|
||||
make_request.return_value.__aenter__.return_value = mock_response
|
||||
|
||||
response = await api_client.get(
|
||||
"/core/api/mcp",
|
||||
headers={
|
||||
"Authorization": "Bearer abc123",
|
||||
"Accept": "text/event-stream",
|
||||
"Last-Event-ID": "5",
|
||||
"Mcp-Session-Id": "test-session-123",
|
||||
},
|
||||
)
|
||||
|
||||
# Verify headers were forwarded in the request
|
||||
assert make_request.call_args[1]["headers"]["Accept"] == "text/event-stream"
|
||||
assert make_request.call_args[1]["headers"]["Last-Event-ID"] == "5"
|
||||
assert (
|
||||
make_request.call_args[1]["headers"]["Mcp-Session-Id"] == "test-session-123"
|
||||
)
|
||||
|
||||
# Verify response headers are preserved
|
||||
assert response.status == 200
|
||||
assert response.headers.get("Mcp-Session-Id") == "test-session-123"
|
||||
|
||||
|
||||
async def test_api_proxy_streaming_response(
|
||||
api_client: TestClient,
|
||||
install_addon_example: Addon,
|
||||
):
|
||||
"""Test that streaming responses (text/event-stream) are handled properly."""
|
||||
install_addon_example.persist[ATTR_ACCESS_TOKEN] = "abc123"
|
||||
install_addon_example.data["homeassistant_api"] = True
|
||||
|
||||
async def mock_content_iter():
|
||||
"""Mock async iterator for streaming content."""
|
||||
yield b"data: event1\n\n"
|
||||
yield b"data: event2\n\n"
|
||||
yield b"data: event3\n\n"
|
||||
|
||||
with patch.object(HomeAssistantAPI, "make_request") as make_request:
|
||||
# Mock the response from make_request
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status = 200
|
||||
mock_response.content_type = "text/event-stream"
|
||||
mock_response.headers = {
|
||||
"Cache-Control": "no-cache",
|
||||
"Mcp-Session-Id": "session-456",
|
||||
}
|
||||
mock_response.content = mock_content_iter()
|
||||
make_request.return_value.__aenter__.return_value = mock_response
|
||||
|
||||
response = await api_client.get(
|
||||
"/core/api/mcp",
|
||||
headers={
|
||||
"Authorization": "Bearer abc123",
|
||||
"Accept": "text/event-stream",
|
||||
},
|
||||
)
|
||||
|
||||
# Verify it's a streaming response
|
||||
assert response.status == 200
|
||||
assert response.content_type == "text/event-stream"
|
||||
assert response.headers.get("X-Accel-Buffering") == "no"
|
||||
assert response.headers.get("Mcp-Session-Id") == "session-456"
|
||||
|
||||
# Read the streamed content
|
||||
content = await response.read()
|
||||
assert b"data: event1\n\n" in content
|
||||
assert b"data: event2\n\n" in content
|
||||
assert b"data: event3\n\n" in content
|
||||
|
||||
|
||||
async def test_api_proxy_streaming_response_client_payload_error(
|
||||
api_client: TestClient,
|
||||
install_addon_example: Addon,
|
||||
):
|
||||
"""Test that client payload errors during streaming are handled gracefully."""
|
||||
install_addon_example.persist[ATTR_ACCESS_TOKEN] = "abc123"
|
||||
install_addon_example.data["homeassistant_api"] = True
|
||||
|
||||
async def mock_content_iter_error():
|
||||
yield b"data: event1\n\n"
|
||||
raise ClientPayloadError("boom")
|
||||
|
||||
with patch.object(HomeAssistantAPI, "make_request") as make_request:
|
||||
mock_response = AsyncMock()
|
||||
mock_response.status = 200
|
||||
mock_response.content_type = "text/event-stream"
|
||||
mock_response.headers = {
|
||||
"Cache-Control": "no-cache",
|
||||
"Mcp-Session-Id": "session-789",
|
||||
}
|
||||
mock_response.content = mock_content_iter_error()
|
||||
make_request.return_value.__aenter__.return_value = mock_response
|
||||
|
||||
response = await api_client.get(
|
||||
"/core/api/mcp",
|
||||
headers={
|
||||
"Authorization": "Bearer abc123",
|
||||
"Accept": "text/event-stream",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status == 200
|
||||
assert response.content_type == "text/event-stream"
|
||||
assert response.headers.get("X-Accel-Buffering") == "no"
|
||||
assert response.headers.get("Mcp-Session-Id") == "session-789"
|
||||
|
||||
content = await response.read()
|
||||
assert b"data: event1\n\n" in content
|
||||
|
||||
Reference in New Issue
Block a user