diff --git a/homeassistant/components/ai_task/services.yaml b/homeassistant/components/ai_task/services.yaml index 8a37990a5d7..1b018fe6a65 100644 --- a/homeassistant/components/ai_task/services.yaml +++ b/homeassistant/components/ai_task/services.yaml @@ -30,6 +30,7 @@ generate_data: media: accept: - "*" + multiple: true generate_image: fields: task_name: @@ -57,3 +58,4 @@ generate_image: media: accept: - "*" + multiple: true diff --git a/homeassistant/helpers/llm.py b/homeassistant/helpers/llm.py index 900abb0b9c7..ab1a3dfa54c 100644 --- a/homeassistant/helpers/llm.py +++ b/homeassistant/helpers/llm.py @@ -782,9 +782,19 @@ def selector_serializer(schema: Any) -> Any: # noqa: C901 return {"type": "string", "enum": schema.config["languages"]} return {"type": "string", "format": "RFC 5646"} - if isinstance(schema, (selector.LocationSelector, selector.MediaSelector)): + if isinstance(schema, selector.LocationSelector): return convert(schema.DATA_SCHEMA) + if isinstance(schema, selector.MediaSelector): + item_schema = convert(schema.DATA_SCHEMA) + # Media selector allows multiple when configured + if schema.config.get("multiple"): + return { + "type": "array", + "items": item_schema, + } + return item_schema + if isinstance(schema, selector.NumberSelector): result = {"type": "number"} if "min" in schema.config: diff --git a/homeassistant/helpers/selector.py b/homeassistant/helpers/selector.py index 215a15b3512..a3510726175 100644 --- a/homeassistant/helpers/selector.py +++ b/homeassistant/helpers/selector.py @@ -1030,6 +1030,7 @@ class MediaSelectorConfig(BaseSelectorConfig, total=False): """Class to represent a media selector config.""" accept: list[str] + multiple: bool @SELECTORS.register("media") @@ -1041,6 +1042,7 @@ class MediaSelector(Selector[MediaSelectorConfig]): CONFIG_SCHEMA = make_selector_config_schema( { vol.Optional("accept"): [str], + vol.Optional("multiple", default=False): cv.boolean, } ) DATA_SCHEMA = vol.Schema( @@ -1059,9 +1061,9 @@ class MediaSelector(Selector[MediaSelectorConfig]): """Instantiate a selector.""" super().__init__(config) - def __call__(self, data: Any) -> dict[str, str]: + def __call__(self, data: Any) -> dict[str, str] | list[dict[str, str]]: """Validate the passed selection.""" - schema = { + item_schema_dict = { key: value for key, value in self.DATA_SCHEMA.schema.items() if key != "entity_id" @@ -1069,10 +1071,19 @@ class MediaSelector(Selector[MediaSelectorConfig]): if "accept" not in self.config: # If accept is not set, the entity_id field is required - schema[vol.Required("entity_id")] = cv.entity_id_or_uuid + item_schema_dict[vol.Required("entity_id")] = cv.entity_id_or_uuid - media: dict[str, str] = vol.Schema(schema)(data) - return media + item_schema = vol.Schema(item_schema_dict) + + if not self.config["multiple"]: + media: dict[str, str] = item_schema(data) + return media + + # Backwards compatibility for places that now accept multiple items + if not isinstance(data, list): + data = [data] + + return [item_schema(item) for item in data] class NumberSelectorConfig(BaseSelectorConfig, total=False): diff --git a/tests/helpers/test_llm.py b/tests/helpers/test_llm.py index 7011ba42b72..7e237cc495e 100644 --- a/tests/helpers/test_llm.py +++ b/tests/helpers/test_llm.py @@ -1132,6 +1132,19 @@ async def test_selector_serializer( }, "required": ["media_content_id", "media_content_type"], } + assert selector_serializer(selector.MediaSelector({"multiple": True})) == { + "type": "array", + "items": { + "type": "object", + "properties": { + "entity_id": {"type": "string"}, + "media_content_id": {"type": "string"}, + "media_content_type": {"type": "string"}, + "metadata": {"type": "object", "additionalProperties": True}, + }, + "required": ["media_content_id", "media_content_type"], + }, + } assert selector_serializer(selector.NumberSelector({"mode": "box"})) == { "type": "number" } diff --git a/tests/helpers/test_selector.py b/tests/helpers/test_selector.py index 73db8af126e..416701a1746 100644 --- a/tests/helpers/test_selector.py +++ b/tests/helpers/test_selector.py @@ -1164,6 +1164,8 @@ def test_media_selector_schema(schema, valid_selections, invalid_selections) -> def drop_metadata(data): """Drop metadata key from the input.""" + if isinstance(data, list): + return [drop_metadata(item) for item in data] data.pop("metadata", None) return data @@ -1176,6 +1178,96 @@ def test_media_selector_schema(schema, valid_selections, invalid_selections) -> ) +@pytest.mark.parametrize( + ("schema", "valid_selections", "invalid_selections"), + [ + ( + {"multiple": True}, + ( + [ + { + "entity_id": "sensor.abc", + "media_content_id": "abc", + "media_content_type": "def", + }, + { + "entity_id": "sensor.def", + "media_content_id": "ghi", + "media_content_type": "jkl", + }, + ], + # Not a list is automatically converted to a list + { + "entity_id": "sensor.abc", + "media_content_id": "abc", + "media_content_type": "def", + }, + ), + ( + None, + # Missing required key in one item + [ + { + "entity_id": "sensor.abc", + "media_content_id": "abc", + "media_content_type": "def", + }, + { + "entity_id": "sensor.def", + "media_content_id": "ghi", + }, + ], + ), + ), + ( + {"multiple": True, "accept": ["image/*"]}, + ( + [ + { + "media_content_id": "abc", + "media_content_type": "def", + }, + { + "media_content_id": "ghi", + "media_content_type": "jkl", + }, + ], + ), + ( + None, + # entity_id not allowed when accept is set + [ + { + "entity_id": "sensor.abc", + "media_content_id": "abc", + "media_content_type": "def", + } + ], + ), + ), + ], +) +def test_media_selector_schema_multiple( + schema, valid_selections, invalid_selections +) -> None: + """Test media selector with multiple selections.""" + + def drop_metadata(data, root=True): + if isinstance(data, list): + return [drop_metadata(item, False) for item in data] + data.pop("metadata", None) + # Multiple=true wraps single values in list. + return [data] if root and schema.get("multiple") else data + + _test_selector( + "media", + schema, + valid_selections, + invalid_selections, + drop_metadata, + ) + + @pytest.mark.parametrize( ("schema", "valid_selections", "invalid_selections"), [