mirror of
https://github.com/home-assistant/core.git
synced 2025-12-24 21:06:19 +00:00
Add home assistant cloud conversation (#157090)
This commit is contained in:
@@ -80,6 +80,7 @@ DEFAULT_MODE = MODE_PROD
|
||||
PLATFORMS = [
|
||||
Platform.AI_TASK,
|
||||
Platform.BINARY_SENSOR,
|
||||
Platform.CONVERSATION,
|
||||
Platform.STT,
|
||||
Platform.TTS,
|
||||
]
|
||||
|
||||
@@ -92,6 +92,7 @@ DISPATCHER_REMOTE_UPDATE: SignalType[Any] = SignalType("cloud_remote_update")
|
||||
STT_ENTITY_UNIQUE_ID = "cloud-speech-to-text"
|
||||
TTS_ENTITY_UNIQUE_ID = "cloud-text-to-speech"
|
||||
AI_TASK_ENTITY_UNIQUE_ID = "cloud-ai-task"
|
||||
CONVERSATION_ENTITY_UNIQUE_ID = "cloud-conversation-agent"
|
||||
|
||||
LOGIN_MFA_TIMEOUT = 60
|
||||
|
||||
|
||||
75
homeassistant/components/cloud/conversation.py
Normal file
75
homeassistant/components/cloud/conversation.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Conversation support for Home Assistant Cloud."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from hass_nabucasa.llm import LLMError
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import llm
|
||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||
|
||||
from .const import CONVERSATION_ENTITY_UNIQUE_ID, DATA_CLOUD, DOMAIN
|
||||
from .entity import BaseCloudLLMEntity
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddConfigEntryEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up the Home Assistant Cloud conversation entity."""
|
||||
cloud = hass.data[DATA_CLOUD]
|
||||
try:
|
||||
await cloud.llm.async_ensure_token()
|
||||
except LLMError:
|
||||
return
|
||||
|
||||
async_add_entities([CloudConversationEntity(cloud, config_entry)])
|
||||
|
||||
|
||||
class CloudConversationEntity(
|
||||
conversation.ConversationEntity,
|
||||
BaseCloudLLMEntity,
|
||||
):
|
||||
"""Home Assistant Cloud conversation agent."""
|
||||
|
||||
_attr_has_entity_name = True
|
||||
_attr_name = "Home Assistant Cloud"
|
||||
_attr_translation_key = "cloud_conversation"
|
||||
_attr_unique_id = CONVERSATION_ENTITY_UNIQUE_ID
|
||||
_attr_supported_features = conversation.ConversationEntityFeature.CONTROL
|
||||
|
||||
@property
|
||||
def available(self) -> bool:
|
||||
"""Return if the entity is available."""
|
||||
return self._cloud.is_logged_in and self._cloud.valid_subscription
|
||||
|
||||
@property
|
||||
def supported_languages(self) -> list[str] | Literal["*"]:
|
||||
"""Return a list of supported languages."""
|
||||
return MATCH_ALL
|
||||
|
||||
async def _async_handle_message(
|
||||
self,
|
||||
user_input: conversation.ConversationInput,
|
||||
chat_log: conversation.ChatLog,
|
||||
) -> conversation.ConversationResult:
|
||||
"""Process a user input."""
|
||||
try:
|
||||
await chat_log.async_provide_llm_data(
|
||||
user_input.as_llm_context(DOMAIN),
|
||||
llm.LLM_API_ASSIST,
|
||||
None,
|
||||
user_input.extra_system_prompt,
|
||||
)
|
||||
except conversation.ConverseError as err:
|
||||
return err.as_conversation_result()
|
||||
|
||||
await self._async_handle_chat_log("conversation", chat_log)
|
||||
|
||||
return conversation.async_get_result_from_chat_log(user_input, chat_log)
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Helpers for cloud LLM chat handling."""
|
||||
|
||||
import base64
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from collections.abc import AsyncGenerator, Callable, Iterable
|
||||
from enum import Enum
|
||||
import json
|
||||
import logging
|
||||
@@ -16,13 +16,22 @@ from hass_nabucasa.llm import (
|
||||
LLMResponseError,
|
||||
LLMServiceError,
|
||||
)
|
||||
from litellm import ResponseFunctionToolCall, ResponsesAPIStreamEvents
|
||||
from litellm import (
|
||||
ResponseFunctionToolCall,
|
||||
ResponseInputParam,
|
||||
ResponsesAPIStreamEvents,
|
||||
)
|
||||
from openai.types.responses import (
|
||||
FunctionToolParam,
|
||||
ResponseInputItemParam,
|
||||
ResponseReasoningItem,
|
||||
ToolParam,
|
||||
WebSearchToolParam,
|
||||
)
|
||||
from openai.types.responses.response_input_param import (
|
||||
ImageGenerationCall as ImageGenerationCallParam,
|
||||
)
|
||||
from openai.types.responses.response_output_item import ImageGenerationCall
|
||||
import voluptuous as vol
|
||||
from voluptuous_openapi import convert
|
||||
|
||||
@@ -50,34 +59,97 @@ class ResponseItemType(str, Enum):
|
||||
IMAGE = "image"
|
||||
|
||||
|
||||
def _convert_content_to_chat_message(
|
||||
content: conversation.Content,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Convert ChatLog content to a responses message."""
|
||||
if content.role not in ("user", "system", "tool_result", "assistant"):
|
||||
return None
|
||||
def _convert_content_to_param(
|
||||
chat_content: Iterable[conversation.Content],
|
||||
) -> ResponseInputParam:
|
||||
"""Convert any native chat message for this agent to the native format."""
|
||||
messages: ResponseInputParam = []
|
||||
reasoning_summary: list[str] = []
|
||||
web_search_calls: dict[str, dict[str, Any]] = {}
|
||||
|
||||
text_content = cast(
|
||||
conversation.SystemContent
|
||||
| conversation.UserContent
|
||||
| conversation.AssistantContent,
|
||||
content,
|
||||
)
|
||||
for content in chat_content:
|
||||
if isinstance(content, conversation.ToolResultContent):
|
||||
if (
|
||||
content.tool_name == "web_search_call"
|
||||
and content.tool_call_id in web_search_calls
|
||||
):
|
||||
web_search_call = web_search_calls.pop(content.tool_call_id)
|
||||
web_search_call["status"] = content.tool_result.get(
|
||||
"status", "completed"
|
||||
)
|
||||
messages.append(cast("ResponseInputItemParam", web_search_call))
|
||||
else:
|
||||
messages.append(
|
||||
{
|
||||
"type": "function_call_output",
|
||||
"call_id": content.tool_call_id,
|
||||
"output": json.dumps(content.tool_result),
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
if not text_content.content:
|
||||
return None
|
||||
if content.content:
|
||||
role: Literal["user", "assistant", "system", "developer"] = content.role
|
||||
if role == "system":
|
||||
role = "developer"
|
||||
messages.append(
|
||||
{"type": "message", "role": role, "content": content.content}
|
||||
)
|
||||
|
||||
content_type = "output_text" if text_content.role == "assistant" else "input_text"
|
||||
if isinstance(content, conversation.AssistantContent):
|
||||
if content.tool_calls:
|
||||
for tool_call in content.tool_calls:
|
||||
if (
|
||||
tool_call.external
|
||||
and tool_call.tool_name == "web_search_call"
|
||||
and "action" in tool_call.tool_args
|
||||
):
|
||||
web_search_calls[tool_call.id] = {
|
||||
"type": "web_search_call",
|
||||
"id": tool_call.id,
|
||||
"action": tool_call.tool_args["action"],
|
||||
"status": "completed",
|
||||
}
|
||||
else:
|
||||
messages.append(
|
||||
{
|
||||
"type": "function_call",
|
||||
"name": tool_call.tool_name,
|
||||
"arguments": json.dumps(tool_call.tool_args),
|
||||
"call_id": tool_call.id,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"role": text_content.role,
|
||||
"content": [
|
||||
{
|
||||
"type": content_type,
|
||||
"text": text_content.content,
|
||||
}
|
||||
],
|
||||
}
|
||||
if content.thinking_content:
|
||||
reasoning_summary.append(content.thinking_content)
|
||||
|
||||
if isinstance(content.native, ResponseReasoningItem):
|
||||
messages.append(
|
||||
{
|
||||
"type": "reasoning",
|
||||
"id": content.native.id,
|
||||
"summary": (
|
||||
[
|
||||
{
|
||||
"type": "summary_text",
|
||||
"text": summary,
|
||||
}
|
||||
for summary in reasoning_summary
|
||||
]
|
||||
if content.thinking_content
|
||||
else []
|
||||
),
|
||||
"encrypted_content": content.native.encrypted_content,
|
||||
}
|
||||
)
|
||||
reasoning_summary = []
|
||||
|
||||
elif isinstance(content.native, ImageGenerationCall):
|
||||
messages.append(
|
||||
cast(ImageGenerationCallParam, content.native.to_dict())
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def _format_tool(
|
||||
@@ -381,25 +453,16 @@ class BaseCloudLLMEntity(Entity):
|
||||
async def _prepare_chat_for_generation(
|
||||
self,
|
||||
chat_log: conversation.ChatLog,
|
||||
messages: ResponseInputParam,
|
||||
response_format: dict[str, Any] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Prepare kwargs for Cloud LLM from the chat log."""
|
||||
|
||||
messages = [
|
||||
message
|
||||
for content in chat_log.content
|
||||
if (message := _convert_content_to_chat_message(content))
|
||||
]
|
||||
|
||||
if not messages or messages[-1]["role"] != "user":
|
||||
raise HomeAssistantError("No user prompt found")
|
||||
|
||||
last_content = chat_log.content[-1]
|
||||
last_content: Any = chat_log.content[-1]
|
||||
if last_content.role == "user" and last_content.attachments:
|
||||
files = await self._async_prepare_files_for_prompt(last_content.attachments)
|
||||
user_message = messages[-1]
|
||||
current_content = user_message.get("content", [])
|
||||
user_message["content"] = [*(current_content or []), *files]
|
||||
current_content = last_content.content
|
||||
last_content = [*(current_content or []), *files]
|
||||
|
||||
tools: list[ToolParam] = []
|
||||
tool_choice: str | None = None
|
||||
@@ -503,8 +566,11 @@ class BaseCloudLLMEntity(Entity):
|
||||
},
|
||||
}
|
||||
|
||||
messages = _convert_content_to_param(chat_log.content)
|
||||
|
||||
response_kwargs = await self._prepare_chat_for_generation(
|
||||
chat_log,
|
||||
messages,
|
||||
response_format,
|
||||
)
|
||||
|
||||
@@ -518,15 +584,21 @@ class BaseCloudLLMEntity(Entity):
|
||||
**response_kwargs,
|
||||
)
|
||||
|
||||
async for _ in chat_log.async_add_delta_content_stream(
|
||||
agent_id=self.entity_id,
|
||||
stream=_transform_stream(
|
||||
chat_log,
|
||||
raw_stream,
|
||||
True,
|
||||
),
|
||||
):
|
||||
pass
|
||||
messages.extend(
|
||||
_convert_content_to_param(
|
||||
[
|
||||
content
|
||||
async for content in chat_log.async_add_delta_content_stream(
|
||||
self.entity_id,
|
||||
_transform_stream(
|
||||
chat_log,
|
||||
raw_stream,
|
||||
True,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
except LLMAuthenticationError as err:
|
||||
raise ConfigEntryAuthFailed("Cloud LLM authentication failed") from err
|
||||
|
||||
167
tests/components/cloud/test_conversation.py
Normal file
167
tests/components/cloud/test_conversation.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""Tests for the Home Assistant Cloud conversation entity."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from hass_nabucasa.llm import LLMError
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.components.cloud.const import DATA_CLOUD, DOMAIN
|
||||
from homeassistant.components.cloud.conversation import (
|
||||
CloudConversationEntity,
|
||||
async_setup_entry,
|
||||
)
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.helpers import intent, llm
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cloud_conversation_entity(hass: HomeAssistant) -> CloudConversationEntity:
|
||||
"""Return a CloudConversationEntity attached to hass."""
|
||||
cloud = MagicMock()
|
||||
cloud.llm = MagicMock()
|
||||
cloud.is_logged_in = True
|
||||
cloud.valid_subscription = True
|
||||
entry = MockConfigEntry(domain=DOMAIN)
|
||||
entry.add_to_hass(hass)
|
||||
entity = CloudConversationEntity(cloud, entry)
|
||||
entity.entity_id = "conversation.home_assistant_cloud"
|
||||
entity.hass = hass
|
||||
return entity
|
||||
|
||||
|
||||
def test_entity_availability(
|
||||
cloud_conversation_entity: CloudConversationEntity,
|
||||
) -> None:
|
||||
"""Test that availability mirrors the cloud login/subscription state."""
|
||||
cloud_conversation_entity._cloud.is_logged_in = True
|
||||
cloud_conversation_entity._cloud.valid_subscription = True
|
||||
assert cloud_conversation_entity.available
|
||||
|
||||
cloud_conversation_entity._cloud.is_logged_in = False
|
||||
assert not cloud_conversation_entity.available
|
||||
|
||||
cloud_conversation_entity._cloud.is_logged_in = True
|
||||
cloud_conversation_entity._cloud.valid_subscription = False
|
||||
assert not cloud_conversation_entity.available
|
||||
|
||||
|
||||
async def test_async_handle_message(
|
||||
hass: HomeAssistant, cloud_conversation_entity: CloudConversationEntity
|
||||
) -> None:
|
||||
"""Test that messages are processed through the chat log helper."""
|
||||
user_input = conversation.ConversationInput(
|
||||
text="apaga test",
|
||||
context=Context(),
|
||||
conversation_id="conversation-id",
|
||||
device_id="device-id",
|
||||
satellite_id=None,
|
||||
language="es",
|
||||
agent_id=cloud_conversation_entity.entity_id or "",
|
||||
extra_system_prompt="hazlo",
|
||||
)
|
||||
chat_log = conversation.ChatLog(hass, user_input.conversation_id)
|
||||
chat_log.async_add_user_content(conversation.UserContent(content=user_input.text))
|
||||
chat_log.async_provide_llm_data = AsyncMock()
|
||||
|
||||
async def fake_handle(chat_type, log):
|
||||
"""Inject assistant output so the result can be built."""
|
||||
assert chat_type == "conversation"
|
||||
assert log is chat_log
|
||||
log.async_add_assistant_content_without_tools(
|
||||
conversation.AssistantContent(
|
||||
agent_id=cloud_conversation_entity.entity_id or "",
|
||||
content="hecho",
|
||||
)
|
||||
)
|
||||
|
||||
handle_chat_log = AsyncMock(side_effect=fake_handle)
|
||||
|
||||
with patch.object(
|
||||
cloud_conversation_entity, "_async_handle_chat_log", handle_chat_log
|
||||
):
|
||||
result = await cloud_conversation_entity._async_handle_message(
|
||||
user_input, chat_log
|
||||
)
|
||||
|
||||
chat_log.async_provide_llm_data.assert_awaited_once_with(
|
||||
user_input.as_llm_context(DOMAIN),
|
||||
llm.LLM_API_ASSIST,
|
||||
None,
|
||||
user_input.extra_system_prompt,
|
||||
)
|
||||
handle_chat_log.assert_awaited_once_with("conversation", chat_log)
|
||||
assert result.conversation_id == "conversation-id"
|
||||
assert result.response.speech["plain"]["speech"] == "hecho"
|
||||
|
||||
|
||||
async def test_async_handle_message_converse_error(
|
||||
hass: HomeAssistant, cloud_conversation_entity: CloudConversationEntity
|
||||
) -> None:
|
||||
"""Test that ConverseError short-circuits message handling."""
|
||||
user_input = conversation.ConversationInput(
|
||||
text="hola",
|
||||
context=Context(),
|
||||
conversation_id="conversation-id",
|
||||
device_id=None,
|
||||
satellite_id=None,
|
||||
language="es",
|
||||
agent_id=cloud_conversation_entity.entity_id or "",
|
||||
)
|
||||
chat_log = conversation.ChatLog(hass, user_input.conversation_id)
|
||||
|
||||
error_response = intent.IntentResponse(language="es")
|
||||
converse_error = conversation.ConverseError(
|
||||
"failed", user_input.conversation_id or "", error_response
|
||||
)
|
||||
chat_log.async_provide_llm_data = AsyncMock(side_effect=converse_error)
|
||||
|
||||
with patch.object(
|
||||
cloud_conversation_entity, "_async_handle_chat_log", AsyncMock()
|
||||
) as handle_chat_log:
|
||||
result = await cloud_conversation_entity._async_handle_message(
|
||||
user_input, chat_log
|
||||
)
|
||||
|
||||
handle_chat_log.assert_not_called()
|
||||
assert result.response is error_response
|
||||
assert result.conversation_id == user_input.conversation_id
|
||||
|
||||
|
||||
async def test_async_setup_entry_adds_entity(hass: HomeAssistant) -> None:
|
||||
"""Test the platform setup adds the conversation entity."""
|
||||
cloud = MagicMock()
|
||||
cloud.llm = MagicMock(async_ensure_token=AsyncMock())
|
||||
cloud.is_logged_in = True
|
||||
cloud.valid_subscription = True
|
||||
hass.data[DATA_CLOUD] = cloud
|
||||
entry = MockConfigEntry(domain=DOMAIN)
|
||||
entry.add_to_hass(hass)
|
||||
add_entities = MagicMock()
|
||||
|
||||
await async_setup_entry(hass, entry, add_entities)
|
||||
|
||||
cloud.llm.async_ensure_token.assert_awaited_once()
|
||||
assert add_entities.call_count == 1
|
||||
assert isinstance(add_entities.call_args[0][0][0], CloudConversationEntity)
|
||||
|
||||
|
||||
async def test_async_setup_entry_llm_error(hass: HomeAssistant) -> None:
|
||||
"""Test entity setup is aborted when ensuring the token fails."""
|
||||
cloud = MagicMock()
|
||||
cloud.llm = MagicMock(async_ensure_token=AsyncMock(side_effect=LLMError("fail")))
|
||||
cloud.is_logged_in = True
|
||||
cloud.valid_subscription = True
|
||||
hass.data[DATA_CLOUD] = cloud
|
||||
entry = MockConfigEntry(domain=DOMAIN)
|
||||
entry.add_to_hass(hass)
|
||||
add_entities = MagicMock()
|
||||
|
||||
await async_setup_entry(hass, entry, add_entities)
|
||||
|
||||
cloud.llm.async_ensure_token.assert_awaited_once()
|
||||
add_entities.assert_not_called()
|
||||
@@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from PIL import Image
|
||||
@@ -14,6 +13,7 @@ import voluptuous as vol
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.components.cloud.entity import (
|
||||
BaseCloudLLMEntity,
|
||||
_convert_content_to_param,
|
||||
_format_structured_output,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
@@ -38,6 +38,19 @@ def cloud_entity(hass: HomeAssistant) -> BaseCloudLLMEntity:
|
||||
return entity
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_prepare_files_for_prompt(
|
||||
cloud_entity: BaseCloudLLMEntity,
|
||||
) -> AsyncMock:
|
||||
"""Patch file preparation helper on the entity."""
|
||||
with patch.object(
|
||||
cloud_entity,
|
||||
"_async_prepare_files_for_prompt",
|
||||
AsyncMock(),
|
||||
) as mock:
|
||||
yield mock
|
||||
|
||||
|
||||
class DummyTool(llm.Tool):
|
||||
"""Simple tool used for schema conversion tests."""
|
||||
|
||||
@@ -162,12 +175,12 @@ async def test_prepare_chat_for_generation_appends_attachments(
|
||||
attachment = conversation.Attachment(
|
||||
media_content_id="media-source://media/doorbell.jpg",
|
||||
mime_type="image/jpeg",
|
||||
path=hass.config.path("doorbell.jpg"),
|
||||
path=Path(hass.config.path("doorbell.jpg")),
|
||||
)
|
||||
chat_log.async_add_user_content(
|
||||
conversation.UserContent(content="Describe the door", attachments=[attachment])
|
||||
)
|
||||
chat_log.llm_api = SimpleNamespace(
|
||||
chat_log.llm_api = MagicMock(
|
||||
tools=[DummyTool()],
|
||||
custom_serializer=None,
|
||||
)
|
||||
@@ -175,8 +188,11 @@ async def test_prepare_chat_for_generation_appends_attachments(
|
||||
files = [{"type": "input_image", "image_url": "data://img", "detail": "auto"}]
|
||||
|
||||
mock_prepare_files_for_prompt.return_value = files
|
||||
messages = _convert_content_to_param(chat_log.content)
|
||||
response = await cloud_entity._prepare_chat_for_generation(
|
||||
chat_log, response_format={"type": "json"}
|
||||
chat_log,
|
||||
messages,
|
||||
response_format={"type": "json"},
|
||||
)
|
||||
|
||||
assert response["conversation_id"] == "conversation-id"
|
||||
@@ -185,35 +201,21 @@ async def test_prepare_chat_for_generation_appends_attachments(
|
||||
assert len(response["tools"]) == 2
|
||||
assert response["tools"][0]["name"] == "do_something"
|
||||
assert response["tools"][1]["type"] == "web_search"
|
||||
user_message = response["messages"][-1]
|
||||
assert user_message["content"][0] == {
|
||||
"type": "input_text",
|
||||
"text": "Describe the door",
|
||||
}
|
||||
assert user_message["content"][1:] == files
|
||||
assert response["messages"] is messages
|
||||
mock_prepare_files_for_prompt.assert_awaited_once_with([attachment])
|
||||
|
||||
|
||||
async def test_prepare_chat_for_generation_requires_user_prompt(
|
||||
async def test_prepare_chat_for_generation_passes_messages_through(
|
||||
hass: HomeAssistant, cloud_entity: BaseCloudLLMEntity
|
||||
) -> None:
|
||||
"""Test that we fail fast when there is no user input to process."""
|
||||
"""Test that prepared messages are forwarded unchanged."""
|
||||
chat_log = conversation.ChatLog(hass, "conversation-id")
|
||||
chat_log.async_add_assistant_content_without_tools(
|
||||
conversation.AssistantContent(agent_id="agent", content="Ready")
|
||||
)
|
||||
messages = _convert_content_to_param(chat_log.content)
|
||||
|
||||
with pytest.raises(HomeAssistantError, match="No user prompt found"):
|
||||
await cloud_entity._prepare_chat_for_generation(chat_log)
|
||||
response = await cloud_entity._prepare_chat_for_generation(chat_log, messages)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_prepare_files_for_prompt(
|
||||
cloud_entity: BaseCloudLLMEntity,
|
||||
) -> AsyncMock:
|
||||
"""Patch file preparation helper on the entity."""
|
||||
with patch.object(
|
||||
cloud_entity,
|
||||
"_async_prepare_files_for_prompt",
|
||||
AsyncMock(),
|
||||
) as mock:
|
||||
yield mock
|
||||
assert response["messages"] == messages
|
||||
assert response["conversation_id"] == "conversation-id"
|
||||
|
||||
Reference in New Issue
Block a user