diff --git a/supervisor/api/proxy.py b/supervisor/api/proxy.py index 03b8614f6..4bb5fd273 100644 --- a/supervisor/api/proxy.py +++ b/supervisor/api/proxy.py @@ -21,7 +21,13 @@ from ..utils.logging import AddonLoggerAdapter _LOGGER: logging.Logger = logging.getLogger(__name__) -FORWARD_HEADERS = ("X-Speech-Content",) +FORWARD_HEADERS = ( + "X-Speech-Content", + "Accept", + "Last-Event-ID", + "Mcp-Session-Id", + "MCP-Protocol-Version", +) HEADER_HA_ACCESS = "X-Ha-Access" # Maximum message size for websocket messages from Home Assistant. @@ -35,6 +41,38 @@ MAX_MESSAGE_SIZE_FROM_CORE = 64 * 1024 * 1024 class APIProxy(CoreSysAttributes): """API Proxy for Home Assistant.""" + async def _stream_client_response( + self, + request: web.Request, + client: aiohttp.ClientResponse, + *, + content_type: str, + headers_to_copy: tuple[str, ...] = (), + ) -> web.StreamResponse: + """Stream an upstream aiohttp response to the caller. + + Used for event streams (e.g. Home Assistant /api/stream) and for SSE endpoints + such as MCP (text/event-stream). + """ + response = web.StreamResponse(status=client.status) + response.content_type = content_type + + for header in headers_to_copy: + if header in client.headers: + response.headers[header] = client.headers[header] + + response.headers["X-Accel-Buffering"] = "no" + + try: + await response.prepare(request) + async for data in client.content: + await response.write(data) + except (aiohttp.ClientError, aiohttp.ClientPayloadError): + # Client disconnected or upstream closed + pass + + return response + def _check_access(self, request: web.Request): """Check the Supervisor token.""" if AUTHORIZATION in request.headers: @@ -95,16 +133,11 @@ class APIProxy(CoreSysAttributes): _LOGGER.info("Home Assistant EventStream start") async with self._api_client(request, "stream", timeout=None) as client: - response = web.StreamResponse() - response.content_type = request.headers.get(CONTENT_TYPE, "") - try: - response.headers["X-Accel-Buffering"] = "no" - await response.prepare(request) - async for data in client.content: - await response.write(data) - - except (aiohttp.ClientError, aiohttp.ClientPayloadError): - pass + response = await self._stream_client_response( + request, + client, + content_type=request.headers.get(CONTENT_TYPE, ""), + ) _LOGGER.info("Home Assistant EventStream close") return response @@ -118,10 +151,31 @@ class APIProxy(CoreSysAttributes): # Normal request path = request.match_info.get("path", "") async with self._api_client(request, path) as client: + # Check if this is a streaming response (e.g., MCP SSE endpoints) + if client.content_type == "text/event-stream": + return await self._stream_client_response( + request, + client, + content_type=client.content_type, + headers_to_copy=( + "Cache-Control", + "Mcp-Session-Id", + ), + ) + + # Non-streaming response data = await client.read() - return web.Response( + response = web.Response( body=data, status=client.status, content_type=client.content_type ) + # Copy selected headers from the upstream response + for header in ( + "Cache-Control", + "Mcp-Session-Id", + ): + if header in client.headers: + response.headers[header] = client.headers[header] + return response async def _websocket_client(self) -> ClientWebSocketResponse: """Initialize a WebSocket API connection.""" diff --git a/tests/api/test_proxy.py b/tests/api/test_proxy.py index 466b5ad12..56824ffcb 100644 --- a/tests/api/test_proxy.py +++ b/tests/api/test_proxy.py @@ -9,7 +9,7 @@ import logging from typing import Any, cast from unittest.mock import AsyncMock, patch -from aiohttp import ClientWebSocketResponse, WSCloseCode +from aiohttp import ClientPayloadError, ClientWebSocketResponse, WSCloseCode from aiohttp.http_websocket import WSMessage, WSMsgType from aiohttp.test_utils import TestClient import pytest @@ -326,3 +326,129 @@ async def test_api_proxy_delete_request( assert response.status == 200 assert await response.text() == '{"result": "ok"}' assert response.content_type == "application/json" + + +async def test_api_proxy_mcp_headers_forwarded( + api_client: TestClient, + install_addon_example: Addon, +): + """Test that MCP headers are forwarded to Home Assistant.""" + install_addon_example.persist[ATTR_ACCESS_TOKEN] = "abc123" + install_addon_example.data["homeassistant_api"] = True + + with patch.object(HomeAssistantAPI, "make_request") as make_request: + # Mock the response from make_request + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.content_type = "application/json" + mock_response.read.return_value = b"mocked response" + mock_response.headers = {"Mcp-Session-Id": "test-session-123"} + make_request.return_value.__aenter__.return_value = mock_response + + response = await api_client.get( + "/core/api/mcp", + headers={ + "Authorization": "Bearer abc123", + "Accept": "text/event-stream", + "Last-Event-ID": "5", + "Mcp-Session-Id": "test-session-123", + }, + ) + + # Verify headers were forwarded in the request + assert make_request.call_args[1]["headers"]["Accept"] == "text/event-stream" + assert make_request.call_args[1]["headers"]["Last-Event-ID"] == "5" + assert ( + make_request.call_args[1]["headers"]["Mcp-Session-Id"] == "test-session-123" + ) + + # Verify response headers are preserved + assert response.status == 200 + assert response.headers.get("Mcp-Session-Id") == "test-session-123" + + +async def test_api_proxy_streaming_response( + api_client: TestClient, + install_addon_example: Addon, +): + """Test that streaming responses (text/event-stream) are handled properly.""" + install_addon_example.persist[ATTR_ACCESS_TOKEN] = "abc123" + install_addon_example.data["homeassistant_api"] = True + + async def mock_content_iter(): + """Mock async iterator for streaming content.""" + yield b"data: event1\n\n" + yield b"data: event2\n\n" + yield b"data: event3\n\n" + + with patch.object(HomeAssistantAPI, "make_request") as make_request: + # Mock the response from make_request + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.content_type = "text/event-stream" + mock_response.headers = { + "Cache-Control": "no-cache", + "Mcp-Session-Id": "session-456", + } + mock_response.content = mock_content_iter() + make_request.return_value.__aenter__.return_value = mock_response + + response = await api_client.get( + "/core/api/mcp", + headers={ + "Authorization": "Bearer abc123", + "Accept": "text/event-stream", + }, + ) + + # Verify it's a streaming response + assert response.status == 200 + assert response.content_type == "text/event-stream" + assert response.headers.get("X-Accel-Buffering") == "no" + assert response.headers.get("Mcp-Session-Id") == "session-456" + + # Read the streamed content + content = await response.read() + assert b"data: event1\n\n" in content + assert b"data: event2\n\n" in content + assert b"data: event3\n\n" in content + + +async def test_api_proxy_streaming_response_client_payload_error( + api_client: TestClient, + install_addon_example: Addon, +): + """Test that client payload errors during streaming are handled gracefully.""" + install_addon_example.persist[ATTR_ACCESS_TOKEN] = "abc123" + install_addon_example.data["homeassistant_api"] = True + + async def mock_content_iter_error(): + yield b"data: event1\n\n" + raise ClientPayloadError("boom") + + with patch.object(HomeAssistantAPI, "make_request") as make_request: + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.content_type = "text/event-stream" + mock_response.headers = { + "Cache-Control": "no-cache", + "Mcp-Session-Id": "session-789", + } + mock_response.content = mock_content_iter_error() + make_request.return_value.__aenter__.return_value = mock_response + + response = await api_client.get( + "/core/api/mcp", + headers={ + "Authorization": "Bearer abc123", + "Accept": "text/event-stream", + }, + ) + + assert response.status == 200 + assert response.content_type == "text/event-stream" + assert response.headers.get("X-Accel-Buffering") == "no" + assert response.headers.get("Mcp-Session-Id") == "session-789" + + content = await response.read() + assert b"data: event1\n\n" in content