mirror of
https://github.com/home-assistant/core.git
synced 2026-04-02 08:26:41 +01:00
Add bearer token as optional setting to Ollama (#165325)
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> Co-authored-by: Joostlek <joostlek@outlook.com>
This commit is contained in:
@@ -10,9 +10,13 @@ import httpx
|
||||
import ollama
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigSubentry
|
||||
from homeassistant.const import CONF_URL, Platform
|
||||
from homeassistant.const import CONF_API_KEY, CONF_URL, Platform
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import ConfigEntryNotReady
|
||||
from homeassistant.exceptions import (
|
||||
ConfigEntryAuthFailed,
|
||||
ConfigEntryError,
|
||||
ConfigEntryNotReady,
|
||||
)
|
||||
from homeassistant.helpers import (
|
||||
config_validation as cv,
|
||||
device_registry as dr,
|
||||
@@ -62,10 +66,28 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: OllamaConfigEntry) -> bool:
|
||||
"""Set up Ollama from a config entry."""
|
||||
settings = {**entry.data, **entry.options}
|
||||
client = ollama.AsyncClient(host=settings[CONF_URL], verify=get_default_context())
|
||||
api_key = settings.get(CONF_API_KEY)
|
||||
stripped_api_key = api_key.strip() if isinstance(api_key, str) else None
|
||||
client = ollama.AsyncClient(
|
||||
host=settings[CONF_URL],
|
||||
headers=(
|
||||
{"Authorization": f"Bearer {stripped_api_key}"}
|
||||
if stripped_api_key
|
||||
else None
|
||||
),
|
||||
verify=get_default_context(),
|
||||
)
|
||||
try:
|
||||
async with asyncio.timeout(DEFAULT_TIMEOUT):
|
||||
await client.list()
|
||||
except ollama.ResponseError as err:
|
||||
if err.status_code in (401, 403):
|
||||
raise ConfigEntryAuthFailed from err
|
||||
if err.status_code >= 500 or err.status_code == 429:
|
||||
raise ConfigEntryNotReady(err) from err
|
||||
# If the response is a 4xx error other than 401 or 403, it likely means the URL is valid but not an Ollama instance,
|
||||
# so we raise ConfigEntryError to show an error in the UI, instead of ConfigEntryNotReady which would just keep retrying.
|
||||
raise ConfigEntryError(err) from err
|
||||
except (TimeoutError, httpx.ConnectError) as err:
|
||||
raise ConfigEntryNotReady(err) from err
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from homeassistant.config_entries import (
|
||||
ConfigSubentryFlow,
|
||||
SubentryFlowResult,
|
||||
)
|
||||
from homeassistant.const import CONF_LLM_HASS_API, CONF_NAME, CONF_URL
|
||||
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_NAME, CONF_URL
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import config_validation as cv, llm
|
||||
from homeassistant.helpers.selector import (
|
||||
@@ -68,6 +68,17 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
|
||||
vol.Required(CONF_URL): TextSelector(
|
||||
TextSelectorConfig(type=TextSelectorType.URL)
|
||||
),
|
||||
vol.Optional(CONF_API_KEY): TextSelector(
|
||||
TextSelectorConfig(type=TextSelectorType.PASSWORD)
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
STEP_REAUTH_DATA_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Optional(CONF_API_KEY): TextSelector(
|
||||
TextSelectorConfig(type=TextSelectorType.PASSWORD)
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -78,9 +89,40 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
VERSION = 3
|
||||
MINOR_VERSION = 3
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize config flow."""
|
||||
self.url: str | None = None
|
||||
async def _async_validate_connection(
|
||||
self, url: str, api_key: str | None
|
||||
) -> dict[str, str]:
|
||||
"""Validate connection and credentials against the Ollama server."""
|
||||
errors: dict[str, str] = {}
|
||||
|
||||
try:
|
||||
client = ollama.AsyncClient(
|
||||
host=url,
|
||||
headers={"Authorization": f"Bearer {api_key}"} if api_key else None,
|
||||
verify=get_default_context(),
|
||||
)
|
||||
|
||||
async with asyncio.timeout(DEFAULT_TIMEOUT):
|
||||
await client.list()
|
||||
|
||||
except ollama.ResponseError as err:
|
||||
if err.status_code in (401, 403):
|
||||
errors["base"] = "invalid_auth"
|
||||
else:
|
||||
_LOGGER.warning(
|
||||
"Error response from Ollama server at %s: status %s, detail: %s",
|
||||
url,
|
||||
err.status_code,
|
||||
str(err),
|
||||
)
|
||||
errors["base"] = "unknown"
|
||||
except TimeoutError, httpx.ConnectError:
|
||||
errors["base"] = "cannot_connect"
|
||||
except Exception:
|
||||
_LOGGER.exception("Unexpected exception")
|
||||
errors["base"] = "unknown"
|
||||
|
||||
return errors
|
||||
|
||||
async def async_step_user(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
@@ -92,9 +134,10 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
)
|
||||
|
||||
errors = {}
|
||||
url = user_input[CONF_URL]
|
||||
|
||||
self._async_abort_entries_match({CONF_URL: url})
|
||||
url = user_input[CONF_URL].strip()
|
||||
api_key = user_input.get(CONF_API_KEY)
|
||||
if api_key:
|
||||
api_key = api_key.strip()
|
||||
|
||||
try:
|
||||
url = cv.url(url)
|
||||
@@ -108,15 +151,8 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
try:
|
||||
client = ollama.AsyncClient(host=url, verify=get_default_context())
|
||||
async with asyncio.timeout(DEFAULT_TIMEOUT):
|
||||
await client.list()
|
||||
except TimeoutError, httpx.ConnectError:
|
||||
errors["base"] = "cannot_connect"
|
||||
except Exception:
|
||||
_LOGGER.exception("Unexpected exception")
|
||||
errors["base"] = "unknown"
|
||||
self._async_abort_entries_match({CONF_URL: url})
|
||||
errors = await self._async_validate_connection(url, api_key)
|
||||
|
||||
if errors:
|
||||
return self.async_show_form(
|
||||
@@ -127,9 +163,65 @@ class OllamaConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
return self.async_create_entry(
|
||||
title=url,
|
||||
data={CONF_URL: url},
|
||||
entry_data: dict[str, str] = {CONF_URL: url}
|
||||
if api_key:
|
||||
entry_data[CONF_API_KEY] = api_key
|
||||
|
||||
return self.async_create_entry(title=url, data=entry_data)
|
||||
|
||||
async def async_step_reauth(
|
||||
self, entry_data: Mapping[str, Any]
|
||||
) -> ConfigFlowResult:
|
||||
"""Handle reauthentication when existing credentials are invalid."""
|
||||
return await self.async_step_reauth_confirm()
|
||||
|
||||
async def async_step_reauth_confirm(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> ConfigFlowResult:
|
||||
"""Handle reauthentication confirmation."""
|
||||
reauth_entry = self._get_reauth_entry()
|
||||
|
||||
if user_input is None:
|
||||
return self.async_show_form(
|
||||
step_id="reauth_confirm",
|
||||
data_schema=STEP_REAUTH_DATA_SCHEMA,
|
||||
)
|
||||
|
||||
api_key = user_input.get(CONF_API_KEY)
|
||||
if api_key:
|
||||
api_key = api_key.strip()
|
||||
|
||||
errors = await self._async_validate_connection(
|
||||
reauth_entry.data[CONF_URL], api_key
|
||||
)
|
||||
if errors:
|
||||
return self.async_show_form(
|
||||
step_id="reauth_confirm",
|
||||
data_schema=self.add_suggested_values_to_schema(
|
||||
STEP_REAUTH_DATA_SCHEMA, user_input
|
||||
),
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
updated_data = {
|
||||
**reauth_entry.data,
|
||||
CONF_URL: reauth_entry.data[CONF_URL],
|
||||
}
|
||||
if api_key:
|
||||
updated_data[CONF_API_KEY] = api_key
|
||||
else:
|
||||
updated_data.pop(CONF_API_KEY, None)
|
||||
|
||||
updated_options = {
|
||||
key: value
|
||||
for key, value in reauth_entry.options.items()
|
||||
if key != CONF_API_KEY
|
||||
}
|
||||
|
||||
return self.async_update_reload_and_abort(
|
||||
reauth_entry,
|
||||
data=updated_data,
|
||||
options=updated_options,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,16 +1,26 @@
|
||||
{
|
||||
"config": {
|
||||
"abort": {
|
||||
"already_configured": "[%key:common::config_flow::abort::already_configured_service%]"
|
||||
"already_configured": "[%key:common::config_flow::abort::already_configured_service%]",
|
||||
"reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]"
|
||||
},
|
||||
"error": {
|
||||
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
|
||||
"invalid_auth": "[%key:common::config_flow::error::invalid_auth%]",
|
||||
"invalid_url": "[%key:common::config_flow::error::invalid_host%]",
|
||||
"unknown": "[%key:common::config_flow::error::unknown%]"
|
||||
},
|
||||
"step": {
|
||||
"reauth_confirm": {
|
||||
"data": {
|
||||
"api_key": "[%key:common::config_flow::data::api_key%]"
|
||||
},
|
||||
"description": "The Ollama integration needs to re-authenticate with your Ollama API key.",
|
||||
"title": "[%key:common::config_flow::title::reauth%]"
|
||||
},
|
||||
"user": {
|
||||
"data": {
|
||||
"api_key": "[%key:common::config_flow::data::api_key%]",
|
||||
"url": "[%key:common::config_flow::data::url%]"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
"""Tests Ollama integration."""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import ollama
|
||||
from homeassistant.const import CONF_LLM_HASS_API
|
||||
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.setup import async_setup_component
|
||||
@@ -22,14 +23,31 @@ def mock_config_entry_options() -> dict[str, Any]:
|
||||
return TEST_OPTIONS
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def has_token() -> bool:
|
||||
"""Fixture to indicate if the config entry has a token."""
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_entry_data(has_token: bool) -> dict[str, Any]:
|
||||
"""Fixture for configuration entry data."""
|
||||
res = deepcopy(TEST_USER_DATA)
|
||||
if has_token:
|
||||
res[CONF_API_KEY] = "test_token"
|
||||
return res
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_entry(
|
||||
hass: HomeAssistant, mock_config_entry_options: dict[str, Any]
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry_options: dict[str, Any],
|
||||
mock_config_entry_data: dict[str, Any],
|
||||
) -> MockConfigEntry:
|
||||
"""Mock a config entry."""
|
||||
entry = MockConfigEntry(
|
||||
domain=ollama.DOMAIN,
|
||||
data=TEST_USER_DATA,
|
||||
data=mock_config_entry_data,
|
||||
version=3,
|
||||
minor_version=2,
|
||||
subentries_data=[
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
"""Test the Ollama config flow."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import ANY, AsyncMock, patch
|
||||
|
||||
from httpx import ConnectError
|
||||
from ollama import ResponseError
|
||||
import pytest
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.components import ollama
|
||||
from homeassistant.const import CONF_LLM_HASS_API, CONF_NAME
|
||||
from homeassistant.components.ollama.const import DOMAIN
|
||||
from homeassistant.config_entries import SOURCE_USER
|
||||
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_NAME, CONF_URL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.data_entry_flow import FlowResultType
|
||||
|
||||
@@ -20,14 +23,14 @@ TEST_MODEL = "test_model:latest"
|
||||
async def test_form(hass: HomeAssistant) -> None:
|
||||
"""Test flow when configuring URL only."""
|
||||
# Pretend we already set up a config entry.
|
||||
hass.config.components.add(ollama.DOMAIN)
|
||||
hass.config.components.add(DOMAIN)
|
||||
MockConfigEntry(
|
||||
domain=ollama.DOMAIN,
|
||||
domain=DOMAIN,
|
||||
state=config_entries.ConfigEntryState.LOADED,
|
||||
).add_to_hass(hass)
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
ollama.DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
DOMAIN, context={"source": SOURCE_USER}
|
||||
)
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["errors"] is None
|
||||
@@ -48,18 +51,18 @@ async def test_form(hass: HomeAssistant) -> None:
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result2["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result2["data"] == {
|
||||
ollama.CONF_URL: "http://localhost:11434",
|
||||
}
|
||||
assert result2["data"] == {ollama.CONF_URL: "http://localhost:11434"}
|
||||
|
||||
# No subentries created by default
|
||||
assert len(result2.get("subentries", [])) == 0
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
assert CONF_API_KEY not in result2["data"]
|
||||
|
||||
|
||||
async def test_duplicate_entry(hass: HomeAssistant) -> None:
|
||||
"""Test we abort on duplicate config entry."""
|
||||
MockConfigEntry(
|
||||
domain=ollama.DOMAIN,
|
||||
domain=DOMAIN,
|
||||
data={
|
||||
ollama.CONF_URL: "http://localhost:11434",
|
||||
ollama.CONF_MODEL: "test_model",
|
||||
@@ -67,7 +70,7 @@ async def test_duplicate_entry(hass: HomeAssistant) -> None:
|
||||
).add_to_hass(hass)
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
ollama.DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
DOMAIN, context={"source": SOURCE_USER}
|
||||
)
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert not result["errors"]
|
||||
@@ -141,7 +144,7 @@ async def test_creating_new_conversation_subentry(
|
||||
):
|
||||
new_flow = await hass.config_entries.subentries.async_init(
|
||||
(mock_config_entry.entry_id, "conversation"),
|
||||
context={"source": config_entries.SOURCE_USER},
|
||||
context={"source": SOURCE_USER},
|
||||
)
|
||||
|
||||
assert new_flow["type"] is FlowResultType.FORM
|
||||
@@ -181,7 +184,7 @@ async def test_creating_conversation_subentry_not_loaded(
|
||||
await hass.config_entries.async_unload(mock_config_entry.entry_id)
|
||||
result = await hass.config_entries.subentries.async_init(
|
||||
(mock_config_entry.entry_id, "conversation"),
|
||||
context={"source": config_entries.SOURCE_USER},
|
||||
context={"source": SOURCE_USER},
|
||||
)
|
||||
|
||||
assert result["type"] is FlowResultType.ABORT
|
||||
@@ -209,7 +212,7 @@ async def test_subentry_need_download(
|
||||
):
|
||||
new_flow = await hass.config_entries.subentries.async_init(
|
||||
(mock_config_entry.entry_id, "conversation"),
|
||||
context={"source": config_entries.SOURCE_USER},
|
||||
context={"source": SOURCE_USER},
|
||||
)
|
||||
|
||||
assert new_flow["type"] is FlowResultType.FORM, new_flow
|
||||
@@ -271,7 +274,7 @@ async def test_subentry_download_error(
|
||||
):
|
||||
new_flow = await hass.config_entries.subentries.async_init(
|
||||
(mock_config_entry.entry_id, "conversation"),
|
||||
context={"source": config_entries.SOURCE_USER},
|
||||
context={"source": SOURCE_USER},
|
||||
)
|
||||
|
||||
assert new_flow["type"] is FlowResultType.FORM
|
||||
@@ -307,6 +310,130 @@ async def test_subentry_download_error(
|
||||
assert result["reason"] == "download_failed"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("init_data", "input_data", "expected_data"),
|
||||
[
|
||||
(
|
||||
{
|
||||
CONF_URL: "http://localhost:11434",
|
||||
CONF_API_KEY: "old-api-key",
|
||||
},
|
||||
{
|
||||
CONF_API_KEY: "new-api-key",
|
||||
},
|
||||
{
|
||||
CONF_URL: "http://localhost:11434",
|
||||
CONF_API_KEY: "new-api-key",
|
||||
},
|
||||
),
|
||||
(
|
||||
{
|
||||
CONF_URL: "http://localhost:11434",
|
||||
CONF_API_KEY: "old-api-key",
|
||||
},
|
||||
{
|
||||
# Reconfigure without api_key to test that it gets removed from data
|
||||
},
|
||||
{
|
||||
CONF_URL: "http://localhost:11434",
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_reauth_flow_success(
|
||||
hass: HomeAssistant, init_data, input_data, expected_data
|
||||
) -> None:
|
||||
"""Test successful reauthentication flow."""
|
||||
entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data=init_data,
|
||||
options={CONF_API_KEY: "stale-options-api-key"},
|
||||
version=3,
|
||||
minor_version=3,
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
result = await entry.start_reauth_flow(hass)
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["step_id"] == "reauth_confirm"
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.list",
|
||||
return_value={"models": [{"model": TEST_MODEL}]},
|
||||
):
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
input_data,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result["type"] is FlowResultType.ABORT
|
||||
assert result["reason"] == "reauth_successful"
|
||||
|
||||
assert entry.data == expected_data
|
||||
assert entry.options == {}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("side_effect", "error"),
|
||||
[
|
||||
(ResponseError(error="Unauthorized", status_code=401), "invalid_auth"),
|
||||
(ConnectError(message="Connection failed"), "cannot_connect"),
|
||||
],
|
||||
)
|
||||
async def test_reauth_flow_errors(hass: HomeAssistant, side_effect, error) -> None:
|
||||
"""Test reauthentication flow when authentication fails."""
|
||||
entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data={
|
||||
CONF_URL: "http://localhost:11434",
|
||||
CONF_API_KEY: "old-api-key",
|
||||
},
|
||||
version=3,
|
||||
minor_version=3,
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
result = await entry.start_reauth_flow(hass)
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["step_id"] == "reauth_confirm"
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.list",
|
||||
side_effect=side_effect,
|
||||
):
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{
|
||||
CONF_API_KEY: "other-api-key",
|
||||
},
|
||||
)
|
||||
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["step_id"] == "reauth_confirm"
|
||||
assert result["errors"] == {"base": error}
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.list",
|
||||
return_value={"models": [{"model": TEST_MODEL}]},
|
||||
):
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{
|
||||
CONF_API_KEY: "new-api-key",
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result["type"] is FlowResultType.ABORT
|
||||
assert result["reason"] == "reauth_successful"
|
||||
|
||||
assert entry.data == {
|
||||
CONF_URL: "http://localhost:11434",
|
||||
CONF_API_KEY: "new-api-key",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("side_effect", "error"),
|
||||
[
|
||||
@@ -317,7 +444,7 @@ async def test_subentry_download_error(
|
||||
async def test_form_errors(hass: HomeAssistant, side_effect, error) -> None:
|
||||
"""Test we handle errors."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
ollama.DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
DOMAIN, context={"source": SOURCE_USER}
|
||||
)
|
||||
|
||||
with patch(
|
||||
@@ -332,10 +459,50 @@ async def test_form_errors(hass: HomeAssistant, side_effect, error) -> None:
|
||||
assert result2["errors"] == {"base": error}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("side_effect", "error"),
|
||||
[
|
||||
(ConnectError(message=""), "cannot_connect"),
|
||||
(RuntimeError(), "unknown"),
|
||||
],
|
||||
)
|
||||
async def test_form_errors_recovery(hass: HomeAssistant, side_effect, error) -> None:
|
||||
"""Test that the user flow recovers after an error and completes successfully."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": SOURCE_USER}
|
||||
)
|
||||
|
||||
# First attempt fails
|
||||
with patch(
|
||||
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.list",
|
||||
side_effect=side_effect,
|
||||
):
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], {ollama.CONF_URL: "http://localhost:11434"}
|
||||
)
|
||||
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["errors"] == {"base": error}
|
||||
|
||||
# Second attempt succeeds
|
||||
with patch(
|
||||
"homeassistant.components.ollama.config_flow.ollama.AsyncClient.list",
|
||||
return_value={"models": [{"model": TEST_MODEL}]},
|
||||
):
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{ollama.CONF_URL: "http://localhost:11434"},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result["data"] == {ollama.CONF_URL: "http://localhost:11434"}
|
||||
|
||||
|
||||
async def test_form_invalid_url(hass: HomeAssistant) -> None:
|
||||
"""Test we handle invalid URL."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
ollama.DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
DOMAIN, context={"source": SOURCE_USER}
|
||||
)
|
||||
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
@@ -358,7 +525,7 @@ async def test_subentry_connection_error(
|
||||
):
|
||||
new_flow = await hass.config_entries.subentries.async_init(
|
||||
(mock_config_entry.entry_id, "conversation"),
|
||||
context={"source": config_entries.SOURCE_USER},
|
||||
context={"source": SOURCE_USER},
|
||||
)
|
||||
|
||||
assert new_flow["type"] is FlowResultType.ABORT
|
||||
@@ -380,7 +547,7 @@ async def test_subentry_model_check_exception(
|
||||
):
|
||||
new_flow = await hass.config_entries.subentries.async_init(
|
||||
(mock_config_entry.entry_id, "conversation"),
|
||||
context={"source": config_entries.SOURCE_USER},
|
||||
context={"source": SOURCE_USER},
|
||||
)
|
||||
|
||||
assert new_flow["type"] is FlowResultType.FORM
|
||||
@@ -500,7 +667,7 @@ async def test_creating_ai_task_subentry(
|
||||
):
|
||||
result = await hass.config_entries.subentries.async_init(
|
||||
(mock_config_entry.entry_id, "ai_task_data"),
|
||||
context={"source": config_entries.SOURCE_USER},
|
||||
context={"source": SOURCE_USER},
|
||||
)
|
||||
|
||||
assert result.get("type") is FlowResultType.FORM
|
||||
@@ -552,8 +719,167 @@ async def test_ai_task_subentry_not_loaded(
|
||||
# Don't call mock_init_component to simulate not loaded state
|
||||
result = await hass.config_entries.subentries.async_init(
|
||||
(mock_config_entry.entry_id, "ai_task_data"),
|
||||
context={"source": config_entries.SOURCE_USER},
|
||||
context={"source": SOURCE_USER},
|
||||
)
|
||||
|
||||
assert result.get("type") is FlowResultType.ABORT
|
||||
assert result.get("reason") == "entry_not_loaded"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("user_input", "expected_headers", "expected_data"),
|
||||
[
|
||||
(
|
||||
{CONF_URL: "http://localhost:11434", CONF_API_KEY: "my-secret-token"},
|
||||
{"Authorization": "Bearer my-secret-token"},
|
||||
{CONF_URL: "http://localhost:11434", CONF_API_KEY: "my-secret-token"},
|
||||
),
|
||||
(
|
||||
{CONF_URL: "http://localhost:11434", CONF_API_KEY: ""},
|
||||
None,
|
||||
{CONF_URL: "http://localhost:11434"},
|
||||
),
|
||||
(
|
||||
{CONF_URL: "http://localhost:11434", CONF_API_KEY: " "},
|
||||
None,
|
||||
{CONF_URL: "http://localhost:11434"},
|
||||
),
|
||||
(
|
||||
{CONF_URL: "http://localhost:11434"},
|
||||
None,
|
||||
{CONF_URL: "http://localhost:11434"},
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_user_step_async_client_headers(
|
||||
hass: HomeAssistant,
|
||||
user_input: dict[str, str],
|
||||
expected_headers: dict[str, str] | None,
|
||||
expected_data: dict[str, str],
|
||||
) -> None:
|
||||
"""Test Authorization header passed to AsyncClient with/without api_key."""
|
||||
with patch(
|
||||
"homeassistant.components.ollama.config_flow.ollama.AsyncClient",
|
||||
) as mock_async_client:
|
||||
mock_async_client.return_value.list = AsyncMock(return_value={"models": []})
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": SOURCE_USER}
|
||||
)
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
user_input=user_input,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result["data"] == expected_data
|
||||
mock_async_client.assert_called_with(
|
||||
host="http://localhost:11434",
|
||||
headers=expected_headers,
|
||||
verify=ANY,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("status_code", "error", "error_message", "user_input"),
|
||||
[
|
||||
(
|
||||
400,
|
||||
"unknown",
|
||||
"Bad Request",
|
||||
{
|
||||
CONF_URL: "http://localhost:11434",
|
||||
CONF_API_KEY: "my-secret-token",
|
||||
},
|
||||
),
|
||||
(
|
||||
401,
|
||||
"invalid_auth",
|
||||
"Unauthorized",
|
||||
{
|
||||
CONF_URL: "http://localhost:11434",
|
||||
CONF_API_KEY: "my-secret-token",
|
||||
},
|
||||
),
|
||||
(
|
||||
403,
|
||||
"invalid_auth",
|
||||
"Unauthorized",
|
||||
{
|
||||
CONF_URL: "http://localhost:11434",
|
||||
CONF_API_KEY: "my-secret-token",
|
||||
},
|
||||
),
|
||||
(
|
||||
403,
|
||||
"invalid_auth",
|
||||
"Forbidden",
|
||||
{
|
||||
CONF_URL: "http://localhost:11434",
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_user_step_errors(
|
||||
hass: HomeAssistant,
|
||||
status_code: int,
|
||||
error: str,
|
||||
error_message: str,
|
||||
user_input: dict[str, str],
|
||||
) -> None:
|
||||
"""Test error handling when ollama returns HTTP 4xx."""
|
||||
with patch(
|
||||
"homeassistant.components.ollama.config_flow.ollama.AsyncClient"
|
||||
) as mock_async_client:
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_async_client.return_value = mock_client_instance
|
||||
|
||||
mock_client_instance.list.side_effect = ResponseError(
|
||||
error=error_message, status_code=status_code
|
||||
)
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": SOURCE_USER}
|
||||
)
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
user_input=user_input,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result.get("errors") == {"base": error}
|
||||
|
||||
|
||||
async def test_user_step_trim_url(hass: HomeAssistant) -> None:
|
||||
"""Test URL is trimmed before validation and persistence."""
|
||||
with patch(
|
||||
"homeassistant.components.ollama.config_flow.ollama.AsyncClient",
|
||||
) as mock_async_client:
|
||||
mock_async_client.return_value.list = AsyncMock(return_value={"models": []})
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": SOURCE_USER}
|
||||
)
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
user_input={
|
||||
CONF_URL: " http://localhost:11434 ",
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result["data"] == {CONF_URL: "http://localhost:11434"}
|
||||
mock_async_client.assert_called_with(
|
||||
host="http://localhost:11434",
|
||||
headers=None,
|
||||
verify=ANY,
|
||||
)
|
||||
|
||||
@@ -1,14 +1,19 @@
|
||||
"""Tests for the Ollama integration."""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from httpx import ConnectError
|
||||
from ollama import ResponseError
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import ollama
|
||||
from homeassistant.components.ollama.const import DOMAIN
|
||||
from homeassistant.config_entries import ConfigEntryDisabler, ConfigSubentryData
|
||||
from homeassistant.config_entries import (
|
||||
ConfigEntryDisabler,
|
||||
ConfigEntryState,
|
||||
ConfigSubentryData,
|
||||
)
|
||||
from homeassistant.const import CONF_URL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import device_registry as dr, entity_registry as er, llm
|
||||
@@ -21,7 +26,7 @@ from . import TEST_OPTIONS
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
V1_TEST_USER_DATA = {
|
||||
ollama.CONF_URL: "http://localhost:11434",
|
||||
CONF_URL: "http://localhost:11434",
|
||||
ollama.CONF_MODEL: "test_model:latest",
|
||||
}
|
||||
|
||||
@@ -58,6 +63,74 @@ async def test_init_error(
|
||||
assert error in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.parametrize("has_token", [True])
|
||||
async def test_init_with_api_key(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||
) -> None:
|
||||
"""Test initialization with API key - Authorization header should be set."""
|
||||
# Create entry with API key in data (version 3.0 after migration)
|
||||
mock_config_entry.add_to_hass(hass)
|
||||
|
||||
with patch("homeassistant.components.ollama.ollama.AsyncClient") as mock_client:
|
||||
mock_client.return_value.list = AsyncMock(return_value={"models": []})
|
||||
|
||||
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert any(
|
||||
call.kwargs["headers"] == {"Authorization": "Bearer test_token"}
|
||||
for call in mock_client.call_args_list
|
||||
)
|
||||
|
||||
|
||||
async def test_init_without_api_key(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||
) -> None:
|
||||
"""Test initialization without API key - Authorization header should not be set."""
|
||||
# Create entry without API key in data (version 3.0 after migration)
|
||||
mock_config_entry.add_to_hass(hass)
|
||||
|
||||
with patch("homeassistant.components.ollama.ollama.AsyncClient") as mock_client:
|
||||
mock_client.return_value.list = AsyncMock(return_value={"models": []})
|
||||
|
||||
assert await async_setup_component(hass, ollama.DOMAIN, {})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert all(
|
||||
call.kwargs["headers"] is None for call in mock_client.call_args_list
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("status_code", "entry_state"),
|
||||
[
|
||||
(401, ConfigEntryState.SETUP_ERROR),
|
||||
(403, ConfigEntryState.SETUP_ERROR),
|
||||
(500, ConfigEntryState.SETUP_RETRY),
|
||||
(429, ConfigEntryState.SETUP_RETRY),
|
||||
(400, ConfigEntryState.SETUP_ERROR),
|
||||
],
|
||||
)
|
||||
async def test_async_setup_entry_auth_failed_on_response_error(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
status_code: int,
|
||||
entry_state: ConfigEntryState,
|
||||
) -> None:
|
||||
"""Test async_setup_entry raises auth failed on 401/403 response."""
|
||||
mock_config_entry.add_to_hass(hass)
|
||||
|
||||
with patch("homeassistant.components.ollama.ollama.AsyncClient") as mock_client:
|
||||
mock_client.return_value.list = AsyncMock(
|
||||
side_effect=ResponseError(error="Unauthorized", status_code=status_code)
|
||||
)
|
||||
|
||||
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert mock_config_entry.state is entry_state
|
||||
|
||||
|
||||
async def test_migration_from_v1(
|
||||
hass: HomeAssistant,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
@@ -102,7 +175,7 @@ async def test_migration_from_v1(
|
||||
assert mock_config_entry.version == 3
|
||||
assert mock_config_entry.minor_version == 3
|
||||
# After migration, parent entry should only have URL
|
||||
assert mock_config_entry.data == {ollama.CONF_URL: "http://localhost:11434"}
|
||||
assert mock_config_entry.data == {CONF_URL: "http://localhost:11434"}
|
||||
assert mock_config_entry.options == {}
|
||||
|
||||
assert len(mock_config_entry.subentries) == 2
|
||||
@@ -748,7 +821,7 @@ async def test_migration_from_v2_2(hass: HomeAssistant) -> None:
|
||||
mock_config_entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data={
|
||||
ollama.CONF_URL: "http://localhost:11434",
|
||||
CONF_URL: "http://localhost:11434",
|
||||
ollama.CONF_MODEL: "test_model:latest", # Model still in main data
|
||||
},
|
||||
version=2,
|
||||
@@ -768,7 +841,7 @@ async def test_migration_from_v2_2(hass: HomeAssistant) -> None:
|
||||
assert mock_config_entry.minor_version == 3
|
||||
|
||||
# Check that model was moved from main data to subentry
|
||||
assert mock_config_entry.data == {ollama.CONF_URL: "http://localhost:11434"}
|
||||
assert mock_config_entry.data == {CONF_URL: "http://localhost:11434"}
|
||||
assert len(mock_config_entry.subentries) == 2
|
||||
|
||||
subentry = next(iter(mock_config_entry.subentries.values()))
|
||||
|
||||
Reference in New Issue
Block a user