1
0
mirror of https://github.com/home-assistant/core.git synced 2026-04-02 08:26:41 +01:00

Add bearer token as optional setting to Ollama (#165325)

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Joostlek <joostlek@outlook.com>
This commit is contained in:
Cyril MARIN
2026-03-16 22:14:33 +01:00
committed by GitHub
parent 2042f2e2bd
commit a963eed3a7
6 changed files with 593 additions and 52 deletions

View File

@@ -10,9 +10,13 @@ import httpx
import ollama
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
from homeassistant.const import CONF_URL, Platform
from homeassistant.const import CONF_API_KEY, CONF_URL, Platform
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.exceptions import (
ConfigEntryAuthFailed,
ConfigEntryError,
ConfigEntryNotReady,
)
from homeassistant.helpers import (
config_validation as cv,
device_registry as dr,
@@ -62,10 +66,28 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async def async_setup_entry(hass: HomeAssistant, entry: OllamaConfigEntry) -> bool:
"""Set up Ollama from a config entry."""
settings = {**entry.data, **entry.options}
client = ollama.AsyncClient(host=settings[CONF_URL], verify=get_default_context())
api_key = settings.get(CONF_API_KEY)
stripped_api_key = api_key.strip() if isinstance(api_key, str) else None
client = ollama.AsyncClient(
host=settings[CONF_URL],
headers=(
{"Authorization": f"Bearer {stripped_api_key}"}
if stripped_api_key
else None
),
verify=get_default_context(),
)
try:
async with asyncio.timeout(DEFAULT_TIMEOUT):
await client.list()
except ollama.ResponseError as err:
if err.status_code in (401, 403):
raise ConfigEntryAuthFailed from err
if err.status_code >= 500 or err.status_code == 429:
raise ConfigEntryNotReady(err) from err
# If the response is a 4xx error other than 401 or 403, it likely means the URL is valid but not an Ollama instance,
# so we raise ConfigEntryError to show an error in the UI, instead of ConfigEntryNotReady which would just keep retrying.
raise ConfigEntryError(err) from err
except (TimeoutError, httpx.ConnectError) as err:
raise ConfigEntryNotReady(err) from err

View File

@@ -20,7 +20,7 @@ from homeassistant.config_entries import (
ConfigSubentryFlow,
SubentryFlowResult,
)
from homeassistant.const import CONF_LLM_HASS_API, CONF_NAME, CONF_URL
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_NAME, CONF_URL
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_validation as cv, llm
from homeassistant.helpers.selector import (
@@ -68,6 +68,17 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
vol.Required(CONF_URL): TextSelector(
TextSelectorConfig(type=TextSelectorType.URL)
),
vol.Optional(CONF_API_KEY): TextSelector(
TextSelectorConfig(type=TextSelectorType.PASSWORD)
),
},
)
STEP_REAUTH_DATA_SCHEMA = vol.Schema(
{
vol.Optional(CONF_API_KEY): TextSelector(
TextSelectorConfig(type=TextSelectorType.PASSWORD)
),
}
)
@@ -78,9 +89,40 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
VERSION = 3
MINOR_VERSION = 3
def __init__(self) -> None:
"""Initialize config flow."""
self.url: str | None = None
async def _async_validate_connection(
self, url: str, api_key: str | None
) -> dict[str, str]:
"""Validate connection and credentials against the Ollama server."""
errors: dict[str, str] = {}
try:
client = ollama.AsyncClient(
host=url,
headers={"Authorization": f"Bearer {api_key}"} if api_key else None,
verify=get_default_context(),
)
async with asyncio.timeout(DEFAULT_TIMEOUT):
await client.list()
except ollama.ResponseError as err:
if err.status_code in (401, 403):
errors["base"] = "invalid_auth"
else:
_LOGGER.warning(
"Error response from Ollama server at %s: status %s, detail: %s",
url,
err.status_code,
str(err),
)
errors["base"] = "unknown"
except TimeoutError, httpx.ConnectError:
errors["base"] = "cannot_connect"
except Exception:
_LOGGER.exception("Unexpected exception")
errors["base"] = "unknown"
return errors
async def async_step_user(
self, user_input: dict[str, Any] | None = None
@@ -92,9 +134,10 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
)
errors = {}
url = user_input[CONF_URL]
self._async_abort_entries_match({CONF_URL: url})
url = user_input[CONF_URL].strip()
api_key = user_input.get(CONF_API_KEY)
if api_key:
api_key = api_key.strip()
try:
url = cv.url(url)
@@ -108,15 +151,8 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
errors=errors,
)
try:
client = ollama.AsyncClient(host=url, verify=get_default_context())
async with asyncio.timeout(DEFAULT_TIMEOUT):
await client.list()
except TimeoutError, httpx.ConnectError:
errors["base"] = "cannot_connect"
except Exception:
_LOGGER.exception("Unexpected exception")
errors["base"] = "unknown"
self._async_abort_entries_match({CONF_URL: url})
errors = await self._async_validate_connection(url, api_key)
if errors:
return self.async_show_form(
@@ -127,9 +163,65 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
errors=errors,
)
return self.async_create_entry(
title=url,
data={CONF_URL: url},
entry_data: dict[str, str] = {CONF_URL: url}
if api_key:
entry_data[CONF_API_KEY] = api_key
return self.async_create_entry(title=url, data=entry_data)
async def async_step_reauth(
self, entry_data: Mapping[str, Any]
) -> ConfigFlowResult:
"""Handle reauthentication when existing credentials are invalid."""
return await self.async_step_reauth_confirm()
async def async_step_reauth_confirm(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Handle reauthentication confirmation."""
reauth_entry = self._get_reauth_entry()
if user_input is None:
return self.async_show_form(
step_id="reauth_confirm",
data_schema=STEP_REAUTH_DATA_SCHEMA,
)
api_key = user_input.get(CONF_API_KEY)
if api_key:
api_key = api_key.strip()
errors = await self._async_validate_connection(
reauth_entry.data[CONF_URL], api_key
)
if errors:
return self.async_show_form(
step_id="reauth_confirm",
data_schema=self.add_suggested_values_to_schema(
STEP_REAUTH_DATA_SCHEMA, user_input
),
errors=errors,
)
updated_data = {
**reauth_entry.data,
CONF_URL: reauth_entry.data[CONF_URL],
}
if api_key:
updated_data[CONF_API_KEY] = api_key
else:
updated_data.pop(CONF_API_KEY, None)
updated_options = {
key: value
for key, value in reauth_entry.options.items()
if key != CONF_API_KEY
}
return self.async_update_reload_and_abort(
reauth_entry,
data=updated_data,
options=updated_options,
)
@classmethod

View File

@@ -1,16 +1,26 @@
{
"config": {
"abort": {
"already_configured": "[%key:common::config_flow::abort::already_configured_service%]"
"already_configured": "[%key:common::config_flow::abort::already_configured_service%]",
"reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]"
},
"error": {
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
"invalid_auth": "[%key:common::config_flow::error::invalid_auth%]",
"invalid_url": "[%key:common::config_flow::error::invalid_host%]",
"unknown": "[%key:common::config_flow::error::unknown%]"
},
"step": {
"reauth_confirm": {
"data": {
"api_key": "[%key:common::config_flow::data::api_key%]"
},
"description": "The Ollama integration needs to re-authenticate with your Ollama API key.",
"title": "[%key:common::config_flow::title::reauth%]"
},
"user": {
"data": {
"api_key": "[%key:common::config_flow::data::api_key%]",
"url": "[%key:common::config_flow::data::url%]"
}
}

View File

@@ -1,12 +1,13 @@
"""Tests Ollama integration."""
from copy import deepcopy
from typing import Any
from unittest.mock import patch
import pytest
from homeassistant.components import ollama
from homeassistant.const import CONF_LLM_HASS_API
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant
from homeassistant.helpers import llm
from homeassistant.setup import async_setup_component
@@ -22,14 +23,31 @@ def mock_config_entry_options() -> dict[str, Any]:
return TEST_OPTIONS
@pytest.fixture
def has_token() -> bool:
"""Fixture to indicate if the config entry has a token."""
return False
@pytest.fixture
def mock_config_entry_data(has_token: bool) -> dict[str, Any]:
"""Fixture for configuration entry data."""
res = deepcopy(TEST_USER_DATA)
if has_token:
res[CONF_API_KEY] = "test_token"
return res
@pytest.fixture
def mock_config_entry(
hass: HomeAssistant, mock_config_entry_options: dict[str, Any]
hass: HomeAssistant,
mock_config_entry_options: dict[str, Any],
mock_config_entry_data: dict[str, Any],
) -> MockConfigEntry:
"""Mock a config entry."""
entry = MockConfigEntry(
domain=ollama.DOMAIN,
data=TEST_USER_DATA,
data=mock_config_entry_data,
version=3,
minor_version=2,
subentries_data=[

View File

@@ -1,14 +1,17 @@
"""Test the Ollama config flow."""
import asyncio
from unittest.mock import patch
from unittest.mock import ANY, AsyncMock, patch
from httpx import ConnectError
from ollama import ResponseError
import pytest
from homeassistant import config_entries
from homeassistant.components import ollama
from homeassistant.const import CONF_LLM_HASS_API, CONF_NAME
from homeassistant.components.ollama.const import DOMAIN
from homeassistant.config_entries import SOURCE_USER
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_NAME, CONF_URL
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType
@@ -20,14 +23,14 @@ TEST_MODEL = "test_model:latest"
async def test_form(hass: HomeAssistant) -> None:
"""Test flow when configuring URL only."""
# Pretend we already set up a config entry.
hass.config.components.add(ollama.DOMAIN)
hass.config.components.add(DOMAIN)
MockConfigEntry(
domain=ollama.DOMAIN,
domain=DOMAIN,
state=config_entries.ConfigEntryState.LOADED,
).add_to_hass(hass)
result = await hass.config_entries.flow.async_init(
ollama.DOMAIN, context={"source": config_entries.SOURCE_USER}
DOMAIN, context={"source": SOURCE_USER}
)
assert result["type"] is FlowResultType.FORM
assert result["errors"] is None
@@ -48,18 +51,18 @@ async def test_form(hass: HomeAssistant) -> None:
await hass.async_block_till_done()
assert result2["type"] is FlowResultType.CREATE_ENTRY
assert result2["data"] == {
ollama.CONF_URL: "http://localhost:11434",
}
assert result2["data"] == {ollama.CONF_URL: "http://localhost:11434"}
# No subentries created by default
assert len(result2.get("subentries", [])) == 0
assert len(mock_setup_entry.mock_calls) == 1
assert CONF_API_KEY not in result2["data"]
async def test_duplicate_entry(hass: HomeAssistant) -> None:
"""Test we abort on duplicate config entry."""
MockConfigEntry(
domain=ollama.DOMAIN,
domain=DOMAIN,
data={
ollama.CONF_URL: "http://localhost:11434",
ollama.CONF_MODEL: "test_model",
@@ -67,7 +70,7 @@ async def test_duplicate_entry(hass: HomeAssistant) -> None:
).add_to_hass(hass)
result = await hass.config_entries.flow.async_init(
ollama.DOMAIN, context={"source": config_entries.SOURCE_USER}
DOMAIN, context={"source": SOURCE_USER}
)
assert result["type"] is FlowResultType.FORM
assert not result["errors"]
@@ -141,7 +144,7 @@ async def test_creating_new_conversation_subentry(
):
new_flow = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "conversation"),
context={"source": config_entries.SOURCE_USER},
context={"source": SOURCE_USER},
)
assert new_flow["type"] is FlowResultType.FORM
@@ -181,7 +184,7 @@ async def test_creating_conversation_subentry_not_loaded(
await hass.config_entries.async_unload(mock_config_entry.entry_id)
result = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "conversation"),
context={"source": config_entries.SOURCE_USER},
context={"source": SOURCE_USER},
)
assert result["type"] is FlowResultType.ABORT
@@ -209,7 +212,7 @@ async def test_subentry_need_download(
):
new_flow = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "conversation"),
context={"source": config_entries.SOURCE_USER},
context={"source": SOURCE_USER},
)
assert new_flow["type"] is FlowResultType.FORM, new_flow
@@ -271,7 +274,7 @@ async def test_subentry_download_error(
):
new_flow = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "conversation"),
context={"source": config_entries.SOURCE_USER},
context={"source": SOURCE_USER},
)
assert new_flow["type"] is FlowResultType.FORM
@@ -307,6 +310,130 @@ async def test_subentry_download_error(
assert result["reason"] == "download_failed"
@pytest.mark.parametrize(
("init_data", "input_data", "expected_data"),
[
(
{
CONF_URL: "http://localhost:11434",
CONF_API_KEY: "old-api-key",
},
{
CONF_API_KEY: "new-api-key",
},
{
CONF_URL: "http://localhost:11434",
CONF_API_KEY: "new-api-key",
},
),
(
{
CONF_URL: "http://localhost:11434",
CONF_API_KEY: "old-api-key",
},
{
# Reconfigure without api_key to test that it gets removed from data
},
{
CONF_URL: "http://localhost:11434",
},
),
],
)
async def test_reauth_flow_success(
hass: HomeAssistant, init_data, input_data, expected_data
) -> None:
"""Test successful reauthentication flow."""
entry = MockConfigEntry(
domain=DOMAIN,
data=init_data,
options={CONF_API_KEY: "stale-options-api-key"},
version=3,
minor_version=3,
)
entry.add_to_hass(hass)
result = await entry.start_reauth_flow(hass)
assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "reauth_confirm"
with patch(
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.list",
return_value={"models": [{"model": TEST_MODEL}]},
):
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
input_data,
)
await hass.async_block_till_done()
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "reauth_successful"
assert entry.data == expected_data
assert entry.options == {}
@pytest.mark.parametrize(
("side_effect", "error"),
[
(ResponseError(error="Unauthorized", status_code=401), "invalid_auth"),
(ConnectError(message="Connection failed"), "cannot_connect"),
],
)
async def test_reauth_flow_errors(hass: HomeAssistant, side_effect, error) -> None:
"""Test reauthentication flow when authentication fails."""
entry = MockConfigEntry(
domain=DOMAIN,
data={
CONF_URL: "http://localhost:11434",
CONF_API_KEY: "old-api-key",
},
version=3,
minor_version=3,
)
entry.add_to_hass(hass)
result = await entry.start_reauth_flow(hass)
assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "reauth_confirm"
with patch(
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.list",
side_effect=side_effect,
):
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_API_KEY: "other-api-key",
},
)
assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "reauth_confirm"
assert result["errors"] == {"base": error}
with patch(
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.list",
return_value={"models": [{"model": TEST_MODEL}]},
):
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_API_KEY: "new-api-key",
},
)
await hass.async_block_till_done()
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "reauth_successful"
assert entry.data == {
CONF_URL: "http://localhost:11434",
CONF_API_KEY: "new-api-key",
}
@pytest.mark.parametrize(
("side_effect", "error"),
[
@@ -317,7 +444,7 @@ async def test_subentry_download_error(
async def test_form_errors(hass: HomeAssistant, side_effect, error) -> None:
"""Test we handle errors."""
result = await hass.config_entries.flow.async_init(
ollama.DOMAIN, context={"source": config_entries.SOURCE_USER}
DOMAIN, context={"source": SOURCE_USER}
)
with patch(
@@ -332,10 +459,50 @@ async def test_form_errors(hass: HomeAssistant, side_effect, error) -> None:
assert result2["errors"] == {"base": error}
@pytest.mark.parametrize(
("side_effect", "error"),
[
(ConnectError(message=""), "cannot_connect"),
(RuntimeError(), "unknown"),
],
)
async def test_form_errors_recovery(hass: HomeAssistant, side_effect, error) -> None:
"""Test that the user flow recovers after an error and completes successfully."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_USER}
)
# First attempt fails
with patch(
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.list",
side_effect=side_effect,
):
result = await hass.config_entries.flow.async_configure(
result["flow_id"], {ollama.CONF_URL: "http://localhost:11434"}
)
assert result["type"] is FlowResultType.FORM
assert result["errors"] == {"base": error}
# Second attempt succeeds
with patch(
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.list",
return_value={"models": [{"model": TEST_MODEL}]},
):
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{ollama.CONF_URL: "http://localhost:11434"},
)
await hass.async_block_till_done()
assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["data"] == {ollama.CONF_URL: "http://localhost:11434"}
async def test_form_invalid_url(hass: HomeAssistant) -> None:
"""Test we handle invalid URL."""
result = await hass.config_entries.flow.async_init(
ollama.DOMAIN, context={"source": config_entries.SOURCE_USER}
DOMAIN, context={"source": SOURCE_USER}
)
result2 = await hass.config_entries.flow.async_configure(
@@ -358,7 +525,7 @@ async def test_subentry_connection_error(
):
new_flow = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "conversation"),
context={"source": config_entries.SOURCE_USER},
context={"source": SOURCE_USER},
)
assert new_flow["type"] is FlowResultType.ABORT
@@ -380,7 +547,7 @@ async def test_subentry_model_check_exception(
):
new_flow = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "conversation"),
context={"source": config_entries.SOURCE_USER},
context={"source": SOURCE_USER},
)
assert new_flow["type"] is FlowResultType.FORM
@@ -500,7 +667,7 @@ async def test_creating_ai_task_subentry(
):
result = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "ai_task_data"),
context={"source": config_entries.SOURCE_USER},
context={"source": SOURCE_USER},
)
assert result.get("type") is FlowResultType.FORM
@@ -552,8 +719,167 @@ async def test_ai_task_subentry_not_loaded(
# Don't call mock_init_component to simulate not loaded state
result = await hass.config_entries.subentries.async_init(
(mock_config_entry.entry_id, "ai_task_data"),
context={"source": config_entries.SOURCE_USER},
context={"source": SOURCE_USER},
)
assert result.get("type") is FlowResultType.ABORT
assert result.get("reason") == "entry_not_loaded"
@pytest.mark.parametrize(
("user_input", "expected_headers", "expected_data"),
[
(
{CONF_URL: "http://localhost:11434", CONF_API_KEY: "my-secret-token"},
{"Authorization": "Bearer my-secret-token"},
{CONF_URL: "http://localhost:11434", CONF_API_KEY: "my-secret-token"},
),
(
{CONF_URL: "http://localhost:11434", CONF_API_KEY: ""},
None,
{CONF_URL: "http://localhost:11434"},
),
(
{CONF_URL: "http://localhost:11434", CONF_API_KEY: " "},
None,
{CONF_URL: "http://localhost:11434"},
),
(
{CONF_URL: "http://localhost:11434"},
None,
{CONF_URL: "http://localhost:11434"},
),
],
)
async def test_user_step_async_client_headers(
hass: HomeAssistant,
user_input: dict[str, str],
expected_headers: dict[str, str] | None,
expected_data: dict[str, str],
) -> None:
"""Test Authorization header passed to AsyncClient with/without api_key."""
with patch(
"homeassistant.components.ollama.config_flow.ollama.AsyncClient",
) as mock_async_client:
mock_async_client.return_value.list = AsyncMock(return_value={"models": []})
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_USER}
)
assert result["type"] is FlowResultType.FORM
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
user_input=user_input,
)
await hass.async_block_till_done()
assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["data"] == expected_data
mock_async_client.assert_called_with(
host="http://localhost:11434",
headers=expected_headers,
verify=ANY,
)
@pytest.mark.parametrize(
("status_code", "error", "error_message", "user_input"),
[
(
400,
"unknown",
"Bad Request",
{
CONF_URL: "http://localhost:11434",
CONF_API_KEY: "my-secret-token",
},
),
(
401,
"invalid_auth",
"Unauthorized",
{
CONF_URL: "http://localhost:11434",
CONF_API_KEY: "my-secret-token",
},
),
(
403,
"invalid_auth",
"Unauthorized",
{
CONF_URL: "http://localhost:11434",
CONF_API_KEY: "my-secret-token",
},
),
(
403,
"invalid_auth",
"Forbidden",
{
CONF_URL: "http://localhost:11434",
},
),
],
)
async def test_user_step_errors(
hass: HomeAssistant,
status_code: int,
error: str,
error_message: str,
user_input: dict[str, str],
) -> None:
"""Test error handling when ollama returns HTTP 4xx."""
with patch(
"homeassistant.components.ollama.config_flow.ollama.AsyncClient"
) as mock_async_client:
mock_client_instance = AsyncMock()
mock_async_client.return_value = mock_client_instance
mock_client_instance.list.side_effect = ResponseError(
error=error_message, status_code=status_code
)
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_USER}
)
assert result["type"] is FlowResultType.FORM
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
user_input=user_input,
)
await hass.async_block_till_done()
assert result["type"] is FlowResultType.FORM
assert result.get("errors") == {"base": error}
async def test_user_step_trim_url(hass: HomeAssistant) -> None:
"""Test URL is trimmed before validation and persistence."""
with patch(
"homeassistant.components.ollama.config_flow.ollama.AsyncClient",
) as mock_async_client:
mock_async_client.return_value.list = AsyncMock(return_value={"models": []})
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_USER}
)
assert result["type"] is FlowResultType.FORM
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
user_input={
CONF_URL: " http://localhost:11434 ",
},
)
await hass.async_block_till_done()
assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["data"] == {CONF_URL: "http://localhost:11434"}
mock_async_client.assert_called_with(
host="http://localhost:11434",
headers=None,
verify=ANY,
)

View File

@@ -1,14 +1,19 @@
"""Tests for the Ollama integration."""
from typing import Any
from unittest.mock import patch
from unittest.mock import AsyncMock, patch
from httpx import ConnectError
from ollama import ResponseError
import pytest
from homeassistant.components import ollama
from homeassistant.components.ollama.const import DOMAIN
from homeassistant.config_entries import ConfigEntryDisabler, ConfigSubentryData
from homeassistant.config_entries import (
ConfigEntryDisabler,
ConfigEntryState,
ConfigSubentryData,
)
from homeassistant.const import CONF_URL
from homeassistant.core import HomeAssistant
from homeassistant.helpers import device_registry as dr, entity_registry as er, llm
@@ -21,7 +26,7 @@ from . import TEST_OPTIONS
from tests.common import MockConfigEntry
V1_TEST_USER_DATA = {
ollama.CONF_URL: "http://localhost:11434",
CONF_URL: "http://localhost:11434",
ollama.CONF_MODEL: "test_model:latest",
}
@@ -58,6 +63,74 @@ async def test_init_error(
assert error in caplog.text
@pytest.mark.parametrize("has_token", [True])
async def test_init_with_api_key(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None:
"""Test initialization with API key - Authorization header should be set."""
# Create entry with API key in data (version 3.0 after migration)
mock_config_entry.add_to_hass(hass)
with patch("homeassistant.components.ollama.ollama.AsyncClient") as mock_client:
mock_client.return_value.list = AsyncMock(return_value={"models": []})
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
assert any(
call.kwargs["headers"] == {"Authorization": "Bearer test_token"}
for call in mock_client.call_args_list
)
async def test_init_without_api_key(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> None:
"""Test initialization without API key - Authorization header should not be set."""
# Create entry without API key in data (version 3.0 after migration)
mock_config_entry.add_to_hass(hass)
with patch("homeassistant.components.ollama.ollama.AsyncClient") as mock_client:
mock_client.return_value.list = AsyncMock(return_value={"models": []})
assert await async_setup_component(hass, ollama.DOMAIN, {})
await hass.async_block_till_done()
assert all(
call.kwargs["headers"] is None for call in mock_client.call_args_list
)
@pytest.mark.parametrize(
("status_code", "entry_state"),
[
(401, ConfigEntryState.SETUP_ERROR),
(403, ConfigEntryState.SETUP_ERROR),
(500, ConfigEntryState.SETUP_RETRY),
(429, ConfigEntryState.SETUP_RETRY),
(400, ConfigEntryState.SETUP_ERROR),
],
)
async def test_async_setup_entry_auth_failed_on_response_error(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
status_code: int,
entry_state: ConfigEntryState,
) -> None:
"""Test async_setup_entry raises auth failed on 401/403 response."""
mock_config_entry.add_to_hass(hass)
with patch("homeassistant.components.ollama.ollama.AsyncClient") as mock_client:
mock_client.return_value.list = AsyncMock(
side_effect=ResponseError(error="Unauthorized", status_code=status_code)
)
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
assert mock_config_entry.state is entry_state
async def test_migration_from_v1(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
@@ -102,7 +175,7 @@ async def test_migration_from_v1(
assert mock_config_entry.version == 3
assert mock_config_entry.minor_version == 3
# After migration, parent entry should only have URL
assert mock_config_entry.data == {ollama.CONF_URL: "http://localhost:11434"}
assert mock_config_entry.data == {CONF_URL: "http://localhost:11434"}
assert mock_config_entry.options == {}
assert len(mock_config_entry.subentries) == 2
@@ -748,7 +821,7 @@ async def test_migration_from_v2_2(hass: HomeAssistant) -> None:
mock_config_entry = MockConfigEntry(
domain=DOMAIN,
data={
ollama.CONF_URL: "http://localhost:11434",
CONF_URL: "http://localhost:11434",
ollama.CONF_MODEL: "test_model:latest", # Model still in main data
},
version=2,
@@ -768,7 +841,7 @@ async def test_migration_from_v2_2(hass: HomeAssistant) -> None:
assert mock_config_entry.minor_version == 3
# Check that model was moved from main data to subentry
assert mock_config_entry.data == {ollama.CONF_URL: "http://localhost:11434"}
assert mock_config_entry.data == {CONF_URL: "http://localhost:11434"}
assert len(mock_config_entry.subentries) == 2
subentry = next(iter(mock_config_entry.subentries.values()))