diff --git a/script/hassfest/conditions.py b/script/hassfest/conditions.py index ecb7ceca7f2..22449cfd636 100644 --- a/script/hassfest/conditions.py +++ b/script/hassfest/conditions.py @@ -26,21 +26,95 @@ def exists(value: Any) -> Any: return value +def validate_field_schema(condition_schema: dict[str, Any]) -> dict[str, Any]: + """Validate a field schema including context references.""" + + for field_name, field_schema in condition_schema.get("fields", {}).items(): + # Validate context if present + if "context" in field_schema: + if CONF_SELECTOR not in field_schema: + raise vol.Invalid( + f"Context defined without a selector in '{field_name}'" + ) + + context = field_schema["context"] + if not isinstance(context, dict): + raise vol.Invalid(f"Context must be a dictionary in '{field_name}'") + + # Determine which selector type is being used + selector_config = field_schema[CONF_SELECTOR] + selector_class = selector.selector(selector_config) + + for context_key, field_ref in context.items(): + # Check if context key is allowed for this selector type + allowed_keys = selector_class.allowed_context_keys + if context_key not in allowed_keys: + raise vol.Invalid( + f"Invalid context key '{context_key}' for selector type '{selector_class.selector_type}'. " + f"Allowed keys: {', '.join(sorted(allowed_keys)) if allowed_keys else 'none'}" + ) + + # Check if the referenced field exists in condition schema or target + if not isinstance(field_ref, str): + raise vol.Invalid( + f"Context value for '{context_key}' must be a string field reference" + ) + + # Check if field exists in condition schema fields or target + condition_fields = condition_schema["fields"] + field_exists = field_ref in condition_fields + if field_exists and "selector" in condition_fields[field_ref]: + # Check if the selector type is allowed for this context key + field_selector_config = condition_fields[field_ref][CONF_SELECTOR] + field_selector_class = selector.selector(field_selector_config) + if field_selector_class.selector_type not in allowed_keys.get( + context_key, set() + ): + raise vol.Invalid( + f"The context '{context_key}' for '{field_name}' references '{field_ref}', but '{context_key}' " + f"does not allow selectors of type '{field_selector_class.selector_type}'. Allowed selector types: {', '.join(allowed_keys.get(context_key, set()))}" + ) + if not field_exists and "target" in condition_schema: + # Target is a special field that always exists when defined + field_exists = field_ref == "target" + if field_exists and "target" not in allowed_keys.get( + context_key, set() + ): + raise vol.Invalid( + f"The context '{context_key}' for '{field_name}' references 'target', but '{context_key}' " + f"does not allow 'target'. Allowed selector types: {', '.join(allowed_keys.get(context_key, set()))}" + ) + + if not field_exists: + raise vol.Invalid( + f"Context reference '{field_ref}' for key '{context_key}' does not exist " + f"in condition schema fields or target" + ) + + return condition_schema + + FIELD_SCHEMA = vol.Schema( { vol.Optional("example"): exists, vol.Optional("default"): exists, vol.Optional("required"): bool, vol.Optional(CONF_SELECTOR): selector.validate_selector, + vol.Optional("context"): { + str: str # key is context key, value is field name in the schema which value should be used + }, # Will be validated in validate_field_schema } ) CONDITION_SCHEMA = vol.Any( - vol.Schema( - { - vol.Optional("target"): selector.TargetSelector.CONFIG_SCHEMA, - vol.Optional("fields"): vol.Schema({str: FIELD_SCHEMA}), - } + vol.All( + vol.Schema( + { + vol.Optional("target"): selector.TargetSelector.CONFIG_SCHEMA, + vol.Optional("fields"): vol.Schema({str: FIELD_SCHEMA}), + } + ), + validate_field_schema, ), None, )