mirror of
https://github.com/home-assistant/core.git
synced 2025-12-24 12:59:34 +00:00
Allow non strict response_format structures for Cloud LLM generation (#157822)
This commit is contained in:
committed by
Franck Nijhof
parent
062366966b
commit
0d26d22986
@@ -561,7 +561,7 @@ class BaseCloudLLMEntity(Entity):
|
||||
"schema": _format_structured_output(
|
||||
structure, chat_log.llm_api
|
||||
),
|
||||
"strict": True,
|
||||
"strict": False,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import pytest
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.components.cloud.const import AI_TASK_ENTITY_UNIQUE_ID, DOMAIN
|
||||
from homeassistant.components.cloud.entity import (
|
||||
BaseCloudLLMEntity,
|
||||
_convert_content_to_param,
|
||||
@@ -18,7 +19,8 @@ from homeassistant.components.cloud.entity import (
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import llm, selector
|
||||
from homeassistant.helpers import entity_registry as er, llm, selector
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
@@ -219,3 +221,66 @@ async def test_prepare_chat_for_generation_passes_messages_through(
|
||||
|
||||
assert response["messages"] == messages
|
||||
assert response["conversation_id"] == "conversation-id"
|
||||
|
||||
|
||||
async def test_async_handle_chat_log_service_sets_structured_output_non_strict(
|
||||
hass: HomeAssistant,
|
||||
cloud: MagicMock,
|
||||
entity_registry: er.EntityRegistry,
|
||||
mock_cloud_login: None,
|
||||
) -> None:
|
||||
"""Ensure structured output requests always disable strict validation via service."""
|
||||
assert await async_setup_component(hass, DOMAIN, {})
|
||||
await hass.async_block_till_done()
|
||||
|
||||
on_start_callback = cloud.register_on_start.call_args[0][0]
|
||||
await on_start_callback()
|
||||
await hass.async_block_till_done()
|
||||
|
||||
entity_id = entity_registry.async_get_entity_id(
|
||||
"ai_task", DOMAIN, AI_TASK_ENTITY_UNIQUE_ID
|
||||
)
|
||||
assert entity_id is not None
|
||||
|
||||
async def _empty_stream():
|
||||
return
|
||||
|
||||
async def _fake_delta_stream(
|
||||
self: conversation.ChatLog,
|
||||
agent_id: str,
|
||||
stream,
|
||||
):
|
||||
content = conversation.AssistantContent(
|
||||
agent_id=agent_id, content='{"value": "ok"}'
|
||||
)
|
||||
self.async_add_assistant_content_without_tools(content)
|
||||
yield content
|
||||
|
||||
cloud.llm.async_generate_data = AsyncMock(return_value=_empty_stream())
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.conversation.chat_log.ChatLog.async_add_delta_content_stream",
|
||||
_fake_delta_stream,
|
||||
):
|
||||
await hass.services.async_call(
|
||||
"ai_task",
|
||||
"generate_data",
|
||||
{
|
||||
"entity_id": entity_id,
|
||||
"task_name": "Device Report",
|
||||
"instructions": "Provide value.",
|
||||
"structure": {
|
||||
"value": {
|
||||
"selector": {"text": None},
|
||||
"required": True,
|
||||
}
|
||||
},
|
||||
},
|
||||
blocking=True,
|
||||
return_response=True,
|
||||
)
|
||||
|
||||
cloud.llm.async_generate_data.assert_awaited_once()
|
||||
_, kwargs = cloud.llm.async_generate_data.call_args
|
||||
|
||||
assert kwargs["response_format"]["json_schema"]["strict"] is False
|
||||
|
||||
Reference in New Issue
Block a user