1
0
mirror of https://github.com/home-assistant/core.git synced 2026-02-15 07:36:16 +00:00

Prevent common control calling async methods from thread (#152931)

Co-authored-by: J. Nick Koston <nick@home-assistant.io>
Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
Paulus Schoutsen
2025-09-24 23:09:54 -04:00
committed by GitHub
parent 7c8ad9d535
commit 91e13d447a
2 changed files with 101 additions and 72 deletions

View File

@@ -3,13 +3,14 @@
from __future__ import annotations
from collections import Counter
from collections.abc import Callable
from collections.abc import Callable, Sequence
from datetime import datetime, timedelta
from functools import cache
import logging
from typing import Any, Literal, cast
from sqlalchemy import select
from sqlalchemy.engine.row import Row
from sqlalchemy.orm import Session
from homeassistant.components.recorder import get_instance
@@ -90,61 +91,32 @@ async def async_predict_common_control(
Args:
hass: Home Assistant instance
user_id: User ID to filter events by.
Returns:
Dictionary with time categories as keys and lists of most common entity IDs as values
"""
# Get the recorder instance to ensure it's ready
recorder = get_instance(hass)
ent_reg = er.async_get(hass)
# Execute the database operation in the recorder's executor
return await recorder.async_add_executor_job(
data = await recorder.async_add_executor_job(
_fetch_with_session, hass, _fetch_and_process_data, ent_reg, user_id
)
def _fetch_and_process_data(
session: Session, ent_reg: er.EntityRegistry, user_id: str
) -> EntityUsagePredictions:
"""Fetch and process service call events from the database."""
# Prepare a dictionary to track results
results: dict[str, Counter[str]] = {
time_cat: Counter() for time_cat in TIME_CATEGORIES
}
allowed_entities = set(hass.states.async_entity_ids(ALLOWED_DOMAINS))
hidden_entities: set[str] = set()
# Keep track of contexts that we processed so that we will only process
# the first service call in a context, and not subsequent calls.
context_processed: set[bytes] = set()
thirty_days_ago_ts = (dt_util.utcnow() - timedelta(days=30)).timestamp()
user_id_bytes = uuid_hex_to_bytes_or_none(user_id)
if not user_id_bytes:
raise ValueError("Invalid user_id format")
# Build the main query for events with their data
query = (
select(
Events.context_id_bin,
Events.time_fired_ts,
EventData.shared_data,
)
.select_from(Events)
.outerjoin(EventData, Events.data_id == EventData.data_id)
.outerjoin(EventTypes, Events.event_type_id == EventTypes.event_type_id)
.where(Events.time_fired_ts >= thirty_days_ago_ts)
.where(Events.context_user_id_bin == user_id_bytes)
.where(EventTypes.event_type == "call_service")
.order_by(Events.time_fired_ts)
)
# Execute the query
context_id: bytes
time_fired_ts: float
shared_data: str | None
local_time_zone = dt_util.get_default_time_zone()
for context_id, time_fired_ts, shared_data in (
session.connection().execute(query).all()
):
for context_id, time_fired_ts, shared_data in data:
# Skip if we have already processed an event that was part of this context
if context_id in context_processed:
continue
@@ -153,7 +125,7 @@ def _fetch_and_process_data(
context_processed.add(context_id)
# Parse the event data
if not shared_data:
if not time_fired_ts or not shared_data:
continue
try:
@@ -187,27 +159,26 @@ def _fetch_and_process_data(
if not isinstance(entity_ids, list):
entity_ids = [entity_ids]
# Filter out entity IDs that are not in allowed domains
entity_ids = [
entity_id
for entity_id in entity_ids
if entity_id.split(".")[0] in ALLOWED_DOMAINS
and ((entry := ent_reg.async_get(entity_id)) is None or not entry.hidden)
]
# Convert to local time for time category determination
period = time_category(
datetime.fromtimestamp(time_fired_ts, local_time_zone).hour
)
period_results = results[period]
if not entity_ids:
continue
# Count entity usage
for entity_id in entity_ids:
if entity_id not in allowed_entities or entity_id in hidden_entities:
continue
# Convert timestamp to datetime and determine time category
if time_fired_ts:
# Convert to local time for time category determination
period = time_category(
datetime.fromtimestamp(time_fired_ts, local_time_zone).hour
)
if (
entity_id not in period_results
and (entry := ent_reg.async_get(entity_id))
and entry.hidden
):
hidden_entities.add(entity_id)
continue
# Count entity usage
for entity_id in entity_ids:
results[period][entity_id] += 1
period_results[entity_id] += 1
return EntityUsagePredictions(
morning=[
@@ -226,11 +197,40 @@ def _fetch_and_process_data(
)
def _fetch_and_process_data(
session: Session, ent_reg: er.EntityRegistry, user_id: str
) -> Sequence[Row[tuple[bytes | None, float | None, str | None]]]:
"""Fetch and process service call events from the database."""
thirty_days_ago_ts = (dt_util.utcnow() - timedelta(days=30)).timestamp()
user_id_bytes = uuid_hex_to_bytes_or_none(user_id)
if not user_id_bytes:
raise ValueError("Invalid user_id format")
# Build the main query for events with their data
query = (
select(
Events.context_id_bin,
Events.time_fired_ts,
EventData.shared_data,
)
.select_from(Events)
.outerjoin(EventData, Events.data_id == EventData.data_id)
.outerjoin(EventTypes, Events.event_type_id == EventTypes.event_type_id)
.where(Events.time_fired_ts >= thirty_days_ago_ts)
.where(Events.context_user_id_bin == user_id_bytes)
.where(EventTypes.event_type == "call_service")
.order_by(Events.time_fired_ts)
)
return session.connection().execute(query).all()
def _fetch_with_session(
hass: HomeAssistant,
fetch_func: Callable[[Session], EntityUsagePredictions],
fetch_func: Callable[
[Session], Sequence[Row[tuple[bytes | None, float | None, str | None]]]
],
*args: object,
) -> EntityUsagePredictions:
) -> Sequence[Row[tuple[bytes | None, float | None, str | None]]]:
"""Execute a fetch function with a database session."""
with session_scope(hass=hass, read_only=True) as session:
return fetch_func(session, *args)

View File

@@ -62,9 +62,15 @@ async def test_with_service_calls(hass: HomeAssistant) -> None:
"""Test function with actual service call events in database."""
user_id = str(uuid.uuid4())
hass.states.async_set("light.living_room", "off")
hass.states.async_set("light.kitchen", "off")
hass.states.async_set("climate.thermostat", "off")
hass.states.async_set("light.bedroom", "off")
hass.states.async_set("lock.front_door", "locked")
# Create service call events at different times of day
# Morning events - use separate service calls to get around context deduplication
with freeze_time("2023-07-01 07:00:00+00:00"): # Morning
with freeze_time("2023-07-01 07:00:00"): # Morning
hass.bus.async_fire(
EVENT_CALL_SERVICE,
{
@@ -77,7 +83,7 @@ async def test_with_service_calls(hass: HomeAssistant) -> None:
await hass.async_block_till_done()
# Afternoon events
with freeze_time("2023-07-01 14:00:00+00:00"): # Afternoon
with freeze_time("2023-07-01 14:00:00"): # Afternoon
hass.bus.async_fire(
EVENT_CALL_SERVICE,
{
@@ -90,7 +96,7 @@ async def test_with_service_calls(hass: HomeAssistant) -> None:
await hass.async_block_till_done()
# Evening events
with freeze_time("2023-07-01 19:00:00+00:00"): # Evening
with freeze_time("2023-07-01 19:00:00"): # Evening
hass.bus.async_fire(
EVENT_CALL_SERVICE,
{
@@ -103,7 +109,7 @@ async def test_with_service_calls(hass: HomeAssistant) -> None:
await hass.async_block_till_done()
# Night events
with freeze_time("2023-07-01 23:00:00+00:00"): # Night
with freeze_time("2023-07-01 23:00:00"): # Night
hass.bus.async_fire(
EVENT_CALL_SERVICE,
{
@@ -119,7 +125,7 @@ async def test_with_service_calls(hass: HomeAssistant) -> None:
await async_wait_recording_done(hass)
# Get predictions - make sure we're still in a reasonable timeframe
with freeze_time("2023-07-02 10:00:00+00:00"): # Next day, so events are recent
with freeze_time("2023-07-02 10:00:00"): # Next day, so events are recent
results = await async_predict_common_control(hass, user_id)
# Verify results contain the expected entities in the correct time periods
@@ -151,7 +157,12 @@ async def test_multiple_entities_in_one_call(hass: HomeAssistant) -> None:
suggested_object_id="kitchen",
)
with freeze_time("2023-07-01 10:00:00+00:00"): # Morning
hass.states.async_set("light.living_room", "off")
hass.states.async_set("light.kitchen", "off")
hass.states.async_set("light.hallway", "off")
hass.states.async_set("not_allowed.domain", "off")
with freeze_time("2023-07-01 10:00:00"): # Morning
hass.bus.async_fire(
EVENT_CALL_SERVICE,
{
@@ -163,6 +174,7 @@ async def test_multiple_entities_in_one_call(hass: HomeAssistant) -> None:
"light.kitchen",
"light.hallway",
"not_allowed.domain",
"light.not_in_state_machine",
]
},
},
@@ -172,7 +184,7 @@ async def test_multiple_entities_in_one_call(hass: HomeAssistant) -> None:
await async_wait_recording_done(hass)
with freeze_time("2023-07-02 10:00:00+00:00"): # Next day, so events are recent
with freeze_time("2023-07-02 10:00:00"): # Next day, so events are recent
results = await async_predict_common_control(hass, user_id)
# Two lights should be counted (10:00 UTC = 02:00 local = night)
@@ -189,7 +201,10 @@ async def test_context_deduplication(hass: HomeAssistant) -> None:
user_id = str(uuid.uuid4())
context = Context(user_id=user_id)
with freeze_time("2023-07-01 10:00:00+00:00"): # Morning
hass.states.async_set("light.living_room", "off")
hass.states.async_set("switch.coffee_maker", "off")
with freeze_time("2023-07-01 10:00:00"): # Morning
# Fire multiple events with the same context
hass.bus.async_fire(
EVENT_CALL_SERVICE,
@@ -215,7 +230,7 @@ async def test_context_deduplication(hass: HomeAssistant) -> None:
await async_wait_recording_done(hass)
with freeze_time("2023-07-02 10:00:00+00:00"): # Next day, so events are recent
with freeze_time("2023-07-02 10:00:00"): # Next day, so events are recent
results = await async_predict_common_control(hass, user_id)
# Only the first event should be processed (10:00 UTC = 02:00 local = night)
@@ -232,8 +247,11 @@ async def test_old_events_excluded(hass: HomeAssistant) -> None:
"""Test that events older than 30 days are excluded."""
user_id = str(uuid.uuid4())
hass.states.async_set("light.old_event", "off")
hass.states.async_set("light.recent_event", "off")
# Create an old event (35 days ago)
with freeze_time("2023-05-27 10:00:00+00:00"): # 35 days before July 1st
with freeze_time("2023-05-27 10:00:00"): # 35 days before July 1st
hass.bus.async_fire(
EVENT_CALL_SERVICE,
{
@@ -246,7 +264,7 @@ async def test_old_events_excluded(hass: HomeAssistant) -> None:
await hass.async_block_till_done()
# Create a recent event (5 days ago)
with freeze_time("2023-06-26 10:00:00+00:00"): # 5 days before July 1st
with freeze_time("2023-06-26 10:00:00"): # 5 days before July 1st
hass.bus.async_fire(
EVENT_CALL_SERVICE,
{
@@ -261,7 +279,7 @@ async def test_old_events_excluded(hass: HomeAssistant) -> None:
await async_wait_recording_done(hass)
# Query with current time
with freeze_time("2023-07-01 10:00:00+00:00"):
with freeze_time("2023-07-01 10:00:00"):
results = await async_predict_common_control(hass, user_id)
# Only recent event should be included (10:00 UTC = 02:00 local = night)
@@ -278,8 +296,16 @@ async def test_entities_limit(hass: HomeAssistant) -> None:
"""Test that only top entities are returned per time category."""
user_id = str(uuid.uuid4())
hass.states.async_set("light.most_used", "off")
hass.states.async_set("light.second", "off")
hass.states.async_set("light.third", "off")
hass.states.async_set("light.fourth", "off")
hass.states.async_set("light.fifth", "off")
hass.states.async_set("light.sixth", "off")
hass.states.async_set("light.seventh", "off")
# Create more than 5 different entities in morning
with freeze_time("2023-07-01 08:00:00+00:00"):
with freeze_time("2023-07-01 08:00:00"):
# Create entities with different frequencies
entities_with_counts = [
("light.most_used", 10),
@@ -308,7 +334,7 @@ async def test_entities_limit(hass: HomeAssistant) -> None:
await async_wait_recording_done(hass)
with (
freeze_time("2023-07-02 10:00:00+00:00"),
freeze_time("2023-07-02 10:00:00"),
patch(
"homeassistant.components.usage_prediction.common_control.RESULTS_TO_INCLUDE",
5,
@@ -335,7 +361,10 @@ async def test_different_users_separated(hass: HomeAssistant) -> None:
user_id_1 = str(uuid.uuid4())
user_id_2 = str(uuid.uuid4())
with freeze_time("2023-07-01 10:00:00+00:00"):
hass.states.async_set("light.user1_light", "off")
hass.states.async_set("light.user2_light", "off")
with freeze_time("2023-07-01 10:00:00"):
# User 1 events
hass.bus.async_fire(
EVENT_CALL_SERVICE,
@@ -363,7 +392,7 @@ async def test_different_users_separated(hass: HomeAssistant) -> None:
await async_wait_recording_done(hass)
# Get results for each user
with freeze_time("2023-07-02 10:00:00+00:00"): # Next day, so events are recent
with freeze_time("2023-07-02 10:00:00"): # Next day, so events are recent
results_user1 = await async_predict_common_control(hass, user_id_1)
results_user2 = await async_predict_common_control(hass, user_id_2)