diff --git a/homeassistant/components/ollama/__init__.py b/homeassistant/components/ollama/__init__.py index 805724b82e3..f95f8c8881f 100644 --- a/homeassistant/components/ollama/__init__.py +++ b/homeassistant/components/ollama/__init__.py @@ -10,9 +10,13 @@ import httpx import ollama from homeassistant.config_entries import ConfigEntry, ConfigSubentry -from homeassistant.const import CONF_URL, Platform +from homeassistant.const import CONF_API_KEY, CONF_URL, Platform from homeassistant.core import HomeAssistant -from homeassistant.exceptions import ConfigEntryNotReady +from homeassistant.exceptions import ( + ConfigEntryAuthFailed, + ConfigEntryError, + ConfigEntryNotReady, +) from homeassistant.helpers import ( config_validation as cv, device_registry as dr, @@ -62,10 +66,28 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup_entry(hass: HomeAssistant, entry: OllamaConfigEntry) -> bool: """Set up Ollama from a config entry.""" settings = {**entry.data, **entry.options} - client = ollama.AsyncClient(host=settings[CONF_URL], verify=get_default_context()) + api_key = settings.get(CONF_API_KEY) + stripped_api_key = api_key.strip() if isinstance(api_key, str) else None + client = ollama.AsyncClient( + host=settings[CONF_URL], + headers=( + {"Authorization": f"Bearer {stripped_api_key}"} + if stripped_api_key + else None + ), + verify=get_default_context(), + ) try: async with asyncio.timeout(DEFAULT_TIMEOUT): await client.list() + except ollama.ResponseError as err: + if err.status_code in (401, 403): + raise ConfigEntryAuthFailed from err + if err.status_code >= 500 or err.status_code == 429: + raise ConfigEntryNotReady(err) from err + # If the response is a 4xx error other than 401 or 403, it likely means the URL is valid but not an Ollama instance, + # so we raise ConfigEntryError to show an error in the UI, instead of ConfigEntryNotReady which would just keep retrying. + raise ConfigEntryError(err) from err except (TimeoutError, httpx.ConnectError) as err: raise ConfigEntryNotReady(err) from err diff --git a/homeassistant/components/ollama/config_flow.py b/homeassistant/components/ollama/config_flow.py index 84f56d966f4..5209208b9f0 100644 --- a/homeassistant/components/ollama/config_flow.py +++ b/homeassistant/components/ollama/config_flow.py @@ -20,7 +20,7 @@ from homeassistant.config_entries import ( ConfigSubentryFlow, SubentryFlowResult, ) -from homeassistant.const import CONF_LLM_HASS_API, CONF_NAME, CONF_URL +from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_NAME, CONF_URL from homeassistant.core import HomeAssistant, callback from homeassistant.helpers import config_validation as cv, llm from homeassistant.helpers.selector import ( @@ -68,6 +68,17 @@ STEP_USER_DATA_SCHEMA = vol.Schema( vol.Required(CONF_URL): TextSelector( TextSelectorConfig(type=TextSelectorType.URL) ), + vol.Optional(CONF_API_KEY): TextSelector( + TextSelectorConfig(type=TextSelectorType.PASSWORD) + ), + }, +) + +STEP_REAUTH_DATA_SCHEMA = vol.Schema( + { + vol.Optional(CONF_API_KEY): TextSelector( + TextSelectorConfig(type=TextSelectorType.PASSWORD) + ), } ) @@ -78,9 +89,40 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN): VERSION = 3 MINOR_VERSION = 3 - def __init__(self) -> None: - """Initialize config flow.""" - self.url: str | None = None + async def _async_validate_connection( + self, url: str, api_key: str | None + ) -> dict[str, str]: + """Validate connection and credentials against the Ollama server.""" + errors: dict[str, str] = {} + + try: + client = ollama.AsyncClient( + host=url, + headers={"Authorization": f"Bearer {api_key}"} if api_key else None, + verify=get_default_context(), + ) + + async with asyncio.timeout(DEFAULT_TIMEOUT): + await client.list() + + except ollama.ResponseError as err: + if err.status_code in (401, 403): + errors["base"] = "invalid_auth" + else: + _LOGGER.warning( + "Error response from Ollama server at %s: status %s, detail: %s", + url, + err.status_code, + str(err), + ) + errors["base"] = "unknown" + except TimeoutError, httpx.ConnectError: + errors["base"] = "cannot_connect" + except Exception: + _LOGGER.exception("Unexpected exception") + errors["base"] = "unknown" + + return errors async def async_step_user( self, user_input: dict[str, Any] | None = None @@ -92,9 +134,10 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN): ) errors = {} - url = user_input[CONF_URL] - - self._async_abort_entries_match({CONF_URL: url}) + url = user_input[CONF_URL].strip() + api_key = user_input.get(CONF_API_KEY) + if api_key: + api_key = api_key.strip() try: url = cv.url(url) @@ -108,15 +151,8 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN): errors=errors, ) - try: - client = ollama.AsyncClient(host=url, verify=get_default_context()) - async with asyncio.timeout(DEFAULT_TIMEOUT): - await client.list() - except TimeoutError, httpx.ConnectError: - errors["base"] = "cannot_connect" - except Exception: - _LOGGER.exception("Unexpected exception") - errors["base"] = "unknown" + self._async_abort_entries_match({CONF_URL: url}) + errors = await self._async_validate_connection(url, api_key) if errors: return self.async_show_form( @@ -127,9 +163,65 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN): errors=errors, ) - return self.async_create_entry( - title=url, - data={CONF_URL: url}, + entry_data: dict[str, str] = {CONF_URL: url} + if api_key: + entry_data[CONF_API_KEY] = api_key + + return self.async_create_entry(title=url, data=entry_data) + + async def async_step_reauth( + self, entry_data: Mapping[str, Any] + ) -> ConfigFlowResult: + """Handle reauthentication when existing credentials are invalid.""" + return await self.async_step_reauth_confirm() + + async def async_step_reauth_confirm( + self, user_input: dict[str, Any] | None = None + ) -> ConfigFlowResult: + """Handle reauthentication confirmation.""" + reauth_entry = self._get_reauth_entry() + + if user_input is None: + return self.async_show_form( + step_id="reauth_confirm", + data_schema=STEP_REAUTH_DATA_SCHEMA, + ) + + api_key = user_input.get(CONF_API_KEY) + if api_key: + api_key = api_key.strip() + + errors = await self._async_validate_connection( + reauth_entry.data[CONF_URL], api_key + ) + if errors: + return self.async_show_form( + step_id="reauth_confirm", + data_schema=self.add_suggested_values_to_schema( + STEP_REAUTH_DATA_SCHEMA, user_input + ), + errors=errors, + ) + + updated_data = { + **reauth_entry.data, + CONF_URL: reauth_entry.data[CONF_URL], + } + if api_key: + updated_data[CONF_API_KEY] = api_key + else: + updated_data.pop(CONF_API_KEY, None) + + updated_options = { + key: value + for key, value in reauth_entry.options.items() + if key != CONF_API_KEY + } + + return self.async_update_reload_and_abort( + reauth_entry, + data=updated_data, + options=updated_options, ) @classmethod diff --git a/homeassistant/components/ollama/strings.json b/homeassistant/components/ollama/strings.json index f8388fb5dd0..b4aaa7d75e1 100644 --- a/homeassistant/components/ollama/strings.json +++ b/homeassistant/components/ollama/strings.json @@ -1,16 +1,26 @@ { "config": { "abort": { - "already_configured": "[%key:common::config_flow::abort::already_configured_service%]" + "already_configured": "[%key:common::config_flow::abort::already_configured_service%]", + "reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]" }, "error": { "cannot_connect": "[%key:common::config_flow::error::cannot_connect%]", + "invalid_auth": "[%key:common::config_flow::error::invalid_auth%]", "invalid_url": "[%key:common::config_flow::error::invalid_host%]", "unknown": "[%key:common::config_flow::error::unknown%]" }, "step": { + "reauth_confirm": { + "data": { + "api_key": "[%key:common::config_flow::data::api_key%]" + }, + "description": "The Ollama integration needs to re-authenticate with your Ollama API key.", + "title": "[%key:common::config_flow::title::reauth%]" + }, "user": { "data": { + "api_key": "[%key:common::config_flow::data::api_key%]", "url": "[%key:common::config_flow::data::url%]" } } diff --git a/tests/components/ollama/conftest.py b/tests/components/ollama/conftest.py index 4b6b9a41fa3..39fe6f0b134 100644 --- a/tests/components/ollama/conftest.py +++ b/tests/components/ollama/conftest.py @@ -1,12 +1,13 @@ """Tests Ollama integration.""" +from copy import deepcopy from typing import Any from unittest.mock import patch import pytest from homeassistant.components import ollama -from homeassistant.const import CONF_LLM_HASS_API +from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API from homeassistant.core import HomeAssistant from homeassistant.helpers import llm from homeassistant.setup import async_setup_component @@ -22,14 +23,31 @@ def mock_config_entry_options() -> dict[str, Any]: return TEST_OPTIONS +@pytest.fixture +def has_token() -> bool: + """Fixture to indicate if the config entry has a token.""" + return False + + +@pytest.fixture +def mock_config_entry_data(has_token: bool) -> dict[str, Any]: + """Fixture for configuration entry data.""" + res = deepcopy(TEST_USER_DATA) + if has_token: + res[CONF_API_KEY] = "test_token" + return res + + @pytest.fixture def mock_config_entry( - hass: HomeAssistant, mock_config_entry_options: dict[str, Any] + hass: HomeAssistant, + mock_config_entry_options: dict[str, Any], + mock_config_entry_data: dict[str, Any], ) -> MockConfigEntry: """Mock a config entry.""" entry = MockConfigEntry( domain=ollama.DOMAIN, - data=TEST_USER_DATA, + data=mock_config_entry_data, version=3, minor_version=2, subentries_data=[ diff --git a/tests/components/ollama/test_config_flow.py b/tests/components/ollama/test_config_flow.py index ce52cfaf9ec..72390f27432 100644 --- a/tests/components/ollama/test_config_flow.py +++ b/tests/components/ollama/test_config_flow.py @@ -1,14 +1,17 @@ """Test the Ollama config flow.""" import asyncio -from unittest.mock import patch +from unittest.mock import ANY, AsyncMock, patch from httpx import ConnectError +from ollama import ResponseError import pytest from homeassistant import config_entries from homeassistant.components import ollama -from homeassistant.const import CONF_LLM_HASS_API, CONF_NAME +from homeassistant.components.ollama.const import DOMAIN +from homeassistant.config_entries import SOURCE_USER +from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_NAME, CONF_URL from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResultType @@ -20,14 +23,14 @@ TEST_MODEL = "test_model:latest" async def test_form(hass: HomeAssistant) -> None: """Test flow when configuring URL only.""" # Pretend we already set up a config entry. - hass.config.components.add(ollama.DOMAIN) + hass.config.components.add(DOMAIN) MockConfigEntry( - domain=ollama.DOMAIN, + domain=DOMAIN, state=config_entries.ConfigEntryState.LOADED, ).add_to_hass(hass) result = await hass.config_entries.flow.async_init( - ollama.DOMAIN, context={"source": config_entries.SOURCE_USER} + DOMAIN, context={"source": SOURCE_USER} ) assert result["type"] is FlowResultType.FORM assert result["errors"] is None @@ -48,18 +51,18 @@ async def test_form(hass: HomeAssistant) -> None: await hass.async_block_till_done() assert result2["type"] is FlowResultType.CREATE_ENTRY - assert result2["data"] == { - ollama.CONF_URL: "http://localhost:11434", - } + assert result2["data"] == {ollama.CONF_URL: "http://localhost:11434"} + # No subentries created by default assert len(result2.get("subentries", [])) == 0 assert len(mock_setup_entry.mock_calls) == 1 + assert CONF_API_KEY not in result2["data"] async def test_duplicate_entry(hass: HomeAssistant) -> None: """Test we abort on duplicate config entry.""" MockConfigEntry( - domain=ollama.DOMAIN, + domain=DOMAIN, data={ ollama.CONF_URL: "http://localhost:11434", ollama.CONF_MODEL: "test_model", @@ -67,7 +70,7 @@ async def test_duplicate_entry(hass: HomeAssistant) -> None: ).add_to_hass(hass) result = await hass.config_entries.flow.async_init( - ollama.DOMAIN, context={"source": config_entries.SOURCE_USER} + DOMAIN, context={"source": SOURCE_USER} ) assert result["type"] is FlowResultType.FORM assert not result["errors"] @@ -141,7 +144,7 @@ async def test_creating_new_conversation_subentry( ): new_flow = await hass.config_entries.subentries.async_init( (mock_config_entry.entry_id, "conversation"), - context={"source": config_entries.SOURCE_USER}, + context={"source": SOURCE_USER}, ) assert new_flow["type"] is FlowResultType.FORM @@ -181,7 +184,7 @@ async def test_creating_conversation_subentry_not_loaded( await hass.config_entries.async_unload(mock_config_entry.entry_id) result = await hass.config_entries.subentries.async_init( (mock_config_entry.entry_id, "conversation"), - context={"source": config_entries.SOURCE_USER}, + context={"source": SOURCE_USER}, ) assert result["type"] is FlowResultType.ABORT @@ -209,7 +212,7 @@ async def test_subentry_need_download( ): new_flow = await hass.config_entries.subentries.async_init( (mock_config_entry.entry_id, "conversation"), - context={"source": config_entries.SOURCE_USER}, + context={"source": SOURCE_USER}, ) assert new_flow["type"] is FlowResultType.FORM, new_flow @@ -271,7 +274,7 @@ async def test_subentry_download_error( ): new_flow = await hass.config_entries.subentries.async_init( (mock_config_entry.entry_id, "conversation"), - context={"source": config_entries.SOURCE_USER}, + context={"source": SOURCE_USER}, ) assert new_flow["type"] is FlowResultType.FORM @@ -307,6 +310,130 @@ async def test_subentry_download_error( assert result["reason"] == "download_failed" +@pytest.mark.parametrize( + ("init_data", "input_data", "expected_data"), + [ + ( + { + CONF_URL: "http://localhost:11434", + CONF_API_KEY: "old-api-key", + }, + { + CONF_API_KEY: "new-api-key", + }, + { + CONF_URL: "http://localhost:11434", + CONF_API_KEY: "new-api-key", + }, + ), + ( + { + CONF_URL: "http://localhost:11434", + CONF_API_KEY: "old-api-key", + }, + { + # Reconfigure without api_key to test that it gets removed from data + }, + { + CONF_URL: "http://localhost:11434", + }, + ), + ], +) +async def test_reauth_flow_success( + hass: HomeAssistant, init_data, input_data, expected_data +) -> None: + """Test successful reauthentication flow.""" + entry = MockConfigEntry( + domain=DOMAIN, + data=init_data, + options={CONF_API_KEY: "stale-options-api-key"}, + version=3, + minor_version=3, + ) + entry.add_to_hass(hass) + + result = await entry.start_reauth_flow(hass) + assert result["type"] is FlowResultType.FORM + assert result["step_id"] == "reauth_confirm" + + with patch( + "homeassistant.components.ollama.config_flow.ollama.AsyncClient.list", + return_value={"models": [{"model": TEST_MODEL}]}, + ): + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + input_data, + ) + await hass.async_block_till_done() + + assert result["type"] is FlowResultType.ABORT + assert result["reason"] == "reauth_successful" + + assert entry.data == expected_data + assert entry.options == {} + + +@pytest.mark.parametrize( + ("side_effect", "error"), + [ + (ResponseError(error="Unauthorized", status_code=401), "invalid_auth"), + (ConnectError(message="Connection failed"), "cannot_connect"), + ], +) +async def test_reauth_flow_errors(hass: HomeAssistant, side_effect, error) -> None: + """Test reauthentication flow when authentication fails.""" + entry = MockConfigEntry( + domain=DOMAIN, + data={ + CONF_URL: "http://localhost:11434", + CONF_API_KEY: "old-api-key", + }, + version=3, + minor_version=3, + ) + entry.add_to_hass(hass) + + result = await entry.start_reauth_flow(hass) + assert result["type"] is FlowResultType.FORM + assert result["step_id"] == "reauth_confirm" + + with patch( + "homeassistant.components.ollama.config_flow.ollama.AsyncClient.list", + side_effect=side_effect, + ): + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + { + CONF_API_KEY: "other-api-key", + }, + ) + + assert result["type"] is FlowResultType.FORM + assert result["step_id"] == "reauth_confirm" + assert result["errors"] == {"base": error} + + with patch( + "homeassistant.components.ollama.config_flow.ollama.AsyncClient.list", + return_value={"models": [{"model": TEST_MODEL}]}, + ): + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + { + CONF_API_KEY: "new-api-key", + }, + ) + await hass.async_block_till_done() + + assert result["type"] is FlowResultType.ABORT + assert result["reason"] == "reauth_successful" + + assert entry.data == { + CONF_URL: "http://localhost:11434", + CONF_API_KEY: "new-api-key", + } + + @pytest.mark.parametrize( ("side_effect", "error"), [ @@ -317,7 +444,7 @@ async def test_subentry_download_error( async def test_form_errors(hass: HomeAssistant, side_effect, error) -> None: """Test we handle errors.""" result = await hass.config_entries.flow.async_init( - ollama.DOMAIN, context={"source": config_entries.SOURCE_USER} + DOMAIN, context={"source": SOURCE_USER} ) with patch( @@ -332,10 +459,50 @@ async def test_form_errors(hass: HomeAssistant, side_effect, error) -> None: assert result2["errors"] == {"base": error} +@pytest.mark.parametrize( + ("side_effect", "error"), + [ + (ConnectError(message=""), "cannot_connect"), + (RuntimeError(), "unknown"), + ], +) +async def test_form_errors_recovery(hass: HomeAssistant, side_effect, error) -> None: + """Test that the user flow recovers after an error and completes successfully.""" + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": SOURCE_USER} + ) + + # First attempt fails + with patch( + "homeassistant.components.ollama.config_flow.ollama.AsyncClient.list", + side_effect=side_effect, + ): + result = await hass.config_entries.flow.async_configure( + result["flow_id"], {ollama.CONF_URL: "http://localhost:11434"} + ) + + assert result["type"] is FlowResultType.FORM + assert result["errors"] == {"base": error} + + # Second attempt succeeds + with patch( + "homeassistant.components.ollama.config_flow.ollama.AsyncClient.list", + return_value={"models": [{"model": TEST_MODEL}]}, + ): + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + {ollama.CONF_URL: "http://localhost:11434"}, + ) + await hass.async_block_till_done() + + assert result["type"] is FlowResultType.CREATE_ENTRY + assert result["data"] == {ollama.CONF_URL: "http://localhost:11434"} + + async def test_form_invalid_url(hass: HomeAssistant) -> None: """Test we handle invalid URL.""" result = await hass.config_entries.flow.async_init( - ollama.DOMAIN, context={"source": config_entries.SOURCE_USER} + DOMAIN, context={"source": SOURCE_USER} ) result2 = await hass.config_entries.flow.async_configure( @@ -358,7 +525,7 @@ async def test_subentry_connection_error( ): new_flow = await hass.config_entries.subentries.async_init( (mock_config_entry.entry_id, "conversation"), - context={"source": config_entries.SOURCE_USER}, + context={"source": SOURCE_USER}, ) assert new_flow["type"] is FlowResultType.ABORT @@ -380,7 +547,7 @@ async def test_subentry_model_check_exception( ): new_flow = await hass.config_entries.subentries.async_init( (mock_config_entry.entry_id, "conversation"), - context={"source": config_entries.SOURCE_USER}, + context={"source": SOURCE_USER}, ) assert new_flow["type"] is FlowResultType.FORM @@ -500,7 +667,7 @@ async def test_creating_ai_task_subentry( ): result = await hass.config_entries.subentries.async_init( (mock_config_entry.entry_id, "ai_task_data"), - context={"source": config_entries.SOURCE_USER}, + context={"source": SOURCE_USER}, ) assert result.get("type") is FlowResultType.FORM @@ -552,8 +719,167 @@ async def test_ai_task_subentry_not_loaded( # Don't call mock_init_component to simulate not loaded state result = await hass.config_entries.subentries.async_init( (mock_config_entry.entry_id, "ai_task_data"), - context={"source": config_entries.SOURCE_USER}, + context={"source": SOURCE_USER}, ) assert result.get("type") is FlowResultType.ABORT assert result.get("reason") == "entry_not_loaded" + + +@pytest.mark.parametrize( + ("user_input", "expected_headers", "expected_data"), + [ + ( + {CONF_URL: "http://localhost:11434", CONF_API_KEY: "my-secret-token"}, + {"Authorization": "Bearer my-secret-token"}, + {CONF_URL: "http://localhost:11434", CONF_API_KEY: "my-secret-token"}, + ), + ( + {CONF_URL: "http://localhost:11434", CONF_API_KEY: ""}, + None, + {CONF_URL: "http://localhost:11434"}, + ), + ( + {CONF_URL: "http://localhost:11434", CONF_API_KEY: " "}, + None, + {CONF_URL: "http://localhost:11434"}, + ), + ( + {CONF_URL: "http://localhost:11434"}, + None, + {CONF_URL: "http://localhost:11434"}, + ), + ], +) +async def test_user_step_async_client_headers( + hass: HomeAssistant, + user_input: dict[str, str], + expected_headers: dict[str, str] | None, + expected_data: dict[str, str], +) -> None: + """Test Authorization header passed to AsyncClient with/without api_key.""" + with patch( + "homeassistant.components.ollama.config_flow.ollama.AsyncClient", + ) as mock_async_client: + mock_async_client.return_value.list = AsyncMock(return_value={"models": []}) + + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": SOURCE_USER} + ) + assert result["type"] is FlowResultType.FORM + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + user_input=user_input, + ) + await hass.async_block_till_done() + + assert result["type"] is FlowResultType.CREATE_ENTRY + assert result["data"] == expected_data + mock_async_client.assert_called_with( + host="http://localhost:11434", + headers=expected_headers, + verify=ANY, + ) + + +@pytest.mark.parametrize( + ("status_code", "error", "error_message", "user_input"), + [ + ( + 400, + "unknown", + "Bad Request", + { + CONF_URL: "http://localhost:11434", + CONF_API_KEY: "my-secret-token", + }, + ), + ( + 401, + "invalid_auth", + "Unauthorized", + { + CONF_URL: "http://localhost:11434", + CONF_API_KEY: "my-secret-token", + }, + ), + ( + 403, + "invalid_auth", + "Unauthorized", + { + CONF_URL: "http://localhost:11434", + CONF_API_KEY: "my-secret-token", + }, + ), + ( + 403, + "invalid_auth", + "Forbidden", + { + CONF_URL: "http://localhost:11434", + }, + ), + ], +) +async def test_user_step_errors( + hass: HomeAssistant, + status_code: int, + error: str, + error_message: str, + user_input: dict[str, str], +) -> None: + """Test error handling when ollama returns HTTP 4xx.""" + with patch( + "homeassistant.components.ollama.config_flow.ollama.AsyncClient" + ) as mock_async_client: + mock_client_instance = AsyncMock() + mock_async_client.return_value = mock_client_instance + + mock_client_instance.list.side_effect = ResponseError( + error=error_message, status_code=status_code + ) + + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": SOURCE_USER} + ) + assert result["type"] is FlowResultType.FORM + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + user_input=user_input, + ) + await hass.async_block_till_done() + + assert result["type"] is FlowResultType.FORM + assert result.get("errors") == {"base": error} + + +async def test_user_step_trim_url(hass: HomeAssistant) -> None: + """Test URL is trimmed before validation and persistence.""" + with patch( + "homeassistant.components.ollama.config_flow.ollama.AsyncClient", + ) as mock_async_client: + mock_async_client.return_value.list = AsyncMock(return_value={"models": []}) + + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": SOURCE_USER} + ) + assert result["type"] is FlowResultType.FORM + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + user_input={ + CONF_URL: " http://localhost:11434 ", + }, + ) + await hass.async_block_till_done() + + assert result["type"] is FlowResultType.CREATE_ENTRY + assert result["data"] == {CONF_URL: "http://localhost:11434"} + mock_async_client.assert_called_with( + host="http://localhost:11434", + headers=None, + verify=ANY, + ) diff --git a/tests/components/ollama/test_init.py b/tests/components/ollama/test_init.py index 25e41daf276..598f5ae7d54 100644 --- a/tests/components/ollama/test_init.py +++ b/tests/components/ollama/test_init.py @@ -1,14 +1,19 @@ """Tests for the Ollama integration.""" from typing import Any -from unittest.mock import patch +from unittest.mock import AsyncMock, patch from httpx import ConnectError +from ollama import ResponseError import pytest from homeassistant.components import ollama from homeassistant.components.ollama.const import DOMAIN -from homeassistant.config_entries import ConfigEntryDisabler, ConfigSubentryData +from homeassistant.config_entries import ( + ConfigEntryDisabler, + ConfigEntryState, + ConfigSubentryData, +) from homeassistant.const import CONF_URL from homeassistant.core import HomeAssistant from homeassistant.helpers import device_registry as dr, entity_registry as er, llm @@ -21,7 +26,7 @@ from . import TEST_OPTIONS from tests.common import MockConfigEntry V1_TEST_USER_DATA = { - ollama.CONF_URL: "http://localhost:11434", + CONF_URL: "http://localhost:11434", ollama.CONF_MODEL: "test_model:latest", } @@ -58,6 +63,74 @@ async def test_init_error( assert error in caplog.text +@pytest.mark.parametrize("has_token", [True]) +async def test_init_with_api_key( + hass: HomeAssistant, mock_config_entry: MockConfigEntry +) -> None: + """Test initialization with API key - Authorization header should be set.""" + # Create entry with API key in data (version 3.0 after migration) + mock_config_entry.add_to_hass(hass) + + with patch("homeassistant.components.ollama.ollama.AsyncClient") as mock_client: + mock_client.return_value.list = AsyncMock(return_value={"models": []}) + + await hass.config_entries.async_setup(mock_config_entry.entry_id) + await hass.async_block_till_done() + + assert any( + call.kwargs["headers"] == {"Authorization": "Bearer test_token"} + for call in mock_client.call_args_list + ) + + +async def test_init_without_api_key( + hass: HomeAssistant, mock_config_entry: MockConfigEntry +) -> None: + """Test initialization without API key - Authorization header should not be set.""" + # Create entry without API key in data (version 3.0 after migration) + mock_config_entry.add_to_hass(hass) + + with patch("homeassistant.components.ollama.ollama.AsyncClient") as mock_client: + mock_client.return_value.list = AsyncMock(return_value={"models": []}) + + assert await async_setup_component(hass, ollama.DOMAIN, {}) + await hass.async_block_till_done() + + assert all( + call.kwargs["headers"] is None for call in mock_client.call_args_list + ) + + +@pytest.mark.parametrize( + ("status_code", "entry_state"), + [ + (401, ConfigEntryState.SETUP_ERROR), + (403, ConfigEntryState.SETUP_ERROR), + (500, ConfigEntryState.SETUP_RETRY), + (429, ConfigEntryState.SETUP_RETRY), + (400, ConfigEntryState.SETUP_ERROR), + ], +) +async def test_async_setup_entry_auth_failed_on_response_error( + hass: HomeAssistant, + mock_config_entry: MockConfigEntry, + status_code: int, + entry_state: ConfigEntryState, +) -> None: + """Test async_setup_entry raises auth failed on 401/403 response.""" + mock_config_entry.add_to_hass(hass) + + with patch("homeassistant.components.ollama.ollama.AsyncClient") as mock_client: + mock_client.return_value.list = AsyncMock( + side_effect=ResponseError(error="Unauthorized", status_code=status_code) + ) + + await hass.config_entries.async_setup(mock_config_entry.entry_id) + await hass.async_block_till_done() + + assert mock_config_entry.state is entry_state + + async def test_migration_from_v1( hass: HomeAssistant, device_registry: dr.DeviceRegistry, @@ -102,7 +175,7 @@ async def test_migration_from_v1( assert mock_config_entry.version == 3 assert mock_config_entry.minor_version == 3 # After migration, parent entry should only have URL - assert mock_config_entry.data == {ollama.CONF_URL: "http://localhost:11434"} + assert mock_config_entry.data == {CONF_URL: "http://localhost:11434"} assert mock_config_entry.options == {} assert len(mock_config_entry.subentries) == 2 @@ -748,7 +821,7 @@ async def test_migration_from_v2_2(hass: HomeAssistant) -> None: mock_config_entry = MockConfigEntry( domain=DOMAIN, data={ - ollama.CONF_URL: "http://localhost:11434", + CONF_URL: "http://localhost:11434", ollama.CONF_MODEL: "test_model:latest", # Model still in main data }, version=2, @@ -768,7 +841,7 @@ async def test_migration_from_v2_2(hass: HomeAssistant) -> None: assert mock_config_entry.minor_version == 3 # Check that model was moved from main data to subentry - assert mock_config_entry.data == {ollama.CONF_URL: "http://localhost:11434"} + assert mock_config_entry.data == {CONF_URL: "http://localhost:11434"} assert len(mock_config_entry.subentries) == 2 subentry = next(iter(mock_config_entry.subentries.values()))