1
0
mirror of https://github.com/home-assistant/core.git synced 2025-12-25 05:26:47 +00:00

Anthropic model selection from list (#156261)

This commit is contained in:
Denis Shulyaka
2025-11-15 05:16:52 +03:00
committed by GitHub
parent 275670a526
commit a06f4b6776
3 changed files with 201 additions and 2 deletions

View File

@@ -5,6 +5,7 @@ from __future__ import annotations
from functools import partial
import json
import logging
import re
from typing import Any
import anthropic
@@ -283,7 +284,11 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
vol.Optional(
CONF_CHAT_MODEL,
default=RECOMMENDED_CHAT_MODEL,
): str,
): SelectSelector(
SelectSelectorConfig(
options=await self._get_model_list(), custom_value=True
)
),
vol.Optional(
CONF_MAX_TOKENS,
default=RECOMMENDED_MAX_TOKENS,
@@ -394,6 +399,39 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
last_step=True,
)
async def _get_model_list(self) -> list[SelectOptionDict]:
"""Get list of available models."""
try:
client = await self.hass.async_add_executor_job(
partial(
anthropic.AsyncAnthropic,
api_key=self._get_entry().data[CONF_API_KEY],
)
)
models = (await client.models.list()).data
except anthropic.AnthropicError:
models = []
_LOGGER.debug("Available models: %s", models)
model_options: list[SelectOptionDict] = []
short_form = re.compile(r"[^\d]-\d$")
for model_info in models:
# Resolve alias from versioned model name:
model_alias = (
model_info.id[:-9]
if model_info.id
not in ("claude-3-haiku-20240307", "claude-3-opus-20240229")
else model_info.id
)
if short_form.search(model_alias):
model_alias += "-0"
model_options.append(
SelectOptionDict(
label=model_info.display_name,
value=model_alias,
)
)
return model_options
async def _get_location_data(self) -> dict[str, str]:
"""Get approximate location data of the user."""
location_data: dict[str, str] = {}

View File

@@ -1,11 +1,14 @@
"""Tests helpers."""
from collections.abc import AsyncGenerator, Generator, Iterable
import datetime
from unittest.mock import AsyncMock, patch
from anthropic.pagination import AsyncPage
from anthropic.types import (
Message,
MessageDeltaUsage,
ModelInfo,
RawContentBlockStartEvent,
RawMessageDeltaEvent,
RawMessageStartEvent,
@@ -123,7 +126,72 @@ async def mock_init_component(
hass: HomeAssistant, mock_config_entry: MockConfigEntry
) -> AsyncGenerator[None]:
"""Initialize integration."""
with patch("anthropic.resources.models.AsyncModels.retrieve"):
model_list = AsyncPage(
data=[
ModelInfo(
id="claude-haiku-4-5-20251001",
created_at=datetime.datetime(2025, 10, 15, 0, 0, tzinfo=datetime.UTC),
display_name="Claude Haiku 4.5",
type="model",
),
ModelInfo(
id="claude-sonnet-4-5-20250929",
created_at=datetime.datetime(2025, 9, 29, 0, 0, tzinfo=datetime.UTC),
display_name="Claude Sonnet 4.5",
type="model",
),
ModelInfo(
id="claude-opus-4-1-20250805",
created_at=datetime.datetime(2025, 8, 5, 0, 0, tzinfo=datetime.UTC),
display_name="Claude Opus 4.1",
type="model",
),
ModelInfo(
id="claude-opus-4-20250514",
created_at=datetime.datetime(2025, 5, 22, 0, 0, tzinfo=datetime.UTC),
display_name="Claude Opus 4",
type="model",
),
ModelInfo(
id="claude-sonnet-4-20250514",
created_at=datetime.datetime(2025, 5, 22, 0, 0, tzinfo=datetime.UTC),
display_name="Claude Sonnet 4",
type="model",
),
ModelInfo(
id="claude-3-7-sonnet-20250219",
created_at=datetime.datetime(2025, 2, 24, 0, 0, tzinfo=datetime.UTC),
display_name="Claude Sonnet 3.7",
type="model",
),
ModelInfo(
id="claude-3-5-haiku-20241022",
created_at=datetime.datetime(2024, 10, 22, 0, 0, tzinfo=datetime.UTC),
display_name="Claude Haiku 3.5",
type="model",
),
ModelInfo(
id="claude-3-haiku-20240307",
created_at=datetime.datetime(2024, 3, 7, 0, 0, tzinfo=datetime.UTC),
display_name="Claude Haiku 3",
type="model",
),
ModelInfo(
id="claude-3-opus-20240229",
created_at=datetime.datetime(2024, 2, 29, 0, 0, tzinfo=datetime.UTC),
display_name="Claude Opus 3",
type="model",
),
]
)
with (
patch("anthropic.resources.models.AsyncModels.retrieve"),
patch(
"anthropic.resources.models.AsyncModels.list",
new_callable=AsyncMock,
return_value=model_list,
),
):
assert await async_setup_component(hass, "anthropic", {})
await hass.async_block_till_done()
yield

View File

@@ -339,6 +339,99 @@ async def test_subentry_web_search_user_location(
}
async def test_model_list(
hass: HomeAssistant, mock_config_entry, mock_init_component
) -> None:
"""Test fetching and processing the list of models."""
subentry = next(iter(mock_config_entry.subentries.values()))
options_flow = await mock_config_entry.start_subentry_reconfigure_flow(
hass, subentry.subentry_id
)
# Configure initial step
options = await hass.config_entries.subentries.async_configure(
options_flow["flow_id"],
{
"prompt": "You are a helpful assistant",
"recommended": False,
},
)
assert options["type"] == FlowResultType.FORM
assert options["step_id"] == "advanced"
assert options["data_schema"].schema["chat_model"].config["options"] == [
{
"label": "Claude Haiku 4.5",
"value": "claude-haiku-4-5",
},
{
"label": "Claude Sonnet 4.5",
"value": "claude-sonnet-4-5",
},
{
"label": "Claude Opus 4.1",
"value": "claude-opus-4-1",
},
{
"label": "Claude Opus 4",
"value": "claude-opus-4-0",
},
{
"label": "Claude Sonnet 4",
"value": "claude-sonnet-4-0",
},
{
"label": "Claude Sonnet 3.7",
"value": "claude-3-7-sonnet",
},
{
"label": "Claude Haiku 3.5",
"value": "claude-3-5-haiku",
},
{
"label": "Claude Haiku 3",
"value": "claude-3-haiku-20240307",
},
{
"label": "Claude Opus 3",
"value": "claude-3-opus-20240229",
},
]
async def test_model_list_error(
hass: HomeAssistant, mock_config_entry, mock_init_component
) -> None:
"""Test exception handling during fetching the list of models."""
subentry = next(iter(mock_config_entry.subentries.values()))
options_flow = await mock_config_entry.start_subentry_reconfigure_flow(
hass, subentry.subentry_id
)
# Configure initial step
with patch(
"homeassistant.components.anthropic.config_flow.anthropic.resources.models.AsyncModels.list",
new_callable=AsyncMock,
side_effect=InternalServerError(
message=None,
response=Response(
status_code=500,
request=Request(method="POST", url=URL()),
),
body=None,
),
):
options = await hass.config_entries.subentries.async_configure(
options_flow["flow_id"],
{
"prompt": "You are a helpful assistant",
"recommended": False,
},
)
assert options["type"] == FlowResultType.FORM
assert options["step_id"] == "advanced"
assert options["data_schema"].schema["chat_model"].config["options"] == []
@pytest.mark.parametrize(
("current_options", "new_options", "expected_options"),
[