mirror of
https://github.com/home-assistant/core.git
synced 2026-02-15 07:36:16 +00:00
Prevent common control calling async methods from thread (#152931)
Co-authored-by: J. Nick Koston <nick@home-assistant.io> Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
@@ -3,13 +3,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import Counter
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Sequence
|
||||
from datetime import datetime, timedelta
|
||||
from functools import cache
|
||||
import logging
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.engine.row import Row
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from homeassistant.components.recorder import get_instance
|
||||
@@ -90,61 +91,32 @@ async def async_predict_common_control(
|
||||
Args:
|
||||
hass: Home Assistant instance
|
||||
user_id: User ID to filter events by.
|
||||
|
||||
Returns:
|
||||
Dictionary with time categories as keys and lists of most common entity IDs as values
|
||||
"""
|
||||
# Get the recorder instance to ensure it's ready
|
||||
recorder = get_instance(hass)
|
||||
ent_reg = er.async_get(hass)
|
||||
|
||||
# Execute the database operation in the recorder's executor
|
||||
return await recorder.async_add_executor_job(
|
||||
data = await recorder.async_add_executor_job(
|
||||
_fetch_with_session, hass, _fetch_and_process_data, ent_reg, user_id
|
||||
)
|
||||
|
||||
|
||||
def _fetch_and_process_data(
|
||||
session: Session, ent_reg: er.EntityRegistry, user_id: str
|
||||
) -> EntityUsagePredictions:
|
||||
"""Fetch and process service call events from the database."""
|
||||
# Prepare a dictionary to track results
|
||||
results: dict[str, Counter[str]] = {
|
||||
time_cat: Counter() for time_cat in TIME_CATEGORIES
|
||||
}
|
||||
|
||||
allowed_entities = set(hass.states.async_entity_ids(ALLOWED_DOMAINS))
|
||||
hidden_entities: set[str] = set()
|
||||
|
||||
# Keep track of contexts that we processed so that we will only process
|
||||
# the first service call in a context, and not subsequent calls.
|
||||
context_processed: set[bytes] = set()
|
||||
thirty_days_ago_ts = (dt_util.utcnow() - timedelta(days=30)).timestamp()
|
||||
user_id_bytes = uuid_hex_to_bytes_or_none(user_id)
|
||||
if not user_id_bytes:
|
||||
raise ValueError("Invalid user_id format")
|
||||
|
||||
# Build the main query for events with their data
|
||||
query = (
|
||||
select(
|
||||
Events.context_id_bin,
|
||||
Events.time_fired_ts,
|
||||
EventData.shared_data,
|
||||
)
|
||||
.select_from(Events)
|
||||
.outerjoin(EventData, Events.data_id == EventData.data_id)
|
||||
.outerjoin(EventTypes, Events.event_type_id == EventTypes.event_type_id)
|
||||
.where(Events.time_fired_ts >= thirty_days_ago_ts)
|
||||
.where(Events.context_user_id_bin == user_id_bytes)
|
||||
.where(EventTypes.event_type == "call_service")
|
||||
.order_by(Events.time_fired_ts)
|
||||
)
|
||||
|
||||
# Execute the query
|
||||
context_id: bytes
|
||||
time_fired_ts: float
|
||||
shared_data: str | None
|
||||
local_time_zone = dt_util.get_default_time_zone()
|
||||
for context_id, time_fired_ts, shared_data in (
|
||||
session.connection().execute(query).all()
|
||||
):
|
||||
for context_id, time_fired_ts, shared_data in data:
|
||||
# Skip if we have already processed an event that was part of this context
|
||||
if context_id in context_processed:
|
||||
continue
|
||||
@@ -153,7 +125,7 @@ def _fetch_and_process_data(
|
||||
context_processed.add(context_id)
|
||||
|
||||
# Parse the event data
|
||||
if not shared_data:
|
||||
if not time_fired_ts or not shared_data:
|
||||
continue
|
||||
|
||||
try:
|
||||
@@ -187,27 +159,26 @@ def _fetch_and_process_data(
|
||||
if not isinstance(entity_ids, list):
|
||||
entity_ids = [entity_ids]
|
||||
|
||||
# Filter out entity IDs that are not in allowed domains
|
||||
entity_ids = [
|
||||
entity_id
|
||||
for entity_id in entity_ids
|
||||
if entity_id.split(".")[0] in ALLOWED_DOMAINS
|
||||
and ((entry := ent_reg.async_get(entity_id)) is None or not entry.hidden)
|
||||
]
|
||||
# Convert to local time for time category determination
|
||||
period = time_category(
|
||||
datetime.fromtimestamp(time_fired_ts, local_time_zone).hour
|
||||
)
|
||||
period_results = results[period]
|
||||
|
||||
if not entity_ids:
|
||||
continue
|
||||
# Count entity usage
|
||||
for entity_id in entity_ids:
|
||||
if entity_id not in allowed_entities or entity_id in hidden_entities:
|
||||
continue
|
||||
|
||||
# Convert timestamp to datetime and determine time category
|
||||
if time_fired_ts:
|
||||
# Convert to local time for time category determination
|
||||
period = time_category(
|
||||
datetime.fromtimestamp(time_fired_ts, local_time_zone).hour
|
||||
)
|
||||
if (
|
||||
entity_id not in period_results
|
||||
and (entry := ent_reg.async_get(entity_id))
|
||||
and entry.hidden
|
||||
):
|
||||
hidden_entities.add(entity_id)
|
||||
continue
|
||||
|
||||
# Count entity usage
|
||||
for entity_id in entity_ids:
|
||||
results[period][entity_id] += 1
|
||||
period_results[entity_id] += 1
|
||||
|
||||
return EntityUsagePredictions(
|
||||
morning=[
|
||||
@@ -226,11 +197,40 @@ def _fetch_and_process_data(
|
||||
)
|
||||
|
||||
|
||||
def _fetch_and_process_data(
|
||||
session: Session, ent_reg: er.EntityRegistry, user_id: str
|
||||
) -> Sequence[Row[tuple[bytes | None, float | None, str | None]]]:
|
||||
"""Fetch and process service call events from the database."""
|
||||
thirty_days_ago_ts = (dt_util.utcnow() - timedelta(days=30)).timestamp()
|
||||
user_id_bytes = uuid_hex_to_bytes_or_none(user_id)
|
||||
if not user_id_bytes:
|
||||
raise ValueError("Invalid user_id format")
|
||||
|
||||
# Build the main query for events with their data
|
||||
query = (
|
||||
select(
|
||||
Events.context_id_bin,
|
||||
Events.time_fired_ts,
|
||||
EventData.shared_data,
|
||||
)
|
||||
.select_from(Events)
|
||||
.outerjoin(EventData, Events.data_id == EventData.data_id)
|
||||
.outerjoin(EventTypes, Events.event_type_id == EventTypes.event_type_id)
|
||||
.where(Events.time_fired_ts >= thirty_days_ago_ts)
|
||||
.where(Events.context_user_id_bin == user_id_bytes)
|
||||
.where(EventTypes.event_type == "call_service")
|
||||
.order_by(Events.time_fired_ts)
|
||||
)
|
||||
return session.connection().execute(query).all()
|
||||
|
||||
|
||||
def _fetch_with_session(
|
||||
hass: HomeAssistant,
|
||||
fetch_func: Callable[[Session], EntityUsagePredictions],
|
||||
fetch_func: Callable[
|
||||
[Session], Sequence[Row[tuple[bytes | None, float | None, str | None]]]
|
||||
],
|
||||
*args: object,
|
||||
) -> EntityUsagePredictions:
|
||||
) -> Sequence[Row[tuple[bytes | None, float | None, str | None]]]:
|
||||
"""Execute a fetch function with a database session."""
|
||||
with session_scope(hass=hass, read_only=True) as session:
|
||||
return fetch_func(session, *args)
|
||||
|
||||
@@ -62,9 +62,15 @@ async def test_with_service_calls(hass: HomeAssistant) -> None:
|
||||
"""Test function with actual service call events in database."""
|
||||
user_id = str(uuid.uuid4())
|
||||
|
||||
hass.states.async_set("light.living_room", "off")
|
||||
hass.states.async_set("light.kitchen", "off")
|
||||
hass.states.async_set("climate.thermostat", "off")
|
||||
hass.states.async_set("light.bedroom", "off")
|
||||
hass.states.async_set("lock.front_door", "locked")
|
||||
|
||||
# Create service call events at different times of day
|
||||
# Morning events - use separate service calls to get around context deduplication
|
||||
with freeze_time("2023-07-01 07:00:00+00:00"): # Morning
|
||||
with freeze_time("2023-07-01 07:00:00"): # Morning
|
||||
hass.bus.async_fire(
|
||||
EVENT_CALL_SERVICE,
|
||||
{
|
||||
@@ -77,7 +83,7 @@ async def test_with_service_calls(hass: HomeAssistant) -> None:
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Afternoon events
|
||||
with freeze_time("2023-07-01 14:00:00+00:00"): # Afternoon
|
||||
with freeze_time("2023-07-01 14:00:00"): # Afternoon
|
||||
hass.bus.async_fire(
|
||||
EVENT_CALL_SERVICE,
|
||||
{
|
||||
@@ -90,7 +96,7 @@ async def test_with_service_calls(hass: HomeAssistant) -> None:
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Evening events
|
||||
with freeze_time("2023-07-01 19:00:00+00:00"): # Evening
|
||||
with freeze_time("2023-07-01 19:00:00"): # Evening
|
||||
hass.bus.async_fire(
|
||||
EVENT_CALL_SERVICE,
|
||||
{
|
||||
@@ -103,7 +109,7 @@ async def test_with_service_calls(hass: HomeAssistant) -> None:
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Night events
|
||||
with freeze_time("2023-07-01 23:00:00+00:00"): # Night
|
||||
with freeze_time("2023-07-01 23:00:00"): # Night
|
||||
hass.bus.async_fire(
|
||||
EVENT_CALL_SERVICE,
|
||||
{
|
||||
@@ -119,7 +125,7 @@ async def test_with_service_calls(hass: HomeAssistant) -> None:
|
||||
await async_wait_recording_done(hass)
|
||||
|
||||
# Get predictions - make sure we're still in a reasonable timeframe
|
||||
with freeze_time("2023-07-02 10:00:00+00:00"): # Next day, so events are recent
|
||||
with freeze_time("2023-07-02 10:00:00"): # Next day, so events are recent
|
||||
results = await async_predict_common_control(hass, user_id)
|
||||
|
||||
# Verify results contain the expected entities in the correct time periods
|
||||
@@ -151,7 +157,12 @@ async def test_multiple_entities_in_one_call(hass: HomeAssistant) -> None:
|
||||
suggested_object_id="kitchen",
|
||||
)
|
||||
|
||||
with freeze_time("2023-07-01 10:00:00+00:00"): # Morning
|
||||
hass.states.async_set("light.living_room", "off")
|
||||
hass.states.async_set("light.kitchen", "off")
|
||||
hass.states.async_set("light.hallway", "off")
|
||||
hass.states.async_set("not_allowed.domain", "off")
|
||||
|
||||
with freeze_time("2023-07-01 10:00:00"): # Morning
|
||||
hass.bus.async_fire(
|
||||
EVENT_CALL_SERVICE,
|
||||
{
|
||||
@@ -163,6 +174,7 @@ async def test_multiple_entities_in_one_call(hass: HomeAssistant) -> None:
|
||||
"light.kitchen",
|
||||
"light.hallway",
|
||||
"not_allowed.domain",
|
||||
"light.not_in_state_machine",
|
||||
]
|
||||
},
|
||||
},
|
||||
@@ -172,7 +184,7 @@ async def test_multiple_entities_in_one_call(hass: HomeAssistant) -> None:
|
||||
|
||||
await async_wait_recording_done(hass)
|
||||
|
||||
with freeze_time("2023-07-02 10:00:00+00:00"): # Next day, so events are recent
|
||||
with freeze_time("2023-07-02 10:00:00"): # Next day, so events are recent
|
||||
results = await async_predict_common_control(hass, user_id)
|
||||
|
||||
# Two lights should be counted (10:00 UTC = 02:00 local = night)
|
||||
@@ -189,7 +201,10 @@ async def test_context_deduplication(hass: HomeAssistant) -> None:
|
||||
user_id = str(uuid.uuid4())
|
||||
context = Context(user_id=user_id)
|
||||
|
||||
with freeze_time("2023-07-01 10:00:00+00:00"): # Morning
|
||||
hass.states.async_set("light.living_room", "off")
|
||||
hass.states.async_set("switch.coffee_maker", "off")
|
||||
|
||||
with freeze_time("2023-07-01 10:00:00"): # Morning
|
||||
# Fire multiple events with the same context
|
||||
hass.bus.async_fire(
|
||||
EVENT_CALL_SERVICE,
|
||||
@@ -215,7 +230,7 @@ async def test_context_deduplication(hass: HomeAssistant) -> None:
|
||||
|
||||
await async_wait_recording_done(hass)
|
||||
|
||||
with freeze_time("2023-07-02 10:00:00+00:00"): # Next day, so events are recent
|
||||
with freeze_time("2023-07-02 10:00:00"): # Next day, so events are recent
|
||||
results = await async_predict_common_control(hass, user_id)
|
||||
|
||||
# Only the first event should be processed (10:00 UTC = 02:00 local = night)
|
||||
@@ -232,8 +247,11 @@ async def test_old_events_excluded(hass: HomeAssistant) -> None:
|
||||
"""Test that events older than 30 days are excluded."""
|
||||
user_id = str(uuid.uuid4())
|
||||
|
||||
hass.states.async_set("light.old_event", "off")
|
||||
hass.states.async_set("light.recent_event", "off")
|
||||
|
||||
# Create an old event (35 days ago)
|
||||
with freeze_time("2023-05-27 10:00:00+00:00"): # 35 days before July 1st
|
||||
with freeze_time("2023-05-27 10:00:00"): # 35 days before July 1st
|
||||
hass.bus.async_fire(
|
||||
EVENT_CALL_SERVICE,
|
||||
{
|
||||
@@ -246,7 +264,7 @@ async def test_old_events_excluded(hass: HomeAssistant) -> None:
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Create a recent event (5 days ago)
|
||||
with freeze_time("2023-06-26 10:00:00+00:00"): # 5 days before July 1st
|
||||
with freeze_time("2023-06-26 10:00:00"): # 5 days before July 1st
|
||||
hass.bus.async_fire(
|
||||
EVENT_CALL_SERVICE,
|
||||
{
|
||||
@@ -261,7 +279,7 @@ async def test_old_events_excluded(hass: HomeAssistant) -> None:
|
||||
await async_wait_recording_done(hass)
|
||||
|
||||
# Query with current time
|
||||
with freeze_time("2023-07-01 10:00:00+00:00"):
|
||||
with freeze_time("2023-07-01 10:00:00"):
|
||||
results = await async_predict_common_control(hass, user_id)
|
||||
|
||||
# Only recent event should be included (10:00 UTC = 02:00 local = night)
|
||||
@@ -278,8 +296,16 @@ async def test_entities_limit(hass: HomeAssistant) -> None:
|
||||
"""Test that only top entities are returned per time category."""
|
||||
user_id = str(uuid.uuid4())
|
||||
|
||||
hass.states.async_set("light.most_used", "off")
|
||||
hass.states.async_set("light.second", "off")
|
||||
hass.states.async_set("light.third", "off")
|
||||
hass.states.async_set("light.fourth", "off")
|
||||
hass.states.async_set("light.fifth", "off")
|
||||
hass.states.async_set("light.sixth", "off")
|
||||
hass.states.async_set("light.seventh", "off")
|
||||
|
||||
# Create more than 5 different entities in morning
|
||||
with freeze_time("2023-07-01 08:00:00+00:00"):
|
||||
with freeze_time("2023-07-01 08:00:00"):
|
||||
# Create entities with different frequencies
|
||||
entities_with_counts = [
|
||||
("light.most_used", 10),
|
||||
@@ -308,7 +334,7 @@ async def test_entities_limit(hass: HomeAssistant) -> None:
|
||||
await async_wait_recording_done(hass)
|
||||
|
||||
with (
|
||||
freeze_time("2023-07-02 10:00:00+00:00"),
|
||||
freeze_time("2023-07-02 10:00:00"),
|
||||
patch(
|
||||
"homeassistant.components.usage_prediction.common_control.RESULTS_TO_INCLUDE",
|
||||
5,
|
||||
@@ -335,7 +361,10 @@ async def test_different_users_separated(hass: HomeAssistant) -> None:
|
||||
user_id_1 = str(uuid.uuid4())
|
||||
user_id_2 = str(uuid.uuid4())
|
||||
|
||||
with freeze_time("2023-07-01 10:00:00+00:00"):
|
||||
hass.states.async_set("light.user1_light", "off")
|
||||
hass.states.async_set("light.user2_light", "off")
|
||||
|
||||
with freeze_time("2023-07-01 10:00:00"):
|
||||
# User 1 events
|
||||
hass.bus.async_fire(
|
||||
EVENT_CALL_SERVICE,
|
||||
@@ -363,7 +392,7 @@ async def test_different_users_separated(hass: HomeAssistant) -> None:
|
||||
await async_wait_recording_done(hass)
|
||||
|
||||
# Get results for each user
|
||||
with freeze_time("2023-07-02 10:00:00+00:00"): # Next day, so events are recent
|
||||
with freeze_time("2023-07-02 10:00:00"): # Next day, so events are recent
|
||||
results_user1 = await async_predict_common_control(hass, user_id_1)
|
||||
results_user2 = await async_predict_common_control(hass, user_id_2)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user