mirror of
https://github.com/home-assistant/core.git
synced 2026-02-15 07:36:16 +00:00
Anthropic Structured Outputs support (#162515)
This commit is contained in:
@@ -491,22 +491,24 @@ class ConversationSubentryFlowHandler(ConfigSubentryFlow):
|
||||
"role": "user",
|
||||
"content": "Where are the following coordinates located: "
|
||||
f"({zone_home.attributes[ATTR_LATITUDE]},"
|
||||
f" {zone_home.attributes[ATTR_LONGITUDE]})? Please respond "
|
||||
"only with a JSON object using the following schema:\n"
|
||||
f"{convert(location_schema)}",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "{", # hints the model to skip any preamble
|
||||
},
|
||||
f" {zone_home.attributes[ATTR_LONGITUDE]})?",
|
||||
}
|
||||
],
|
||||
max_tokens=cast(int, DEFAULT[CONF_MAX_TOKENS]),
|
||||
output_config={
|
||||
"format": {
|
||||
"type": "json_schema",
|
||||
"schema": {
|
||||
**convert(location_schema),
|
||||
"additionalProperties": False,
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
_LOGGER.debug("Model response: %s", response.content)
|
||||
location_data = location_schema(
|
||||
json.loads(
|
||||
"{"
|
||||
+ "".join(
|
||||
"".join(
|
||||
block.text
|
||||
for block in response.content
|
||||
if isinstance(block, anthropic.types.TextBlock)
|
||||
|
||||
@@ -56,6 +56,15 @@ NON_ADAPTIVE_THINKING_MODELS = [
|
||||
"claude-3",
|
||||
]
|
||||
|
||||
UNSUPPORTED_STRUCTURED_OUTPUT_MODELS = [
|
||||
"claude-opus-4-1",
|
||||
"claude-opus-4-0",
|
||||
"claude-opus-4-20250514",
|
||||
"claude-sonnet-4-0",
|
||||
"claude-sonnet-4-20250514",
|
||||
"claude-3",
|
||||
]
|
||||
|
||||
WEB_SEARCH_UNSUPPORTED_MODELS = [
|
||||
"claude-3-haiku",
|
||||
"claude-3-opus",
|
||||
|
||||
@@ -20,6 +20,7 @@ from anthropic.types import (
|
||||
DocumentBlockParam,
|
||||
ImageBlockParam,
|
||||
InputJSONDelta,
|
||||
JSONOutputFormatParam,
|
||||
MessageDeltaUsage,
|
||||
MessageParam,
|
||||
MessageStreamEvent,
|
||||
@@ -94,6 +95,7 @@ from .const import (
|
||||
MIN_THINKING_BUDGET,
|
||||
NON_ADAPTIVE_THINKING_MODELS,
|
||||
NON_THINKING_MODELS,
|
||||
UNSUPPORTED_STRUCTURED_OUTPUT_MODELS,
|
||||
)
|
||||
|
||||
# Max number of back and forth with the LLM to generate a response
|
||||
@@ -697,8 +699,25 @@ class AnthropicBaseLLMEntity(Entity):
|
||||
)
|
||||
|
||||
if structure and structure_name:
|
||||
structure_name = slugify(structure_name)
|
||||
if model_args["thinking"]["type"] == "disabled":
|
||||
if not model.startswith(tuple(UNSUPPORTED_STRUCTURED_OUTPUT_MODELS)):
|
||||
# Native structured output for those models who support it.
|
||||
structure_name = None
|
||||
model_args.setdefault("output_config", OutputConfigParam())[
|
||||
"format"
|
||||
] = JSONOutputFormatParam(
|
||||
type="json_schema",
|
||||
schema={
|
||||
**convert(
|
||||
structure,
|
||||
custom_serializer=chat_log.llm_api.custom_serializer
|
||||
if chat_log.llm_api
|
||||
else llm.selector_serializer,
|
||||
),
|
||||
"additionalProperties": False,
|
||||
},
|
||||
)
|
||||
elif model_args["thinking"]["type"] == "disabled":
|
||||
structure_name = slugify(structure_name)
|
||||
if not tools:
|
||||
# Simplest case: no tools and no extended thinking
|
||||
# Add a tool and force its use
|
||||
@@ -718,6 +737,7 @@ class AnthropicBaseLLMEntity(Entity):
|
||||
# force tool use or disable text responses, so we add a hint to the
|
||||
# system prompt instead. With extended thinking, the model should be
|
||||
# smart enough to use the tool.
|
||||
structure_name = slugify(structure_name)
|
||||
model_args["tool_choice"] = ToolChoiceAutoParam(
|
||||
type="auto",
|
||||
)
|
||||
@@ -725,22 +745,24 @@ class AnthropicBaseLLMEntity(Entity):
|
||||
model_args["system"].append( # type: ignore[union-attr]
|
||||
TextBlockParam(
|
||||
type="text",
|
||||
text=f"Claude MUST use the '{structure_name}' tool to provide the final answer instead of plain text.",
|
||||
text=f"Claude MUST use the '{structure_name}' tool to provide "
|
||||
"the final answer instead of plain text.",
|
||||
)
|
||||
)
|
||||
|
||||
tools.append(
|
||||
ToolParam(
|
||||
name=structure_name,
|
||||
description="Use this tool to reply to the user",
|
||||
input_schema=convert(
|
||||
structure,
|
||||
custom_serializer=chat_log.llm_api.custom_serializer
|
||||
if chat_log.llm_api
|
||||
else llm.selector_serializer,
|
||||
),
|
||||
if structure_name:
|
||||
tools.append(
|
||||
ToolParam(
|
||||
name=structure_name,
|
||||
description="Use this tool to reply to the user",
|
||||
input_schema=convert(
|
||||
structure,
|
||||
custom_serializer=chat_log.llm_api.custom_serializer
|
||||
if chat_log.llm_api
|
||||
else llm.selector_serializer,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if tools:
|
||||
model_args["tools"] = tools
|
||||
@@ -761,7 +783,7 @@ class AnthropicBaseLLMEntity(Entity):
|
||||
_transform_stream(
|
||||
chat_log,
|
||||
stream,
|
||||
output_tool=structure_name if structure else None,
|
||||
output_tool=structure_name or None,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
@@ -141,6 +141,22 @@ def mock_config_entry_with_web_search(
|
||||
return mock_config_entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_entry_with_no_structured_output(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||
) -> MockConfigEntry:
|
||||
"""Mock a config entry with a model without structured outputs support."""
|
||||
for subentry in mock_config_entry.subentries.values():
|
||||
hass.config_entries.async_update_subentry(
|
||||
mock_config_entry,
|
||||
subentry,
|
||||
data={
|
||||
CONF_CHAT_MODEL: "claude-sonnet-4-0",
|
||||
},
|
||||
)
|
||||
return mock_config_entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_init_component(
|
||||
hass: HomeAssistant, mock_config_entry: MockConfigEntry
|
||||
|
||||
@@ -51,13 +51,13 @@ async def test_generate_data(
|
||||
assert result.data == "The test data"
|
||||
|
||||
|
||||
async def test_generate_structured_data(
|
||||
async def test_generate_structured_data_legacy(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_config_entry_with_no_structured_output: MockConfigEntry,
|
||||
mock_init_component,
|
||||
mock_create_stream: AsyncMock,
|
||||
) -> None:
|
||||
"""Test AI Task structured data generation."""
|
||||
"""Test AI Task structured data generation with legacy method."""
|
||||
mock_create_stream.return_value = [
|
||||
create_tool_use_block(
|
||||
1,
|
||||
@@ -88,13 +88,13 @@ async def test_generate_structured_data(
|
||||
assert result.data == {"characters": ["Mario", "Luigi"]}
|
||||
|
||||
|
||||
async def test_generate_invalid_structured_data(
|
||||
async def test_generate_invalid_structured_data_legacy(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_config_entry_with_no_structured_output: MockConfigEntry,
|
||||
mock_init_component,
|
||||
mock_create_stream: AsyncMock,
|
||||
) -> None:
|
||||
"""Test AI Task with invalid JSON response."""
|
||||
"""Test AI Task with invalid JSON response with legacy method."""
|
||||
mock_create_stream.return_value = [
|
||||
create_tool_use_block(
|
||||
1,
|
||||
@@ -126,6 +126,38 @@ async def test_generate_invalid_structured_data(
|
||||
)
|
||||
|
||||
|
||||
async def test_generate_structured_data(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_init_component,
|
||||
mock_create_stream: AsyncMock,
|
||||
) -> None:
|
||||
"""Test AI Task structured data generation."""
|
||||
mock_create_stream.return_value = [
|
||||
create_content_block(0, ['{"charac', 'ters": ["Mario', '", "Luigi"]}'])
|
||||
]
|
||||
|
||||
result = await ai_task.async_generate_data(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
entity_id="ai_task.claude_ai_task",
|
||||
instructions="Generate test data",
|
||||
structure=vol.Schema(
|
||||
{
|
||||
vol.Required("characters"): selector.selector(
|
||||
{
|
||||
"text": {
|
||||
"multiple": True,
|
||||
}
|
||||
}
|
||||
)
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
assert result.data == {"characters": ["Mario", "Luigi"]}
|
||||
|
||||
|
||||
async def test_generate_data_with_attachments(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
|
||||
@@ -296,7 +296,8 @@ async def test_subentry_web_search_user_location(
|
||||
usage=types.Usage(input_tokens=100, output_tokens=100),
|
||||
content=[
|
||||
types.TextBlock(
|
||||
type="text", text='"city": "San Francisco", "region": "California"}'
|
||||
type="text",
|
||||
text='{"city": "San Francisco", "region": "California"}',
|
||||
)
|
||||
],
|
||||
),
|
||||
@@ -313,12 +314,7 @@ async def test_subentry_web_search_user_location(
|
||||
|
||||
assert (
|
||||
mock_create.call_args.kwargs["messages"][0]["content"] == "Where are the "
|
||||
"following coordinates located: (37.7749, -122.4194)? Please respond only "
|
||||
"with a JSON object using the following schema:\n"
|
||||
"{'type': 'object', 'properties': {'city': {'type': 'string', 'description': "
|
||||
"'Free text input for the city, e.g. `San Francisco`'}, 'region': {'type': "
|
||||
"'string', 'description': 'Free text input for the region, e.g. `California`'"
|
||||
"}}, 'required': []}"
|
||||
"following coordinates located: (37.7749, -122.4194)?"
|
||||
)
|
||||
assert options["type"] is FlowResultType.ABORT
|
||||
assert options["reason"] == "reconfigure_successful"
|
||||
|
||||
Reference in New Issue
Block a user