mirror of
https://github.com/home-assistant/core.git
synced 2026-04-02 00:20:30 +01:00
OpenRouter: Add WebSearch Support (#164293)
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> Co-authored-by: Joostlek <joostlek@outlook.com>
This commit is contained in:
4
CODEOWNERS
generated
4
CODEOWNERS
generated
@@ -1232,8 +1232,8 @@ build.json @home-assistant/supervisor
|
||||
/tests/components/onvif/ @jterrace
|
||||
/homeassistant/components/open_meteo/ @frenck
|
||||
/tests/components/open_meteo/ @frenck
|
||||
/homeassistant/components/open_router/ @joostlek
|
||||
/tests/components/open_router/ @joostlek
|
||||
/homeassistant/components/open_router/ @joostlek @ab3lson
|
||||
/tests/components/open_router/ @joostlek @ab3lson
|
||||
/homeassistant/components/opendisplay/ @g4bri3lDev
|
||||
/tests/components/opendisplay/ @g4bri3lDev
|
||||
/homeassistant/components/openerz/ @misialq
|
||||
|
||||
@@ -10,7 +10,7 @@ from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import ConfigEntryError, ConfigEntryNotReady
|
||||
from homeassistant.helpers.httpx_client import get_async_client
|
||||
|
||||
from .const import LOGGER
|
||||
from .const import CONF_WEB_SEARCH, LOGGER
|
||||
|
||||
PLATFORMS = [Platform.AI_TASK, Platform.CONVERSATION]
|
||||
|
||||
@@ -56,3 +56,32 @@ async def _async_update_listener(
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: OpenRouterConfigEntry) -> bool:
|
||||
"""Unload OpenRouter."""
|
||||
return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
|
||||
|
||||
|
||||
async def async_migrate_entry(
|
||||
hass: HomeAssistant, entry: OpenRouterConfigEntry
|
||||
) -> bool:
|
||||
"""Migrate config entry."""
|
||||
LOGGER.debug("Migrating from version %s.%s", entry.version, entry.minor_version)
|
||||
|
||||
if entry.version > 1 or (entry.version == 1 and entry.minor_version > 2):
|
||||
return False
|
||||
|
||||
if entry.version == 1 and entry.minor_version < 2:
|
||||
for subentry in entry.subentries.values():
|
||||
if CONF_WEB_SEARCH in subentry.data:
|
||||
continue
|
||||
|
||||
updated_data = {**subentry.data, CONF_WEB_SEARCH: False}
|
||||
|
||||
hass.config_entries.async_update_subentry(
|
||||
entry, subentry, data=updated_data
|
||||
)
|
||||
|
||||
hass.config_entries.async_update_entry(entry, minor_version=2)
|
||||
|
||||
LOGGER.info(
|
||||
"Migration to version %s.%s successful", entry.version, entry.minor_version
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@@ -27,6 +27,7 @@ from homeassistant.core import callback
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from homeassistant.helpers.selector import (
|
||||
BooleanSelector,
|
||||
SelectOptionDict,
|
||||
SelectSelector,
|
||||
SelectSelectorConfig,
|
||||
@@ -34,7 +35,12 @@ from homeassistant.helpers.selector import (
|
||||
TemplateSelector,
|
||||
)
|
||||
|
||||
from .const import CONF_PROMPT, DOMAIN, RECOMMENDED_CONVERSATION_OPTIONS
|
||||
from .const import (
|
||||
CONF_PROMPT,
|
||||
CONF_WEB_SEARCH,
|
||||
DOMAIN,
|
||||
RECOMMENDED_CONVERSATION_OPTIONS,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -43,6 +49,7 @@ class OpenRouterConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
"""Handle a config flow for OpenRouter."""
|
||||
|
||||
VERSION = 1
|
||||
MINOR_VERSION = 2
|
||||
|
||||
@classmethod
|
||||
@callback
|
||||
@@ -66,7 +73,7 @@ class OpenRouterConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
user_input[CONF_API_KEY], async_get_clientsession(self.hass)
|
||||
)
|
||||
try:
|
||||
await client.get_key_data()
|
||||
key_data = await client.get_key_data()
|
||||
except OpenRouterError:
|
||||
errors["base"] = "cannot_connect"
|
||||
except Exception:
|
||||
@@ -74,7 +81,7 @@ class OpenRouterConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
errors["base"] = "unknown"
|
||||
else:
|
||||
return self.async_create_entry(
|
||||
title="OpenRouter",
|
||||
title=key_data.label,
|
||||
data=user_input,
|
||||
)
|
||||
return self.async_show_form(
|
||||
@@ -106,7 +113,7 @@ class OpenRouterSubentryFlowHandler(ConfigSubentryFlow):
|
||||
|
||||
|
||||
class ConversationFlowHandler(OpenRouterSubentryFlowHandler):
|
||||
"""Handle subentry flow."""
|
||||
"""Handle conversation subentry flow."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the subentry flow."""
|
||||
@@ -208,13 +215,20 @@ class ConversationFlowHandler(OpenRouterSubentryFlowHandler):
|
||||
): SelectSelector(
|
||||
SelectSelectorConfig(options=hass_apis, multiple=True)
|
||||
),
|
||||
vol.Optional(
|
||||
CONF_WEB_SEARCH,
|
||||
default=self.options.get(
|
||||
CONF_WEB_SEARCH,
|
||||
RECOMMENDED_CONVERSATION_OPTIONS[CONF_WEB_SEARCH],
|
||||
),
|
||||
): BooleanSelector(),
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class AITaskDataFlowHandler(OpenRouterSubentryFlowHandler):
|
||||
"""Handle subentry flow."""
|
||||
"""Handle AI task subentry flow."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the subentry flow."""
|
||||
|
||||
@@ -9,9 +9,13 @@ DOMAIN = "open_router"
|
||||
LOGGER = logging.getLogger(__package__)
|
||||
|
||||
CONF_RECOMMENDED = "recommended"
|
||||
CONF_WEB_SEARCH = "web_search"
|
||||
|
||||
RECOMMENDED_WEB_SEARCH = False
|
||||
|
||||
RECOMMENDED_CONVERSATION_OPTIONS = {
|
||||
CONF_RECOMMENDED: True,
|
||||
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
|
||||
CONF_PROMPT: llm.DEFAULT_INSTRUCTIONS_PROMPT,
|
||||
CONF_WEB_SEARCH: RECOMMENDED_WEB_SEARCH,
|
||||
}
|
||||
|
||||
@@ -37,9 +37,8 @@ from homeassistant.helpers.entity import Entity
|
||||
from homeassistant.helpers.json import json_dumps
|
||||
|
||||
from . import OpenRouterConfigEntry
|
||||
from .const import DOMAIN, LOGGER
|
||||
from .const import CONF_WEB_SEARCH, DOMAIN, LOGGER
|
||||
|
||||
# Max number of back and forth with the LLM to generate a response
|
||||
MAX_TOOL_ITERATIONS = 10
|
||||
|
||||
|
||||
@@ -52,7 +51,6 @@ def _adjust_schema(schema: dict[str, Any]) -> None:
|
||||
if "required" not in schema:
|
||||
schema["required"] = []
|
||||
|
||||
# Ensure all properties are required
|
||||
for prop, prop_info in schema["properties"].items():
|
||||
_adjust_schema(prop_info)
|
||||
if prop not in schema["required"]:
|
||||
@@ -233,14 +231,20 @@ class OpenRouterEntity(Entity):
|
||||
) -> None:
|
||||
"""Generate an answer for the chat log."""
|
||||
|
||||
model = self.model
|
||||
if self.subentry.data.get(CONF_WEB_SEARCH):
|
||||
model = f"{model}:online"
|
||||
|
||||
extra_body: dict[str, Any] = {"require_parameters": True}
|
||||
|
||||
model_args = {
|
||||
"model": self.model,
|
||||
"model": model,
|
||||
"user": chat_log.conversation_id,
|
||||
"extra_headers": {
|
||||
"X-Title": "Home Assistant",
|
||||
"HTTP-Referer": "https://www.home-assistant.io/integrations/open_router",
|
||||
},
|
||||
"extra_body": {"require_parameters": True},
|
||||
"extra_body": extra_body,
|
||||
}
|
||||
|
||||
tools: list[ChatCompletionFunctionToolParam] | None = None
|
||||
@@ -296,6 +300,10 @@ class OpenRouterEntity(Entity):
|
||||
LOGGER.error("Error talking to API: %s", err)
|
||||
raise HomeAssistantError("Error talking to API") from err
|
||||
|
||||
if not result.choices:
|
||||
LOGGER.error("API returned empty choices")
|
||||
raise HomeAssistantError("API returned empty response")
|
||||
|
||||
result_message = result.choices[0].message
|
||||
|
||||
model_args["messages"].extend(
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
"domain": "open_router",
|
||||
"name": "OpenRouter",
|
||||
"after_dependencies": ["assist_pipeline", "intent"],
|
||||
"codeowners": ["@joostlek"],
|
||||
"codeowners": ["@joostlek", "@ab3lson"],
|
||||
"config_flow": true,
|
||||
"dependencies": ["conversation"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/open_router",
|
||||
|
||||
@@ -23,19 +23,18 @@
|
||||
"abort": {
|
||||
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
|
||||
"entry_not_loaded": "The main integration entry is not loaded. Please ensure the integration is loaded before reconfiguring.",
|
||||
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]",
|
||||
"unknown": "[%key:common::config_flow::error::unknown%]"
|
||||
},
|
||||
"entry_type": "AI task",
|
||||
"initiate_flow": {
|
||||
"reconfigure": "Reconfigure AI task",
|
||||
"user": "Add AI task"
|
||||
},
|
||||
"step": {
|
||||
"init": {
|
||||
"data": {
|
||||
"model": "[%key:component::open_router::config_subentries::conversation::step::init::data::model%]"
|
||||
},
|
||||
"data_description": {
|
||||
"model": "The model to use for the AI task"
|
||||
"model": "[%key:common::generic::model%]"
|
||||
},
|
||||
"description": "Configure the AI task"
|
||||
}
|
||||
@@ -45,22 +44,27 @@
|
||||
"abort": {
|
||||
"cannot_connect": "[%key:common::config_flow::error::cannot_connect%]",
|
||||
"entry_not_loaded": "[%key:component::open_router::config_subentries::ai_task_data::abort::entry_not_loaded%]",
|
||||
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]",
|
||||
"unknown": "[%key:common::config_flow::error::unknown%]"
|
||||
},
|
||||
"entry_type": "Conversation agent",
|
||||
"initiate_flow": {
|
||||
"reconfigure": "Reconfigure conversation agent",
|
||||
"user": "Add conversation agent"
|
||||
},
|
||||
"step": {
|
||||
"init": {
|
||||
"data": {
|
||||
"llm_hass_api": "[%key:common::config_flow::data::llm_hass_api%]",
|
||||
"model": "Model",
|
||||
"prompt": "[%key:common::config_flow::data::prompt%]"
|
||||
"model": "[%key:common::generic::model%]",
|
||||
"prompt": "[%key:common::config_flow::data::prompt%]",
|
||||
"web_search": "Enable web search"
|
||||
},
|
||||
"data_description": {
|
||||
"llm_hass_api": "Select which tools the model can use to interact with your devices and entities.",
|
||||
"model": "The model to use for the conversation agent",
|
||||
"prompt": "Instruct how the LLM should respond. This can be a template."
|
||||
"prompt": "Instruct how the LLM should respond. This can be a template.",
|
||||
"web_search": "Allow the model to search the web for answers"
|
||||
},
|
||||
"description": "Configure the conversation agent"
|
||||
}
|
||||
|
||||
@@ -9,9 +9,13 @@ from openai.types import CompletionUsage
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionMessage
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
import pytest
|
||||
from python_open_router import ModelsDataWrapper
|
||||
from python_open_router import KeyData, ModelsDataWrapper
|
||||
|
||||
from homeassistant.components.open_router.const import CONF_PROMPT, DOMAIN
|
||||
from homeassistant.components.open_router.const import (
|
||||
CONF_PROMPT,
|
||||
CONF_WEB_SEARCH,
|
||||
DOMAIN,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigSubentryData
|
||||
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_MODEL
|
||||
from homeassistant.core import HomeAssistant
|
||||
@@ -38,11 +42,18 @@ def enable_assist() -> bool:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def conversation_subentry_data(enable_assist: bool) -> dict[str, Any]:
|
||||
def web_search() -> bool:
|
||||
"""Mock web search setting."""
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def conversation_subentry_data(enable_assist: bool, web_search: bool) -> dict[str, Any]:
|
||||
"""Mock conversation subentry data."""
|
||||
res: dict[str, Any] = {
|
||||
CONF_MODEL: "openai/gpt-3.5-turbo",
|
||||
CONF_PROMPT: "You are a helpful assistant.",
|
||||
CONF_WEB_SEARCH: web_search,
|
||||
}
|
||||
if enable_assist:
|
||||
res[CONF_LLM_HASS_API] = [llm.LLM_API_ASSIST]
|
||||
@@ -137,6 +148,13 @@ async def mock_open_router_client(hass: HomeAssistant) -> AsyncGenerator[AsyncMo
|
||||
autospec=True,
|
||||
) as mock_client:
|
||||
client = mock_client.return_value
|
||||
client.get_key_data.return_value = KeyData(
|
||||
label="Test account",
|
||||
usage=0,
|
||||
is_provisioning_key=False,
|
||||
limit_remaining=None,
|
||||
is_free_tier=True,
|
||||
)
|
||||
models = await async_load_fixture(hass, "models.json", DOMAIN)
|
||||
client.get_models.return_value = ModelsDataWrapper.from_json(models).data
|
||||
yield client
|
||||
|
||||
@@ -211,6 +211,35 @@ async def test_generate_invalid_structured_data(
|
||||
)
|
||||
|
||||
|
||||
async def test_generate_data_empty_response(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_openai_client: AsyncMock,
|
||||
) -> None:
|
||||
"""Test AI Task raises HomeAssistantError when API returns empty choices."""
|
||||
await setup_integration(hass, mock_config_entry)
|
||||
|
||||
mock_openai_client.chat.completions.create = AsyncMock(
|
||||
return_value=ChatCompletion(
|
||||
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
|
||||
choices=[],
|
||||
created=1700000000,
|
||||
model="x-ai/grok-3",
|
||||
object="chat.completion",
|
||||
system_fingerprint=None,
|
||||
usage=CompletionUsage(completion_tokens=0, prompt_tokens=8, total_tokens=8),
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(HomeAssistantError, match="API returned empty response"):
|
||||
await ai_task.async_generate_data(
|
||||
hass,
|
||||
task_name="Test Task",
|
||||
entity_id="ai_task.gemini_1_5_pro",
|
||||
instructions="Generate test data",
|
||||
)
|
||||
|
||||
|
||||
async def test_generate_data_with_attachments(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
|
||||
@@ -5,7 +5,11 @@ from unittest.mock import AsyncMock
|
||||
import pytest
|
||||
from python_open_router import OpenRouterError
|
||||
|
||||
from homeassistant.components.open_router.const import CONF_PROMPT, DOMAIN
|
||||
from homeassistant.components.open_router.const import (
|
||||
CONF_PROMPT,
|
||||
CONF_WEB_SEARCH,
|
||||
DOMAIN,
|
||||
)
|
||||
from homeassistant.config_entries import SOURCE_USER
|
||||
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_MODEL
|
||||
from homeassistant.core import HomeAssistant
|
||||
@@ -35,9 +39,33 @@ async def test_full_flow(
|
||||
)
|
||||
|
||||
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result["title"] == "Test account"
|
||||
assert result["data"] == {CONF_API_KEY: "bla"}
|
||||
|
||||
|
||||
async def test_second_account(
|
||||
hass: HomeAssistant,
|
||||
mock_open_router_client: AsyncMock,
|
||||
mock_setup_entry: AsyncMock,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test that a second account with a different API key can be added."""
|
||||
mock_config_entry.add_to_hass(hass)
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": SOURCE_USER}
|
||||
)
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{CONF_API_KEY: "different_key"},
|
||||
)
|
||||
|
||||
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result["title"] == "Test account"
|
||||
assert result["data"] == {CONF_API_KEY: "different_key"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("exception", "error"),
|
||||
[
|
||||
@@ -131,6 +159,7 @@ async def test_create_conversation_agent(
|
||||
CONF_MODEL: "openai/gpt-3.5-turbo",
|
||||
CONF_PROMPT: "you are an assistant",
|
||||
CONF_LLM_HASS_API: ["assist"],
|
||||
CONF_WEB_SEARCH: False,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -139,6 +168,7 @@ async def test_create_conversation_agent(
|
||||
CONF_MODEL: "openai/gpt-3.5-turbo",
|
||||
CONF_PROMPT: "you are an assistant",
|
||||
CONF_LLM_HASS_API: ["assist"],
|
||||
CONF_WEB_SEARCH: False,
|
||||
}
|
||||
|
||||
|
||||
@@ -170,6 +200,7 @@ async def test_create_conversation_agent_no_control(
|
||||
CONF_MODEL: "openai/gpt-3.5-turbo",
|
||||
CONF_PROMPT: "you are an assistant",
|
||||
CONF_LLM_HASS_API: [],
|
||||
CONF_WEB_SEARCH: False,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -177,6 +208,7 @@ async def test_create_conversation_agent_no_control(
|
||||
assert result["data"] == {
|
||||
CONF_MODEL: "openai/gpt-3.5-turbo",
|
||||
CONF_PROMPT: "you are an assistant",
|
||||
CONF_WEB_SEARCH: False,
|
||||
}
|
||||
|
||||
|
||||
@@ -263,12 +295,19 @@ async def test_reconfigure_conversation_agent(
|
||||
CONF_MODEL: "openai/gpt-4",
|
||||
CONF_PROMPT: "updated prompt",
|
||||
CONF_LLM_HASS_API: ["assist"],
|
||||
CONF_WEB_SEARCH: True,
|
||||
},
|
||||
)
|
||||
|
||||
assert result["type"] is FlowResultType.ABORT
|
||||
assert result["reason"] == "reconfigure_successful"
|
||||
|
||||
subentry = mock_config_entry.subentries[subentry_id]
|
||||
assert subentry.data[CONF_MODEL] == "openai/gpt-4"
|
||||
assert subentry.data[CONF_PROMPT] == "updated prompt"
|
||||
assert subentry.data[CONF_LLM_HASS_API] == ["assist"]
|
||||
assert subentry.data[CONF_WEB_SEARCH] is True
|
||||
|
||||
|
||||
async def test_reconfigure_ai_task(
|
||||
hass: HomeAssistant,
|
||||
@@ -367,6 +406,83 @@ async def test_reconfigure_ai_task_abort(
|
||||
assert result["reason"] == reason
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("web_search", "expected_web_search"),
|
||||
[(True, True), (False, False)],
|
||||
indirect=["web_search"],
|
||||
)
|
||||
async def test_create_conversation_agent_web_search(
|
||||
hass: HomeAssistant,
|
||||
mock_open_router_client: AsyncMock,
|
||||
mock_setup_entry: AsyncMock,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
web_search: bool,
|
||||
expected_web_search: bool,
|
||||
) -> None:
|
||||
"""Test creating a conversation agent with web search enabled/disabled."""
|
||||
await setup_integration(hass, mock_config_entry)
|
||||
|
||||
result = await hass.config_entries.subentries.async_init(
|
||||
(mock_config_entry.entry_id, "conversation"),
|
||||
context={"source": SOURCE_USER},
|
||||
)
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["step_id"] == "init"
|
||||
|
||||
# Verify web_search field is present in schema with correct default
|
||||
schema = result["data_schema"].schema
|
||||
key = next(k for k in schema if k == CONF_WEB_SEARCH)
|
||||
assert key.default() is False
|
||||
|
||||
result = await hass.config_entries.subentries.async_configure(
|
||||
result["flow_id"],
|
||||
{
|
||||
CONF_MODEL: "openai/gpt-3.5-turbo",
|
||||
CONF_PROMPT: "you are an assistant",
|
||||
CONF_LLM_HASS_API: ["assist"],
|
||||
CONF_WEB_SEARCH: expected_web_search,
|
||||
},
|
||||
)
|
||||
|
||||
assert result["type"] is FlowResultType.CREATE_ENTRY
|
||||
assert result["data"][CONF_WEB_SEARCH] is expected_web_search
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("current_web_search", "expected_default"),
|
||||
[(True, True), (False, False)],
|
||||
)
|
||||
async def test_reconfigure_conversation_subentry_web_search_default(
|
||||
hass: HomeAssistant,
|
||||
mock_open_router_client: AsyncMock,
|
||||
mock_setup_entry: AsyncMock,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
current_web_search: bool,
|
||||
expected_default: bool,
|
||||
) -> None:
|
||||
"""Test web_search field default reflects existing value when reconfiguring."""
|
||||
await setup_integration(hass, mock_config_entry)
|
||||
|
||||
subentry = next(iter(mock_config_entry.subentries.values()))
|
||||
hass.config_entries.async_update_subentry(
|
||||
mock_config_entry,
|
||||
subentry,
|
||||
data={**subentry.data, CONF_WEB_SEARCH: current_web_search},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
result = await mock_config_entry.start_subentry_reconfigure_flow(
|
||||
hass, subentry.subentry_id
|
||||
)
|
||||
|
||||
assert result["type"] is FlowResultType.FORM
|
||||
assert result["step_id"] == "init"
|
||||
|
||||
schema = result["data_schema"].schema
|
||||
key = next(k for k in schema if k == CONF_WEB_SEARCH)
|
||||
assert key.default() is expected_default
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("current_llm_apis", "suggested_llm_apis", "expected_options"),
|
||||
[
|
||||
|
||||
@@ -79,6 +79,66 @@ async def test_default_prompt(
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("web_search", "expected_model_suffix"),
|
||||
[(True, ":online"), (False, "")],
|
||||
ids=["web_search_enabled", "web_search_disabled"],
|
||||
)
|
||||
async def test_web_search(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_openai_client: AsyncMock,
|
||||
mock_chat_log: MockChatLog, # noqa: F811
|
||||
web_search: bool,
|
||||
expected_model_suffix: str,
|
||||
) -> None:
|
||||
"""Test that web search adds :online suffix to model."""
|
||||
await setup_integration(hass, mock_config_entry)
|
||||
await conversation.async_converse(
|
||||
hass,
|
||||
"hello",
|
||||
mock_chat_log.conversation_id,
|
||||
Context(),
|
||||
agent_id="conversation.gpt_3_5_turbo",
|
||||
)
|
||||
|
||||
call = mock_openai_client.chat.completions.create.call_args_list[0][1]
|
||||
expected_model = f"openai/gpt-3.5-turbo{expected_model_suffix}"
|
||||
assert call["model"] == expected_model
|
||||
|
||||
|
||||
async def test_empty_api_response(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_openai_client: AsyncMock,
|
||||
mock_chat_log: MockChatLog, # noqa: F811
|
||||
) -> None:
|
||||
"""Test that an empty choices response raises HomeAssistantError."""
|
||||
await setup_integration(hass, mock_config_entry)
|
||||
|
||||
mock_openai_client.chat.completions.create = AsyncMock(
|
||||
return_value=ChatCompletion(
|
||||
id="chatcmpl-1234567890ABCDEFGHIJKLMNOPQRS",
|
||||
choices=[],
|
||||
created=1700000000,
|
||||
model="gpt-3.5-turbo-0613",
|
||||
object="chat.completion",
|
||||
system_fingerprint=None,
|
||||
usage=CompletionUsage(completion_tokens=0, prompt_tokens=8, total_tokens=8),
|
||||
)
|
||||
)
|
||||
|
||||
result = await conversation.async_converse(
|
||||
hass,
|
||||
"hello",
|
||||
mock_chat_log.conversation_id,
|
||||
Context(),
|
||||
agent_id="conversation.gpt_3_5_turbo",
|
||||
)
|
||||
|
||||
assert result.response.response_type == intent.IntentResponseType.ERROR
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_assist", [True])
|
||||
async def test_function_call(
|
||||
hass: HomeAssistant,
|
||||
|
||||
136
tests/components/open_router/test_init.py
Normal file
136
tests/components/open_router/test_init.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""Tests for the OpenRouter integration."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from homeassistant.components.open_router.const import (
|
||||
CONF_PROMPT,
|
||||
CONF_WEB_SEARCH,
|
||||
DOMAIN,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntryState, ConfigSubentryData
|
||||
from homeassistant.const import CONF_API_KEY, CONF_LLM_HASS_API, CONF_MODEL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import llm
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
async def test_migrate_entry_from_v1_1_to_v1_2(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test migration from version 1.1 to 1.2."""
|
||||
entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data={
|
||||
CONF_API_KEY: "bla",
|
||||
},
|
||||
version=1,
|
||||
minor_version=1,
|
||||
subentries_data=[
|
||||
ConfigSubentryData(
|
||||
data={
|
||||
CONF_MODEL: "openai/gpt-3.5-turbo",
|
||||
CONF_PROMPT: "You are a helpful assistant.",
|
||||
CONF_LLM_HASS_API: [llm.LLM_API_ASSIST],
|
||||
},
|
||||
subentry_id="conversation_subentry",
|
||||
subentry_type="conversation",
|
||||
title="GPT-3.5 Turbo",
|
||||
unique_id=None,
|
||||
),
|
||||
ConfigSubentryData(
|
||||
data={
|
||||
CONF_MODEL: "openai/gpt-4",
|
||||
},
|
||||
subentry_id="ai_task_subentry",
|
||||
subentry_type="ai_task_data",
|
||||
title="GPT-4",
|
||||
unique_id=None,
|
||||
),
|
||||
],
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.open_router.async_setup_entry",
|
||||
return_value=True,
|
||||
):
|
||||
await hass.config_entries.async_setup(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert entry.version == 1
|
||||
assert entry.minor_version == 2
|
||||
|
||||
conversation_subentry = entry.subentries["conversation_subentry"]
|
||||
assert conversation_subentry.data[CONF_MODEL] == "openai/gpt-3.5-turbo"
|
||||
assert conversation_subentry.data[CONF_PROMPT] == "You are a helpful assistant."
|
||||
assert conversation_subentry.data[CONF_LLM_HASS_API] == [llm.LLM_API_ASSIST]
|
||||
assert conversation_subentry.data[CONF_WEB_SEARCH] is False
|
||||
|
||||
ai_task_subentry = entry.subentries["ai_task_subentry"]
|
||||
assert ai_task_subentry.data[CONF_MODEL] == "openai/gpt-4"
|
||||
assert ai_task_subentry.data[CONF_WEB_SEARCH] is False
|
||||
|
||||
|
||||
async def test_migrate_entry_already_migrated(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test migration is skipped when already on version 1.2."""
|
||||
entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data={
|
||||
CONF_API_KEY: "bla",
|
||||
},
|
||||
version=1,
|
||||
minor_version=1,
|
||||
subentries_data=[
|
||||
ConfigSubentryData(
|
||||
data={
|
||||
CONF_MODEL: "openai/gpt-3.5-turbo",
|
||||
CONF_PROMPT: "You are a helpful assistant.",
|
||||
CONF_WEB_SEARCH: True,
|
||||
},
|
||||
subentry_id="conversation_subentry",
|
||||
subentry_type="conversation",
|
||||
title="GPT-3.5 Turbo",
|
||||
unique_id=None,
|
||||
),
|
||||
],
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.open_router.async_setup_entry",
|
||||
return_value=True,
|
||||
):
|
||||
await hass.config_entries.async_setup(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert entry.version == 1
|
||||
assert entry.minor_version == 2
|
||||
|
||||
conversation_subentry = entry.subentries["conversation_subentry"]
|
||||
assert conversation_subentry.data[CONF_MODEL] == "openai/gpt-3.5-turbo"
|
||||
assert conversation_subentry.data[CONF_WEB_SEARCH] is True
|
||||
|
||||
|
||||
async def test_migrate_entry_from_future_version_fails(
|
||||
hass: HomeAssistant,
|
||||
) -> None:
|
||||
"""Test migration fails for future versions."""
|
||||
entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
data={
|
||||
CONF_API_KEY: "bla",
|
||||
},
|
||||
version=100,
|
||||
minor_version=99,
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
await hass.config_entries.async_setup(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert entry.version == 100
|
||||
assert entry.minor_version == 99
|
||||
assert entry.state is ConfigEntryState.MIGRATION_ERROR
|
||||
Reference in New Issue
Block a user