diff --git a/homeassistant/helpers/selector.py b/homeassistant/helpers/selector.py index 176c2930277..a8a29f4e490 100644 --- a/homeassistant/helpers/selector.py +++ b/homeassistant/helpers/selector.py @@ -418,6 +418,88 @@ class BooleanSelector(Selector[BooleanSelectorConfig]): return value +def reject_nested_choose_selector(config: dict[str, Any]) -> dict[str, Any]: + """Reject nested choose selectors.""" + for choice in config.get("choices", {}).values(): + if isinstance(choice["selector"], dict): + selector_type, _ = _get_selector_type_and_class(choice["selector"]) + if selector_type == "choose": + raise vol.Invalid("Nested choose selectors are not allowed") + return config + + +class ChooseSelectorChoiceConfig(TypedDict, total=False): + """Class to represent a choose selector choice config.""" + + selector: Required[Selector | dict[str, Any]] + + +class ChooseSelectorConfig(BaseSelectorConfig): + """Class to represent a choose selector config.""" + + choices: Required[dict[str, ChooseSelectorChoiceConfig]] + translation_key: str + + +@SELECTORS.register("choose") +class ChooseSelector(Selector[ChooseSelectorConfig]): + """Selector allowing to choose one of several selectors.""" + + selector_type = "choose" + + CONFIG_SCHEMA = vol.All( + make_selector_config_schema( + { + vol.Required("choices"): { + str: { + vol.Required("selector"): vol.Any(Selector, validate_selector), + } + }, + vol.Optional("translation_key"): cv.string, + }, + ), + reject_nested_choose_selector, + ) + + def __init__(self, config: ChooseSelectorConfig | None = None) -> None: + """Instantiate a selector.""" + super().__init__(config) + + def serialize(self) -> dict[str, dict[str, ChooseSelectorConfig]]: + """Serialize ChooseSelectorConfig for voluptuous_serialize.""" + _config = deepcopy(self.config) + if "choices" in _config: + for choice in _config["choices"].values(): + if isinstance(choice["selector"], Selector): + choice["selector"] = choice["selector"].serialize()["selector"] + return {"selector": {self.selector_type: _config}} + + def __call__(self, data: Any) -> Any: + """Validate the passed selection.""" + if not isinstance(data, dict): + for choice in self.config["choices"].values(): + try: + validated = selector(choice["selector"])(data) # type: ignore[operator] + except (vol.Invalid, vol.MultipleInvalid): + continue + else: + return validated + + raise vol.Invalid("Value does not match any choice selector") + + if "active_choice" not in data: + raise vol.Invalid("Missing active_choice key") + if data["active_choice"] not in data: + raise vol.Invalid("Missing value for active choice") + + choices = self.config.get("choices", {}) + if data["active_choice"] not in choices: + raise vol.Invalid("Invalid active_choice key") + return selector(choices[data["active_choice"]]["selector"])( # type: ignore[operator] + data[data["active_choice"]] + ) + + class ColorRGBSelectorConfig(BaseSelectorConfig): """Class to represent a color RGB selector config.""" @@ -1223,14 +1305,10 @@ class ObjectSelector(Selector[ObjectSelectorConfig]): _config = deepcopy(self.config) if "fields" in _config: for field_items in _config["fields"].values(): - if isinstance(field_items["selector"], ObjectSelector): - field_items["selector"] = field_items["selector"].serialize() - elif isinstance(field_items["selector"], Selector): - field_items["selector"] = { - field_items["selector"].selector_type: field_items[ - "selector" - ].config - } + if isinstance(field_items["selector"], Selector): + field_items["selector"] = field_items["selector"].serialize()[ + "selector" + ] return {"selector": {self.selector_type: _config}} def __call__(self, data: Any) -> Any: diff --git a/script/hassfest/translations.py b/script/hassfest/translations.py index a308f92c270..ce83e1bab8a 100644 --- a/script/hassfest/translations.py +++ b/script/hassfest/translations.py @@ -323,6 +323,10 @@ def gen_strings_schema(config: Config, integration: Integration) -> vol.Schema: ), slug_validator=vol.Any("_", cv.slug), ), + vol.Optional("choices"): cv.schema_with_slug_keys( + translation_value_validator, + slug_validator=translation_key_validator, + ), vol.Optional("options"): gen_data_entry_schema( config=config, integration=integration, diff --git a/tests/helpers/snapshots/test_selector.ambr b/tests/helpers/snapshots/test_selector.ambr index e3af84e6621..ff66f606f38 100644 --- a/tests/helpers/snapshots/test_selector.ambr +++ b/tests/helpers/snapshots/test_selector.ambr @@ -1,4 +1,163 @@ # serializer version: 1 +# name: test_choose_selector_serialize + dict({ + 'selector': dict({ + 'choose': dict({ + 'choices': dict({ + 'entity_choice': dict({ + 'selector': dict({ + 'entity': dict({ + 'domain': list([ + 'light', + ]), + 'multiple': False, + 'reorder': False, + }), + }), + }), + 'number_choice': dict({ + 'selector': dict({ + 'number': dict({ + 'max': 100.0, + 'min': 0.0, + 'mode': 'slider', + 'step': 1.0, + }), + }), + }), + 'object_choice': dict({ + 'selector': dict({ + 'object': dict({ + 'description_field': 'percentage', + 'fields': dict({ + 'name': dict({ + 'required': True, + 'selector': dict({ + 'text': dict({ + 'multiline': False, + 'multiple': False, + }), + }), + }), + 'percentage': dict({ + 'selector': dict({ + 'number': dict({ + 'mode': 'box', + 'step': 1.0, + }), + }), + }), + }), + 'label_field': 'name', + 'multiple': False, + }), + }), + }), + 'text_choice': dict({ + 'selector': dict({ + 'text': dict({ + 'multiline': True, + 'multiple': False, + }), + }), + }), + }), + }), + }), + }) +# --- +# name: test_choose_selector_serialize.1 + dict({ + 'selector': dict({ + 'choose': dict({ + 'choices': dict({ + 'number_choice': dict({ + 'selector': dict({ + 'number': dict({ + 'max': 100.0, + 'min': 0.0, + 'mode': 'slider', + 'step': 1.0, + }), + }), + }), + 'object_choice': dict({ + 'selector': dict({ + 'object': dict({ + 'description_field': 'percentage', + 'fields': dict({ + 'name': dict({ + 'required': True, + 'selector': dict({ + 'text': dict({ + 'multiline': False, + 'multiple': False, + }), + }), + }), + 'object': dict({ + 'selector': dict({ + 'object': dict({ + 'fields': dict({ + 'choose': dict({ + 'required': True, + 'selector': dict({ + 'choose': dict({ + 'choices': dict({ + 'number_choice': dict({ + 'selector': dict({ + 'number': dict({ + 'max': 100.0, + 'min': 0.0, + 'mode': 'slider', + 'step': 1.0, + }), + }), + }), + 'text_choice': dict({ + 'selector': dict({ + 'text': dict({ + 'multiline': True, + 'multiple': False, + }), + }), + }), + }), + }), + }), + }), + }), + 'multiple': False, + }), + }), + }), + 'percentage': dict({ + 'selector': dict({ + 'number': dict({ + 'mode': 'box', + 'step': 1.0, + }), + }), + }), + }), + 'label_field': 'name', + 'multiple': False, + }), + }), + }), + 'text_choice': dict({ + 'selector': dict({ + 'text': dict({ + 'multiline': True, + 'multiple': False, + }), + }), + }), + }), + }), + }), + }) +# --- # name: test_nested_object_selectors dict({ 'selector': dict({ @@ -16,64 +175,60 @@ }), 'object': dict({ 'selector': dict({ - 'selector': dict({ - 'object': dict({ - 'description_field': 'other_name', - 'fields': dict({ - 'new_object': dict({ - 'required': True, - 'selector': dict({ - 'selector': dict({ - 'object': dict({ - 'description_field': 'description', - 'fields': dict({ - 'description': dict({ - 'required': True, - 'selector': dict({ - 'text': dict({ - 'multiline': False, - 'multiple': False, - }), - }), - }), - 'title': dict({ - 'required': True, - 'selector': dict({ - 'text': dict({ - 'multiline': False, - 'multiple': False, - }), - }), + 'object': dict({ + 'description_field': 'other_name', + 'fields': dict({ + 'new_object': dict({ + 'required': True, + 'selector': dict({ + 'object': dict({ + 'description_field': 'description', + 'fields': dict({ + 'description': dict({ + 'required': True, + 'selector': dict({ + 'text': dict({ + 'multiline': False, + 'multiple': False, + }), + }), + }), + 'title': dict({ + 'required': True, + 'selector': dict({ + 'text': dict({ + 'multiline': False, + 'multiple': False, }), }), - 'label_field': 'title', - 'multiple': False, }), }), - }), - }), - 'no_name': dict({ - 'required': True, - 'selector': dict({ - 'text': dict({ - 'multiline': False, - 'multiple': False, - }), - }), - }), - 'other_name': dict({ - 'required': True, - 'selector': dict({ - 'text': dict({ - 'multiline': False, - 'multiple': False, - }), + 'label_field': 'title', + 'multiple': False, + }), + }), + }), + 'no_name': dict({ + 'required': True, + 'selector': dict({ + 'text': dict({ + 'multiline': False, + 'multiple': False, + }), + }), + }), + 'other_name': dict({ + 'required': True, + 'selector': dict({ + 'text': dict({ + 'multiline': False, + 'multiple': False, }), }), }), - 'label_field': 'no_name', - 'multiple': False, }), + 'label_field': 'no_name', + 'multiple': False, }), }), }), diff --git a/tests/helpers/test_selector.py b/tests/helpers/test_selector.py index 416701a1746..dc7e2637bd9 100644 --- a/tests/helpers/test_selector.py +++ b/tests/helpers/test_selector.py @@ -513,6 +513,259 @@ def test_boolean_selector_schema(schema, valid_selections, invalid_selections) - ) +@pytest.mark.parametrize( + ("schema", "valid_selections", "invalid_selections"), + [ + ( + { + "choices": { + "text_choice": {"selector": {"text": {}}}, + "number_choice": {"selector": {"number": {"min": 0, "max": 100}}}, + } + }, + ( + # Direct value matching text selector + "some text", + # Direct value matching number selector + 42, + # Explicit choice with active_choice key + {"active_choice": "text_choice", "text_choice": "hello world"}, + {"active_choice": "number_choice", "number_choice": 50}, + ), + ( + # None doesn't match any selector + None, + # Missing active_choice key + {"text_choice": "hello"}, + # Invalid active_choice key + {"active_choice": "invalid", "invalid": "value"}, + # Missing value for active choice + {"active_choice": "text_choice"}, + # Wrong value type for number selector + {"active_choice": "number_choice", "number_choice": "not a number"}, + ), + ), + ( + { + "choices": { + "entity": {"selector": {"entity": {}}}, + "device": {"selector": {"device": {}}}, + "text": {"selector": {"text": {}}}, + } + }, + ( + # Direct value matching entity selector + "sensor.abc123", + FAKE_UUID, + # Explicit choice + {"active_choice": "entity", "entity": "light.bedroom"}, + {"active_choice": "device", "device": "device123"}, + {"active_choice": "text", "text": "some text"}, + ), + ( + None, + # List doesn't match any selector + ["sensor.abc", "light.def"], + # Missing active_choice key + {"entity": "sensor.abc"}, + # Invalid active_choice + {"active_choice": "area", "area": "area123"}, + ), + ), + ], +) +def test_choose_selector_schema(schema, valid_selections, invalid_selections) -> None: + """Test choose selector.""" + + def get_selected_value(data): + """Get the selected value from the input.""" + if isinstance(data, dict) and "active_choice" in data: + return data[data["active_choice"]] + return data + + _test_selector( + "choose", schema, valid_selections, invalid_selections, get_selected_value + ) + + +@pytest.mark.parametrize( + ("schema", "raises"), + [ + # Valid schemas + ( + { + "choices": { + "text": {"selector": {"text": {}}}, + "number": {"selector": {"number": {}}}, + } + }, + does_not_raise(), + ), + ( + { + "choices": { + "text": {"selector": selector.TextSelector()}, + "number": {"selector": selector.NumberSelector()}, + } + }, + does_not_raise(), + ), + # Invalid schemas + ( + {}, # Missing required 'choices' key + pytest.raises(vol.Invalid), + ), + ( + { + "choices": {} # Empty choices dict + }, + does_not_raise(), # Empty dict is technically valid + ), + ( + { + "choices": { + "text": {} # Missing required 'selector' key in choice + } + }, + pytest.raises(vol.Invalid), + ), + ( + { + "choices": { + "invalid": {"selector": {"not_exist": {}}} # Invalid selector type + } + }, + pytest.raises(vol.Invalid), + ), + ( + { + "choices": "not a dict" # choices should be a dict + }, + pytest.raises(vol.Invalid), + ), + ( + { + "choices": { + "invalid": { + "selector": { + "choose": { + "choices": { + "text": {"selector": {"text": {}}}, + "number": {"selector": {"number": {}}}, + } + } + } + } # Nested choose is not allowed + } + }, + pytest.raises(vol.Invalid), + ), + ], +) +def test_choose_selector_validate_schema( + schema: dict, raises: AbstractContextManager +) -> None: + """Test choose selector schema validation.""" + with raises: + selector.validate_selector({"choose": schema}) + + +def test_choose_selector_serialize(snapshot: SnapshotAssertion) -> None: + """Test choose selector serialization.""" + # Test with dict-based selectors + choose_selector = selector.ChooseSelector( + { + "choices": { + "text_choice": {"selector": {"text": {"multiline": True}}}, + "number_choice": {"selector": {"number": {"min": 0, "max": 100}}}, + "entity_choice": {"selector": {"entity": {"domain": "light"}}}, + "object_choice": { + "selector": { + "object": { + "fields": { + "name": { + "required": True, + "selector": {"text": {}}, + }, + "percentage": { + "selector": {"number": {}}, + }, + }, + "multiple": False, + "label_field": "name", + "description_field": "percentage", + } + } + }, + } + } + ) + assert choose_selector.serialize() == snapshot + + # Test with Selector object instances + choose_selector_objects = selector.ChooseSelector( + { + "choices": { + "text_choice": {"selector": selector.TextSelector({"multiline": True})}, + "number_choice": { + "selector": selector.NumberSelector({"min": 0, "max": 100}) + }, + "object_choice": { + "selector": selector.ObjectSelector( + { + "fields": { + "name": { + "required": True, + "selector": selector.TextSelector({}), + }, + "percentage": { + "selector": selector.NumberSelector({}), + }, + "object": { + "selector": selector.ObjectSelector( + { + "fields": { + "choose": { + "required": True, + "selector": selector.ChooseSelector( + { + "choices": { + "text_choice": { + "selector": selector.TextSelector( + { + "multiline": True + } + ) + }, + "number_choice": { + "selector": selector.NumberSelector( + { + "min": 0, + "max": 100, + } + ) + }, + } + } + ), + }, + } + } + ), + }, + }, + "multiple": False, + "label_field": "name", + "description_field": "percentage", + } + ) + }, + } + } + ) + assert choose_selector_objects.serialize() == snapshot + + @pytest.mark.parametrize( ("schema", "valid_selections", "invalid_selections"), [