1
0
mirror of https://github.com/home-assistant/core.git synced 2025-12-20 10:59:24 +00:00

Anthropic: consolidate recommended values in a dict (#156787)

This commit is contained in:
Denis Shulyaka
2025-11-25 19:08:55 +03:00
committed by GitHub
parent d7ad0cba94
commit 252dbb706f
5 changed files with 60 additions and 68 deletions

View File

@@ -17,13 +17,7 @@ from homeassistant.helpers import (
) )
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from .const import ( from .const import CONF_CHAT_MODEL, DEFAULT, DEFAULT_CONVERSATION_NAME, DOMAIN, LOGGER
CONF_CHAT_MODEL,
DEFAULT_CONVERSATION_NAME,
DOMAIN,
LOGGER,
RECOMMENDED_CHAT_MODEL,
)
PLATFORMS = (Platform.AI_TASK, Platform.CONVERSATION) PLATFORMS = (Platform.AI_TASK, Platform.CONVERSATION)
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
@@ -46,9 +40,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: AnthropicConfigEntry) ->
# Use model from first conversation subentry for validation # Use model from first conversation subentry for validation
subentries = list(entry.subentries.values()) subentries = list(entry.subentries.values())
if subentries: if subentries:
model_id = subentries[0].data.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL) model_id = subentries[0].data.get(CONF_CHAT_MODEL, DEFAULT[CONF_CHAT_MODEL])
else: else:
model_id = RECOMMENDED_CHAT_MODEL model_id = DEFAULT[CONF_CHAT_MODEL]
model = await client.models.retrieve(model_id=model_id, timeout=10.0) model = await client.models.retrieve(model_id=model_id, timeout=10.0)
LOGGER.debug("Anthropic model: %s", model.display_name) LOGGER.debug("Anthropic model: %s", model.display_name)
except anthropic.AuthenticationError as err: except anthropic.AuthenticationError as err:

View File

@@ -6,7 +6,7 @@ from functools import partial
import json import json
import logging import logging
import re import re
from typing import Any from typing import Any, cast
import anthropic import anthropic
import voluptuous as vol import voluptuous as vol
@@ -54,17 +54,11 @@ from .const import (
CONF_WEB_SEARCH_REGION, CONF_WEB_SEARCH_REGION,
CONF_WEB_SEARCH_TIMEZONE, CONF_WEB_SEARCH_TIMEZONE,
CONF_WEB_SEARCH_USER_LOCATION, CONF_WEB_SEARCH_USER_LOCATION,
DEFAULT,
DEFAULT_AI_TASK_NAME, DEFAULT_AI_TASK_NAME,
DEFAULT_CONVERSATION_NAME, DEFAULT_CONVERSATION_NAME,
DOMAIN, DOMAIN,
NON_THINKING_MODELS, NON_THINKING_MODELS,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_TEMPERATURE,
RECOMMENDED_THINKING_BUDGET,
RECOMMENDED_WEB_SEARCH,
RECOMMENDED_WEB_SEARCH_MAX_USES,
RECOMMENDED_WEB_SEARCH_USER_LOCATION,
WEB_SEARCH_UNSUPPORTED_MODELS, WEB_SEARCH_UNSUPPORTED_MODELS,
) )
@@ -76,13 +70,13 @@ STEP_USER_DATA_SCHEMA = vol.Schema(
} }
) )
RECOMMENDED_CONVERSATION_OPTIONS = { DEFAULT_CONVERSATION_OPTIONS = {
CONF_RECOMMENDED: True, CONF_RECOMMENDED: True,
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST], CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT, CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
} }
RECOMMENDED_AI_TASK_OPTIONS = { DEFAULT_AI_TASK_OPTIONS = {
CONF_RECOMMENDED: True, CONF_RECOMMENDED: True,
} }
@@ -136,13 +130,13 @@ class AnthropicConfigFlow(ConfigFlow, domain=DOMAIN):
subentries=[ subentries=[
{ {
"subentry_type": "conversation", "subentry_type": "conversation",
"data": RECOMMENDED_CONVERSATION_OPTIONS, "data": DEFAULT_CONVERSATION_OPTIONS,
"title": DEFAULT_CONVERSATION_NAME, "title": DEFAULT_CONVERSATION_NAME,
"unique_id": None, "unique_id": None,
}, },
{ {
"subentry_type": "ai_task_data", "subentry_type": "ai_task_data",
"data": RECOMMENDED_AI_TASK_OPTIONS, "data": DEFAULT_AI_TASK_OPTIONS,
"title": DEFAULT_AI_TASK_NAME, "title": DEFAULT_AI_TASK_NAME,
"unique_id": None, "unique_id": None,
}, },
@@ -180,9 +174,9 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
) -> SubentryFlowResult: ) -> SubentryFlowResult:
"""Add a subentry.""" """Add a subentry."""
if self._subentry_type == "ai_task_data": if self._subentry_type == "ai_task_data":
self.options = RECOMMENDED_AI_TASK_OPTIONS.copy() self.options = DEFAULT_AI_TASK_OPTIONS.copy()
else: else:
self.options = RECOMMENDED_CONVERSATION_OPTIONS.copy() self.options = DEFAULT_CONVERSATION_OPTIONS.copy()
return await self.async_step_init() return await self.async_step_init()
async def async_step_reconfigure( async def async_step_reconfigure(
@@ -283,7 +277,7 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
step_schema: VolDictType = { step_schema: VolDictType = {
vol.Optional( vol.Optional(
CONF_CHAT_MODEL, CONF_CHAT_MODEL,
default=RECOMMENDED_CHAT_MODEL, default=DEFAULT[CONF_CHAT_MODEL],
): SelectSelector( ): SelectSelector(
SelectSelectorConfig( SelectSelectorConfig(
options=await self._get_model_list(), custom_value=True options=await self._get_model_list(), custom_value=True
@@ -291,11 +285,11 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
), ),
vol.Optional( vol.Optional(
CONF_MAX_TOKENS, CONF_MAX_TOKENS,
default=RECOMMENDED_MAX_TOKENS, default=DEFAULT[CONF_MAX_TOKENS],
): int, ): int,
vol.Optional( vol.Optional(
CONF_TEMPERATURE, CONF_TEMPERATURE,
default=RECOMMENDED_TEMPERATURE, default=DEFAULT[CONF_TEMPERATURE],
): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)), ): NumberSelector(NumberSelectorConfig(min=0, max=1, step=0.05)),
} }
@@ -325,12 +319,14 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
if not model.startswith(tuple(NON_THINKING_MODELS)): if not model.startswith(tuple(NON_THINKING_MODELS)):
step_schema[ step_schema[
vol.Optional(CONF_THINKING_BUDGET, default=RECOMMENDED_THINKING_BUDGET) vol.Optional(
CONF_THINKING_BUDGET, default=DEFAULT[CONF_THINKING_BUDGET]
)
] = vol.All( ] = vol.All(
NumberSelector( NumberSelector(
NumberSelectorConfig( NumberSelectorConfig(
min=0, min=0,
max=self.options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS), max=self.options.get(CONF_MAX_TOKENS, DEFAULT[CONF_MAX_TOKENS]),
) )
), ),
vol.Coerce(int), vol.Coerce(int),
@@ -343,15 +339,15 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
{ {
vol.Optional( vol.Optional(
CONF_WEB_SEARCH, CONF_WEB_SEARCH,
default=RECOMMENDED_WEB_SEARCH, default=DEFAULT[CONF_WEB_SEARCH],
): bool, ): bool,
vol.Optional( vol.Optional(
CONF_WEB_SEARCH_MAX_USES, CONF_WEB_SEARCH_MAX_USES,
default=RECOMMENDED_WEB_SEARCH_MAX_USES, default=DEFAULT[CONF_WEB_SEARCH_MAX_USES],
): int, ): int,
vol.Optional( vol.Optional(
CONF_WEB_SEARCH_USER_LOCATION, CONF_WEB_SEARCH_USER_LOCATION,
default=RECOMMENDED_WEB_SEARCH_USER_LOCATION, default=DEFAULT[CONF_WEB_SEARCH_USER_LOCATION],
): bool, ): bool,
} }
) )
@@ -369,9 +365,10 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
user_input = {} user_input = {}
if user_input is not None: if user_input is not None:
if user_input.get(CONF_WEB_SEARCH, RECOMMENDED_WEB_SEARCH) and not errors: if user_input.get(CONF_WEB_SEARCH, DEFAULT[CONF_WEB_SEARCH]) and not errors:
if user_input.get( if user_input.get(
CONF_WEB_SEARCH_USER_LOCATION, RECOMMENDED_WEB_SEARCH_USER_LOCATION CONF_WEB_SEARCH_USER_LOCATION,
DEFAULT[CONF_WEB_SEARCH_USER_LOCATION],
): ):
user_input.update(await self._get_location_data()) user_input.update(await self._get_location_data())
@@ -456,7 +453,7 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
} }
) )
response = await client.messages.create( response = await client.messages.create(
model=RECOMMENDED_CHAT_MODEL, model=cast(str, DEFAULT[CONF_CHAT_MODEL]),
messages=[ messages=[
{ {
"role": "user", "role": "user",
@@ -471,7 +468,7 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
"content": "{", # hints the model to skip any preamble "content": "{", # hints the model to skip any preamble
}, },
], ],
max_tokens=RECOMMENDED_MAX_TOKENS, max_tokens=cast(int, DEFAULT[CONF_MAX_TOKENS]),
) )
_LOGGER.debug("Model response: %s", response.content) _LOGGER.debug("Model response: %s", response.content)
location_data = location_schema( location_data = location_schema(

View File

@@ -11,25 +11,29 @@ DEFAULT_AI_TASK_NAME = "Claude AI Task"
CONF_RECOMMENDED = "recommended" CONF_RECOMMENDED = "recommended"
CONF_PROMPT = "prompt" CONF_PROMPT = "prompt"
CONF_CHAT_MODEL = "chat_model" CONF_CHAT_MODEL = "chat_model"
RECOMMENDED_CHAT_MODEL = "claude-3-5-haiku-latest"
CONF_MAX_TOKENS = "max_tokens" CONF_MAX_TOKENS = "max_tokens"
RECOMMENDED_MAX_TOKENS = 3000
CONF_TEMPERATURE = "temperature" CONF_TEMPERATURE = "temperature"
RECOMMENDED_TEMPERATURE = 1.0
CONF_THINKING_BUDGET = "thinking_budget" CONF_THINKING_BUDGET = "thinking_budget"
RECOMMENDED_THINKING_BUDGET = 0
MIN_THINKING_BUDGET = 1024
CONF_WEB_SEARCH = "web_search" CONF_WEB_SEARCH = "web_search"
RECOMMENDED_WEB_SEARCH = False
CONF_WEB_SEARCH_USER_LOCATION = "user_location" CONF_WEB_SEARCH_USER_LOCATION = "user_location"
RECOMMENDED_WEB_SEARCH_USER_LOCATION = False
CONF_WEB_SEARCH_MAX_USES = "web_search_max_uses" CONF_WEB_SEARCH_MAX_USES = "web_search_max_uses"
RECOMMENDED_WEB_SEARCH_MAX_USES = 5
CONF_WEB_SEARCH_CITY = "city" CONF_WEB_SEARCH_CITY = "city"
CONF_WEB_SEARCH_REGION = "region" CONF_WEB_SEARCH_REGION = "region"
CONF_WEB_SEARCH_COUNTRY = "country" CONF_WEB_SEARCH_COUNTRY = "country"
CONF_WEB_SEARCH_TIMEZONE = "timezone" CONF_WEB_SEARCH_TIMEZONE = "timezone"
DEFAULT = {
CONF_CHAT_MODEL: "claude-3-5-haiku-latest",
CONF_MAX_TOKENS: 3000,
CONF_TEMPERATURE: 1.0,
CONF_THINKING_BUDGET: 0,
CONF_WEB_SEARCH: False,
CONF_WEB_SEARCH_USER_LOCATION: False,
CONF_WEB_SEARCH_MAX_USES: 5,
}
MIN_THINKING_BUDGET = 1024
NON_THINKING_MODELS = [ NON_THINKING_MODELS = [
"claude-3-5", # Both sonnet and haiku "claude-3-5", # Both sonnet and haiku
"claude-3-opus", "claude-3-opus",

View File

@@ -84,14 +84,11 @@ from .const import (
CONF_WEB_SEARCH_REGION, CONF_WEB_SEARCH_REGION,
CONF_WEB_SEARCH_TIMEZONE, CONF_WEB_SEARCH_TIMEZONE,
CONF_WEB_SEARCH_USER_LOCATION, CONF_WEB_SEARCH_USER_LOCATION,
DEFAULT,
DOMAIN, DOMAIN,
LOGGER, LOGGER,
MIN_THINKING_BUDGET, MIN_THINKING_BUDGET,
NON_THINKING_MODELS, NON_THINKING_MODELS,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_TEMPERATURE,
RECOMMENDED_THINKING_BUDGET,
) )
# Max number of back and forth with the LLM to generate a response # Max number of back and forth with the LLM to generate a response
@@ -604,17 +601,19 @@ class AnthropicBaseLLMEntity(Entity):
raise TypeError("First message must be a system message") raise TypeError("First message must be a system message")
messages = _convert_content(chat_log.content[1:]) messages = _convert_content(chat_log.content[1:])
model = options.get(CONF_CHAT_MODEL, RECOMMENDED_CHAT_MODEL) model = options.get(CONF_CHAT_MODEL, DEFAULT[CONF_CHAT_MODEL])
model_args = MessageCreateParamsStreaming( model_args = MessageCreateParamsStreaming(
model=model, model=model,
messages=messages, messages=messages,
max_tokens=options.get(CONF_MAX_TOKENS, RECOMMENDED_MAX_TOKENS), max_tokens=options.get(CONF_MAX_TOKENS, DEFAULT[CONF_MAX_TOKENS]),
system=system.content, system=system.content,
stream=True, stream=True,
) )
thinking_budget = options.get(CONF_THINKING_BUDGET, RECOMMENDED_THINKING_BUDGET) thinking_budget = options.get(
CONF_THINKING_BUDGET, DEFAULT[CONF_THINKING_BUDGET]
)
if ( if (
not model.startswith(tuple(NON_THINKING_MODELS)) not model.startswith(tuple(NON_THINKING_MODELS))
and thinking_budget >= MIN_THINKING_BUDGET and thinking_budget >= MIN_THINKING_BUDGET
@@ -625,7 +624,7 @@ class AnthropicBaseLLMEntity(Entity):
else: else:
model_args["thinking"] = ThinkingConfigDisabledParam(type="disabled") model_args["thinking"] = ThinkingConfigDisabledParam(type="disabled")
model_args["temperature"] = options.get( model_args["temperature"] = options.get(
CONF_TEMPERATURE, RECOMMENDED_TEMPERATURE CONF_TEMPERATURE, DEFAULT[CONF_TEMPERATURE]
) )
tools: list[ToolUnionParam] = [] tools: list[ToolUnionParam] = []

View File

@@ -16,8 +16,8 @@ import pytest
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components.anthropic.config_flow import ( from homeassistant.components.anthropic.config_flow import (
RECOMMENDED_AI_TASK_OPTIONS, DEFAULT_AI_TASK_OPTIONS,
RECOMMENDED_CONVERSATION_OPTIONS, DEFAULT_CONVERSATION_OPTIONS,
) )
from homeassistant.components.anthropic.const import ( from homeassistant.components.anthropic.const import (
CONF_CHAT_MODEL, CONF_CHAT_MODEL,
@@ -33,12 +33,10 @@ from homeassistant.components.anthropic.const import (
CONF_WEB_SEARCH_REGION, CONF_WEB_SEARCH_REGION,
CONF_WEB_SEARCH_TIMEZONE, CONF_WEB_SEARCH_TIMEZONE,
CONF_WEB_SEARCH_USER_LOCATION, CONF_WEB_SEARCH_USER_LOCATION,
DEFAULT,
DEFAULT_AI_TASK_NAME, DEFAULT_AI_TASK_NAME,
DEFAULT_CONVERSATION_NAME, DEFAULT_CONVERSATION_NAME,
DOMAIN, DOMAIN,
RECOMMENDED_CHAT_MODEL,
RECOMMENDED_MAX_TOKENS,
RECOMMENDED_THINKING_BUDGET,
) )
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_NAME from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_NAME
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@@ -87,13 +85,13 @@ async def test_form(hass: HomeAssistant) -> None:
assert result2["subentries"] == [ assert result2["subentries"] == [
{ {
"subentry_type": "conversation", "subentry_type": "conversation",
"data": RECOMMENDED_CONVERSATION_OPTIONS, "data": DEFAULT_CONVERSATION_OPTIONS,
"title": DEFAULT_CONVERSATION_NAME, "title": DEFAULT_CONVERSATION_NAME,
"unique_id": None, "unique_id": None,
}, },
{ {
"subentry_type": "ai_task_data", "subentry_type": "ai_task_data",
"data": RECOMMENDED_AI_TASK_OPTIONS, "data": DEFAULT_AI_TASK_OPTIONS,
"title": DEFAULT_AI_TASK_NAME, "title": DEFAULT_AI_TASK_NAME,
"unique_id": None, "unique_id": None,
}, },
@@ -144,13 +142,13 @@ async def test_creating_conversation_subentry(
result2 = await hass.config_entries.subentries.async_configure( result2 = await hass.config_entries.subentries.async_configure(
result["flow_id"], result["flow_id"],
{CONF_NAME: "Mock name", **RECOMMENDED_CONVERSATION_OPTIONS}, {CONF_NAME: "Mock name", **DEFAULT_CONVERSATION_OPTIONS},
) )
assert result2["type"] is FlowResultType.CREATE_ENTRY assert result2["type"] is FlowResultType.CREATE_ENTRY
assert result2["title"] == "Mock name" assert result2["title"] == "Mock name"
processed_options = RECOMMENDED_CONVERSATION_OPTIONS.copy() processed_options = DEFAULT_CONVERSATION_OPTIONS.copy()
processed_options[CONF_PROMPT] = processed_options[CONF_PROMPT].strip() processed_options[CONF_PROMPT] = processed_options[CONF_PROMPT].strip()
assert result2["data"] == processed_options assert result2["data"] == processed_options
@@ -475,7 +473,7 @@ async def test_model_list_error(
CONF_PROMPT: "Speak like a pirate", CONF_PROMPT: "Speak like a pirate",
CONF_TEMPERATURE: 1.0, CONF_TEMPERATURE: 1.0,
CONF_CHAT_MODEL: "claude-3-opus", CONF_CHAT_MODEL: "claude-3-opus",
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS, CONF_MAX_TOKENS: DEFAULT[CONF_MAX_TOKENS],
}, },
), ),
( # Model with web search options ( # Model with web search options
@@ -512,7 +510,7 @@ async def test_model_list_error(
CONF_PROMPT: "Speak like a pirate", CONF_PROMPT: "Speak like a pirate",
CONF_TEMPERATURE: 1.0, CONF_TEMPERATURE: 1.0,
CONF_CHAT_MODEL: "claude-3-5-haiku-latest", CONF_CHAT_MODEL: "claude-3-5-haiku-latest",
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS, CONF_MAX_TOKENS: DEFAULT[CONF_MAX_TOKENS],
CONF_WEB_SEARCH: False, CONF_WEB_SEARCH: False,
CONF_WEB_SEARCH_MAX_USES: 10, CONF_WEB_SEARCH_MAX_USES: 10,
CONF_WEB_SEARCH_USER_LOCATION: False, CONF_WEB_SEARCH_USER_LOCATION: False,
@@ -550,7 +548,7 @@ async def test_model_list_error(
CONF_PROMPT: "Speak like a pirate", CONF_PROMPT: "Speak like a pirate",
CONF_TEMPERATURE: 1.0, CONF_TEMPERATURE: 1.0,
CONF_CHAT_MODEL: "claude-sonnet-4-5", CONF_CHAT_MODEL: "claude-sonnet-4-5",
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS, CONF_MAX_TOKENS: DEFAULT[CONF_MAX_TOKENS],
CONF_THINKING_BUDGET: 2048, CONF_THINKING_BUDGET: 2048,
CONF_WEB_SEARCH: False, CONF_WEB_SEARCH: False,
CONF_WEB_SEARCH_MAX_USES: 10, CONF_WEB_SEARCH_MAX_USES: 10,
@@ -577,8 +575,8 @@ async def test_model_list_error(
CONF_RECOMMENDED: False, CONF_RECOMMENDED: False,
CONF_PROMPT: "Speak like a pirate", CONF_PROMPT: "Speak like a pirate",
CONF_TEMPERATURE: 0.3, CONF_TEMPERATURE: 0.3,
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL, CONF_CHAT_MODEL: DEFAULT[CONF_CHAT_MODEL],
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS, CONF_MAX_TOKENS: DEFAULT[CONF_MAX_TOKENS],
CONF_WEB_SEARCH: False, CONF_WEB_SEARCH: False,
CONF_WEB_SEARCH_MAX_USES: 5, CONF_WEB_SEARCH_MAX_USES: 5,
CONF_WEB_SEARCH_USER_LOCATION: False, CONF_WEB_SEARCH_USER_LOCATION: False,
@@ -589,9 +587,9 @@ async def test_model_list_error(
CONF_RECOMMENDED: False, CONF_RECOMMENDED: False,
CONF_PROMPT: "Speak like a pirate", CONF_PROMPT: "Speak like a pirate",
CONF_TEMPERATURE: 0.3, CONF_TEMPERATURE: 0.3,
CONF_CHAT_MODEL: RECOMMENDED_CHAT_MODEL, CONF_CHAT_MODEL: DEFAULT[CONF_CHAT_MODEL],
CONF_MAX_TOKENS: RECOMMENDED_MAX_TOKENS, CONF_MAX_TOKENS: DEFAULT[CONF_MAX_TOKENS],
CONF_THINKING_BUDGET: RECOMMENDED_THINKING_BUDGET, CONF_THINKING_BUDGET: DEFAULT[CONF_THINKING_BUDGET],
CONF_WEB_SEARCH: False, CONF_WEB_SEARCH: False,
CONF_WEB_SEARCH_MAX_USES: 5, CONF_WEB_SEARCH_MAX_USES: 5,
CONF_WEB_SEARCH_USER_LOCATION: False, CONF_WEB_SEARCH_USER_LOCATION: False,