From 91e13d447a0aca0cdea90d8a824cd5b4537acac8 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Wed, 24 Sep 2025 23:09:54 -0400 Subject: [PATCH] Prevent common control calling async methods from thread (#152931) Co-authored-by: J. Nick Koston Co-authored-by: J. Nick Koston --- .../usage_prediction/common_control.py | 112 +++++++++--------- .../usage_prediction/test_common_control.py | 61 +++++++--- 2 files changed, 101 insertions(+), 72 deletions(-) diff --git a/homeassistant/components/usage_prediction/common_control.py b/homeassistant/components/usage_prediction/common_control.py index 9d86b5f2766..69f2164fc76 100644 --- a/homeassistant/components/usage_prediction/common_control.py +++ b/homeassistant/components/usage_prediction/common_control.py @@ -3,13 +3,14 @@ from __future__ import annotations from collections import Counter -from collections.abc import Callable +from collections.abc import Callable, Sequence from datetime import datetime, timedelta from functools import cache import logging from typing import Any, Literal, cast from sqlalchemy import select +from sqlalchemy.engine.row import Row from sqlalchemy.orm import Session from homeassistant.components.recorder import get_instance @@ -90,61 +91,32 @@ async def async_predict_common_control( Args: hass: Home Assistant instance user_id: User ID to filter events by. - - Returns: - Dictionary with time categories as keys and lists of most common entity IDs as values """ # Get the recorder instance to ensure it's ready recorder = get_instance(hass) ent_reg = er.async_get(hass) # Execute the database operation in the recorder's executor - return await recorder.async_add_executor_job( + data = await recorder.async_add_executor_job( _fetch_with_session, hass, _fetch_and_process_data, ent_reg, user_id ) - - -def _fetch_and_process_data( - session: Session, ent_reg: er.EntityRegistry, user_id: str -) -> EntityUsagePredictions: - """Fetch and process service call events from the database.""" # Prepare a dictionary to track results results: dict[str, Counter[str]] = { time_cat: Counter() for time_cat in TIME_CATEGORIES } + allowed_entities = set(hass.states.async_entity_ids(ALLOWED_DOMAINS)) + hidden_entities: set[str] = set() + # Keep track of contexts that we processed so that we will only process # the first service call in a context, and not subsequent calls. context_processed: set[bytes] = set() - thirty_days_ago_ts = (dt_util.utcnow() - timedelta(days=30)).timestamp() - user_id_bytes = uuid_hex_to_bytes_or_none(user_id) - if not user_id_bytes: - raise ValueError("Invalid user_id format") - - # Build the main query for events with their data - query = ( - select( - Events.context_id_bin, - Events.time_fired_ts, - EventData.shared_data, - ) - .select_from(Events) - .outerjoin(EventData, Events.data_id == EventData.data_id) - .outerjoin(EventTypes, Events.event_type_id == EventTypes.event_type_id) - .where(Events.time_fired_ts >= thirty_days_ago_ts) - .where(Events.context_user_id_bin == user_id_bytes) - .where(EventTypes.event_type == "call_service") - .order_by(Events.time_fired_ts) - ) - # Execute the query context_id: bytes time_fired_ts: float shared_data: str | None local_time_zone = dt_util.get_default_time_zone() - for context_id, time_fired_ts, shared_data in ( - session.connection().execute(query).all() - ): + for context_id, time_fired_ts, shared_data in data: # Skip if we have already processed an event that was part of this context if context_id in context_processed: continue @@ -153,7 +125,7 @@ def _fetch_and_process_data( context_processed.add(context_id) # Parse the event data - if not shared_data: + if not time_fired_ts or not shared_data: continue try: @@ -187,27 +159,26 @@ def _fetch_and_process_data( if not isinstance(entity_ids, list): entity_ids = [entity_ids] - # Filter out entity IDs that are not in allowed domains - entity_ids = [ - entity_id - for entity_id in entity_ids - if entity_id.split(".")[0] in ALLOWED_DOMAINS - and ((entry := ent_reg.async_get(entity_id)) is None or not entry.hidden) - ] + # Convert to local time for time category determination + period = time_category( + datetime.fromtimestamp(time_fired_ts, local_time_zone).hour + ) + period_results = results[period] - if not entity_ids: - continue + # Count entity usage + for entity_id in entity_ids: + if entity_id not in allowed_entities or entity_id in hidden_entities: + continue - # Convert timestamp to datetime and determine time category - if time_fired_ts: - # Convert to local time for time category determination - period = time_category( - datetime.fromtimestamp(time_fired_ts, local_time_zone).hour - ) + if ( + entity_id not in period_results + and (entry := ent_reg.async_get(entity_id)) + and entry.hidden + ): + hidden_entities.add(entity_id) + continue - # Count entity usage - for entity_id in entity_ids: - results[period][entity_id] += 1 + period_results[entity_id] += 1 return EntityUsagePredictions( morning=[ @@ -226,11 +197,40 @@ def _fetch_and_process_data( ) +def _fetch_and_process_data( + session: Session, ent_reg: er.EntityRegistry, user_id: str +) -> Sequence[Row[tuple[bytes | None, float | None, str | None]]]: + """Fetch and process service call events from the database.""" + thirty_days_ago_ts = (dt_util.utcnow() - timedelta(days=30)).timestamp() + user_id_bytes = uuid_hex_to_bytes_or_none(user_id) + if not user_id_bytes: + raise ValueError("Invalid user_id format") + + # Build the main query for events with their data + query = ( + select( + Events.context_id_bin, + Events.time_fired_ts, + EventData.shared_data, + ) + .select_from(Events) + .outerjoin(EventData, Events.data_id == EventData.data_id) + .outerjoin(EventTypes, Events.event_type_id == EventTypes.event_type_id) + .where(Events.time_fired_ts >= thirty_days_ago_ts) + .where(Events.context_user_id_bin == user_id_bytes) + .where(EventTypes.event_type == "call_service") + .order_by(Events.time_fired_ts) + ) + return session.connection().execute(query).all() + + def _fetch_with_session( hass: HomeAssistant, - fetch_func: Callable[[Session], EntityUsagePredictions], + fetch_func: Callable[ + [Session], Sequence[Row[tuple[bytes | None, float | None, str | None]]] + ], *args: object, -) -> EntityUsagePredictions: +) -> Sequence[Row[tuple[bytes | None, float | None, str | None]]]: """Execute a fetch function with a database session.""" with session_scope(hass=hass, read_only=True) as session: return fetch_func(session, *args) diff --git a/tests/components/usage_prediction/test_common_control.py b/tests/components/usage_prediction/test_common_control.py index de6db025472..090d9ddf7ff 100644 --- a/tests/components/usage_prediction/test_common_control.py +++ b/tests/components/usage_prediction/test_common_control.py @@ -62,9 +62,15 @@ async def test_with_service_calls(hass: HomeAssistant) -> None: """Test function with actual service call events in database.""" user_id = str(uuid.uuid4()) + hass.states.async_set("light.living_room", "off") + hass.states.async_set("light.kitchen", "off") + hass.states.async_set("climate.thermostat", "off") + hass.states.async_set("light.bedroom", "off") + hass.states.async_set("lock.front_door", "locked") + # Create service call events at different times of day # Morning events - use separate service calls to get around context deduplication - with freeze_time("2023-07-01 07:00:00+00:00"): # Morning + with freeze_time("2023-07-01 07:00:00"): # Morning hass.bus.async_fire( EVENT_CALL_SERVICE, { @@ -77,7 +83,7 @@ async def test_with_service_calls(hass: HomeAssistant) -> None: await hass.async_block_till_done() # Afternoon events - with freeze_time("2023-07-01 14:00:00+00:00"): # Afternoon + with freeze_time("2023-07-01 14:00:00"): # Afternoon hass.bus.async_fire( EVENT_CALL_SERVICE, { @@ -90,7 +96,7 @@ async def test_with_service_calls(hass: HomeAssistant) -> None: await hass.async_block_till_done() # Evening events - with freeze_time("2023-07-01 19:00:00+00:00"): # Evening + with freeze_time("2023-07-01 19:00:00"): # Evening hass.bus.async_fire( EVENT_CALL_SERVICE, { @@ -103,7 +109,7 @@ async def test_with_service_calls(hass: HomeAssistant) -> None: await hass.async_block_till_done() # Night events - with freeze_time("2023-07-01 23:00:00+00:00"): # Night + with freeze_time("2023-07-01 23:00:00"): # Night hass.bus.async_fire( EVENT_CALL_SERVICE, { @@ -119,7 +125,7 @@ async def test_with_service_calls(hass: HomeAssistant) -> None: await async_wait_recording_done(hass) # Get predictions - make sure we're still in a reasonable timeframe - with freeze_time("2023-07-02 10:00:00+00:00"): # Next day, so events are recent + with freeze_time("2023-07-02 10:00:00"): # Next day, so events are recent results = await async_predict_common_control(hass, user_id) # Verify results contain the expected entities in the correct time periods @@ -151,7 +157,12 @@ async def test_multiple_entities_in_one_call(hass: HomeAssistant) -> None: suggested_object_id="kitchen", ) - with freeze_time("2023-07-01 10:00:00+00:00"): # Morning + hass.states.async_set("light.living_room", "off") + hass.states.async_set("light.kitchen", "off") + hass.states.async_set("light.hallway", "off") + hass.states.async_set("not_allowed.domain", "off") + + with freeze_time("2023-07-01 10:00:00"): # Morning hass.bus.async_fire( EVENT_CALL_SERVICE, { @@ -163,6 +174,7 @@ async def test_multiple_entities_in_one_call(hass: HomeAssistant) -> None: "light.kitchen", "light.hallway", "not_allowed.domain", + "light.not_in_state_machine", ] }, }, @@ -172,7 +184,7 @@ async def test_multiple_entities_in_one_call(hass: HomeAssistant) -> None: await async_wait_recording_done(hass) - with freeze_time("2023-07-02 10:00:00+00:00"): # Next day, so events are recent + with freeze_time("2023-07-02 10:00:00"): # Next day, so events are recent results = await async_predict_common_control(hass, user_id) # Two lights should be counted (10:00 UTC = 02:00 local = night) @@ -189,7 +201,10 @@ async def test_context_deduplication(hass: HomeAssistant) -> None: user_id = str(uuid.uuid4()) context = Context(user_id=user_id) - with freeze_time("2023-07-01 10:00:00+00:00"): # Morning + hass.states.async_set("light.living_room", "off") + hass.states.async_set("switch.coffee_maker", "off") + + with freeze_time("2023-07-01 10:00:00"): # Morning # Fire multiple events with the same context hass.bus.async_fire( EVENT_CALL_SERVICE, @@ -215,7 +230,7 @@ async def test_context_deduplication(hass: HomeAssistant) -> None: await async_wait_recording_done(hass) - with freeze_time("2023-07-02 10:00:00+00:00"): # Next day, so events are recent + with freeze_time("2023-07-02 10:00:00"): # Next day, so events are recent results = await async_predict_common_control(hass, user_id) # Only the first event should be processed (10:00 UTC = 02:00 local = night) @@ -232,8 +247,11 @@ async def test_old_events_excluded(hass: HomeAssistant) -> None: """Test that events older than 30 days are excluded.""" user_id = str(uuid.uuid4()) + hass.states.async_set("light.old_event", "off") + hass.states.async_set("light.recent_event", "off") + # Create an old event (35 days ago) - with freeze_time("2023-05-27 10:00:00+00:00"): # 35 days before July 1st + with freeze_time("2023-05-27 10:00:00"): # 35 days before July 1st hass.bus.async_fire( EVENT_CALL_SERVICE, { @@ -246,7 +264,7 @@ async def test_old_events_excluded(hass: HomeAssistant) -> None: await hass.async_block_till_done() # Create a recent event (5 days ago) - with freeze_time("2023-06-26 10:00:00+00:00"): # 5 days before July 1st + with freeze_time("2023-06-26 10:00:00"): # 5 days before July 1st hass.bus.async_fire( EVENT_CALL_SERVICE, { @@ -261,7 +279,7 @@ async def test_old_events_excluded(hass: HomeAssistant) -> None: await async_wait_recording_done(hass) # Query with current time - with freeze_time("2023-07-01 10:00:00+00:00"): + with freeze_time("2023-07-01 10:00:00"): results = await async_predict_common_control(hass, user_id) # Only recent event should be included (10:00 UTC = 02:00 local = night) @@ -278,8 +296,16 @@ async def test_entities_limit(hass: HomeAssistant) -> None: """Test that only top entities are returned per time category.""" user_id = str(uuid.uuid4()) + hass.states.async_set("light.most_used", "off") + hass.states.async_set("light.second", "off") + hass.states.async_set("light.third", "off") + hass.states.async_set("light.fourth", "off") + hass.states.async_set("light.fifth", "off") + hass.states.async_set("light.sixth", "off") + hass.states.async_set("light.seventh", "off") + # Create more than 5 different entities in morning - with freeze_time("2023-07-01 08:00:00+00:00"): + with freeze_time("2023-07-01 08:00:00"): # Create entities with different frequencies entities_with_counts = [ ("light.most_used", 10), @@ -308,7 +334,7 @@ async def test_entities_limit(hass: HomeAssistant) -> None: await async_wait_recording_done(hass) with ( - freeze_time("2023-07-02 10:00:00+00:00"), + freeze_time("2023-07-02 10:00:00"), patch( "homeassistant.components.usage_prediction.common_control.RESULTS_TO_INCLUDE", 5, @@ -335,7 +361,10 @@ async def test_different_users_separated(hass: HomeAssistant) -> None: user_id_1 = str(uuid.uuid4()) user_id_2 = str(uuid.uuid4()) - with freeze_time("2023-07-01 10:00:00+00:00"): + hass.states.async_set("light.user1_light", "off") + hass.states.async_set("light.user2_light", "off") + + with freeze_time("2023-07-01 10:00:00"): # User 1 events hass.bus.async_fire( EVENT_CALL_SERVICE, @@ -363,7 +392,7 @@ async def test_different_users_separated(hass: HomeAssistant) -> None: await async_wait_recording_done(hass) # Get results for each user - with freeze_time("2023-07-02 10:00:00+00:00"): # Next day, so events are recent + with freeze_time("2023-07-02 10:00:00"): # Next day, so events are recent results_user1 = await async_predict_common_control(hass, user_id_1) results_user2 = await async_predict_common_control(hass, user_id_2)