mirror of
https://github.com/home-assistant/core.git
synced 2025-12-24 21:06:19 +00:00
Context (#15674)
* Add context * Add context to switch/light services * Test set_state API * Lint * Fix tests * Do not include context yet in comparison * Do not pass in loop * Fix Z-Wave tests * Add websocket test without user
This commit is contained in:
committed by
Jason Hu
parent
867f80715e
commit
c7f4bdafc0
@@ -15,6 +15,7 @@ import re
|
||||
import sys
|
||||
import threading
|
||||
from time import monotonic
|
||||
import uuid
|
||||
|
||||
from types import MappingProxyType
|
||||
# pylint: disable=unused-import
|
||||
@@ -23,12 +24,13 @@ from typing import ( # NOQA
|
||||
TYPE_CHECKING, Awaitable, Iterator)
|
||||
|
||||
from async_timeout import timeout
|
||||
import attr
|
||||
import voluptuous as vol
|
||||
from voluptuous.humanize import humanize_error
|
||||
|
||||
from homeassistant.const import (
|
||||
ATTR_DOMAIN, ATTR_FRIENDLY_NAME, ATTR_NOW, ATTR_SERVICE,
|
||||
ATTR_SERVICE_CALL_ID, ATTR_SERVICE_DATA, EVENT_CALL_SERVICE,
|
||||
ATTR_SERVICE_DATA, EVENT_CALL_SERVICE,
|
||||
EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP,
|
||||
EVENT_SERVICE_EXECUTED, EVENT_SERVICE_REGISTERED, EVENT_STATE_CHANGED,
|
||||
EVENT_TIME_CHANGED, MATCH_ALL, EVENT_HOMEASSISTANT_CLOSE,
|
||||
@@ -191,7 +193,7 @@ class HomeAssistant:
|
||||
try:
|
||||
# Only block for EVENT_HOMEASSISTANT_START listener
|
||||
self.async_stop_track_tasks()
|
||||
with timeout(TIMEOUT_EVENT_START, loop=self.loop):
|
||||
with timeout(TIMEOUT_EVENT_START):
|
||||
await self.async_block_till_done()
|
||||
except asyncio.TimeoutError:
|
||||
_LOGGER.warning(
|
||||
@@ -201,7 +203,7 @@ class HomeAssistant:
|
||||
', '.join(self.config.components))
|
||||
|
||||
# Allow automations to set up the start triggers before changing state
|
||||
await asyncio.sleep(0, loop=self.loop)
|
||||
await asyncio.sleep(0)
|
||||
self.state = CoreState.running
|
||||
_async_create_timer(self)
|
||||
|
||||
@@ -307,16 +309,16 @@ class HomeAssistant:
|
||||
async def async_block_till_done(self) -> None:
|
||||
"""Block till all pending work is done."""
|
||||
# To flush out any call_soon_threadsafe
|
||||
await asyncio.sleep(0, loop=self.loop)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
while self._pending_tasks:
|
||||
pending = [task for task in self._pending_tasks
|
||||
if not task.done()]
|
||||
self._pending_tasks.clear()
|
||||
if pending:
|
||||
await asyncio.wait(pending, loop=self.loop)
|
||||
await asyncio.wait(pending)
|
||||
else:
|
||||
await asyncio.sleep(0, loop=self.loop)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop Home Assistant and shuts down all threads."""
|
||||
@@ -343,6 +345,27 @@ class HomeAssistant:
|
||||
self.loop.stop()
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True)
|
||||
class Context:
|
||||
"""The context that triggered something."""
|
||||
|
||||
user_id = attr.ib(
|
||||
type=str,
|
||||
default=None,
|
||||
)
|
||||
id = attr.ib(
|
||||
type=str,
|
||||
default=attr.Factory(lambda: uuid.uuid4().hex),
|
||||
)
|
||||
|
||||
def as_dict(self) -> dict:
|
||||
"""Return a dictionary representation of the context."""
|
||||
return {
|
||||
'id': self.id,
|
||||
'user_id': self.user_id,
|
||||
}
|
||||
|
||||
|
||||
class EventOrigin(enum.Enum):
|
||||
"""Represent the origin of an event."""
|
||||
|
||||
@@ -357,16 +380,18 @@ class EventOrigin(enum.Enum):
|
||||
class Event:
|
||||
"""Representation of an event within the bus."""
|
||||
|
||||
__slots__ = ['event_type', 'data', 'origin', 'time_fired']
|
||||
__slots__ = ['event_type', 'data', 'origin', 'time_fired', 'context']
|
||||
|
||||
def __init__(self, event_type: str, data: Optional[Dict] = None,
|
||||
origin: EventOrigin = EventOrigin.local,
|
||||
time_fired: Optional[int] = None) -> None:
|
||||
time_fired: Optional[int] = None,
|
||||
context: Optional[Context] = None) -> None:
|
||||
"""Initialize a new event."""
|
||||
self.event_type = event_type
|
||||
self.data = data or {}
|
||||
self.origin = origin
|
||||
self.time_fired = time_fired or dt_util.utcnow()
|
||||
self.context = context or Context()
|
||||
|
||||
def as_dict(self) -> Dict:
|
||||
"""Create a dict representation of this Event.
|
||||
@@ -378,6 +403,7 @@ class Event:
|
||||
'data': dict(self.data),
|
||||
'origin': str(self.origin),
|
||||
'time_fired': self.time_fired,
|
||||
'context': self.context.as_dict()
|
||||
}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@@ -425,14 +451,16 @@ class EventBus:
|
||||
).result()
|
||||
|
||||
def fire(self, event_type: str, event_data: Optional[Dict] = None,
|
||||
origin: EventOrigin = EventOrigin.local) -> None:
|
||||
origin: EventOrigin = EventOrigin.local,
|
||||
context: Optional[Context] = None) -> None:
|
||||
"""Fire an event."""
|
||||
self._hass.loop.call_soon_threadsafe(
|
||||
self.async_fire, event_type, event_data, origin)
|
||||
self.async_fire, event_type, event_data, origin, context)
|
||||
|
||||
@callback
|
||||
def async_fire(self, event_type: str, event_data: Optional[Dict] = None,
|
||||
origin: EventOrigin = EventOrigin.local) -> None:
|
||||
origin: EventOrigin = EventOrigin.local,
|
||||
context: Optional[Context] = None) -> None:
|
||||
"""Fire an event.
|
||||
|
||||
This method must be run in the event loop.
|
||||
@@ -445,7 +473,7 @@ class EventBus:
|
||||
event_type != EVENT_HOMEASSISTANT_CLOSE):
|
||||
listeners = match_all_listeners + listeners
|
||||
|
||||
event = Event(event_type, event_data, origin)
|
||||
event = Event(event_type, event_data, origin, None, context)
|
||||
|
||||
if event_type != EVENT_TIME_CHANGED:
|
||||
_LOGGER.info("Bus:Handling %s", event)
|
||||
@@ -569,15 +597,17 @@ class State:
|
||||
attributes: extra information on entity and state
|
||||
last_changed: last time the state was changed, not the attributes.
|
||||
last_updated: last time this object was updated.
|
||||
context: Context in which it was created
|
||||
"""
|
||||
|
||||
__slots__ = ['entity_id', 'state', 'attributes',
|
||||
'last_changed', 'last_updated']
|
||||
'last_changed', 'last_updated', 'context']
|
||||
|
||||
def __init__(self, entity_id: str, state: Any,
|
||||
attributes: Optional[Dict] = None,
|
||||
last_changed: Optional[datetime.datetime] = None,
|
||||
last_updated: Optional[datetime.datetime] = None) -> None:
|
||||
last_updated: Optional[datetime.datetime] = None,
|
||||
context: Optional[Context] = None) -> None:
|
||||
"""Initialize a new state."""
|
||||
state = str(state)
|
||||
|
||||
@@ -596,6 +626,7 @@ class State:
|
||||
self.attributes = MappingProxyType(attributes or {})
|
||||
self.last_updated = last_updated or dt_util.utcnow()
|
||||
self.last_changed = last_changed or self.last_updated
|
||||
self.context = context or Context()
|
||||
|
||||
@property
|
||||
def domain(self) -> str:
|
||||
@@ -626,7 +657,8 @@ class State:
|
||||
'state': self.state,
|
||||
'attributes': dict(self.attributes),
|
||||
'last_changed': self.last_changed,
|
||||
'last_updated': self.last_updated}
|
||||
'last_updated': self.last_updated,
|
||||
'context': self.context.as_dict()}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, json_dict: Dict) -> Any:
|
||||
@@ -650,8 +682,13 @@ class State:
|
||||
if isinstance(last_updated, str):
|
||||
last_updated = dt_util.parse_datetime(last_updated)
|
||||
|
||||
context = json_dict.get('context')
|
||||
if context:
|
||||
context = Context(**context)
|
||||
|
||||
return cls(json_dict['entity_id'], json_dict['state'],
|
||||
json_dict.get('attributes'), last_changed, last_updated)
|
||||
json_dict.get('attributes'), last_changed, last_updated,
|
||||
context)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Return the comparison of the state."""
|
||||
@@ -662,11 +699,11 @@ class State:
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return the representation of the states."""
|
||||
attr = "; {}".format(util.repr_helper(self.attributes)) \
|
||||
if self.attributes else ""
|
||||
attrs = "; {}".format(util.repr_helper(self.attributes)) \
|
||||
if self.attributes else ""
|
||||
|
||||
return "<state {}={}{} @ {}>".format(
|
||||
self.entity_id, self.state, attr,
|
||||
self.entity_id, self.state, attrs,
|
||||
dt_util.as_local(self.last_changed).isoformat())
|
||||
|
||||
|
||||
@@ -761,7 +798,8 @@ class StateMachine:
|
||||
|
||||
def set(self, entity_id: str, new_state: Any,
|
||||
attributes: Optional[Dict] = None,
|
||||
force_update: bool = False) -> None:
|
||||
force_update: bool = False,
|
||||
context: Optional[Context] = None) -> None:
|
||||
"""Set the state of an entity, add entity if it does not exist.
|
||||
|
||||
Attributes is an optional dict to specify attributes of this state.
|
||||
@@ -772,12 +810,14 @@ class StateMachine:
|
||||
run_callback_threadsafe(
|
||||
self._loop,
|
||||
self.async_set, entity_id, new_state, attributes, force_update,
|
||||
context,
|
||||
).result()
|
||||
|
||||
@callback
|
||||
def async_set(self, entity_id: str, new_state: Any,
|
||||
attributes: Optional[Dict] = None,
|
||||
force_update: bool = False) -> None:
|
||||
force_update: bool = False,
|
||||
context: Optional[Context] = None) -> None:
|
||||
"""Set the state of an entity, add entity if it does not exist.
|
||||
|
||||
Attributes is an optional dict to specify attributes of this state.
|
||||
@@ -804,13 +844,17 @@ class StateMachine:
|
||||
if same_state and same_attr:
|
||||
return
|
||||
|
||||
state = State(entity_id, new_state, attributes, last_changed)
|
||||
if context is None:
|
||||
context = Context()
|
||||
|
||||
state = State(entity_id, new_state, attributes, last_changed, None,
|
||||
context)
|
||||
self._states[entity_id] = state
|
||||
self._bus.async_fire(EVENT_STATE_CHANGED, {
|
||||
'entity_id': entity_id,
|
||||
'old_state': old_state,
|
||||
'new_state': state,
|
||||
})
|
||||
}, EventOrigin.local, context)
|
||||
|
||||
|
||||
class Service:
|
||||
@@ -818,7 +862,8 @@ class Service:
|
||||
|
||||
__slots__ = ['func', 'schema', 'is_callback', 'is_coroutinefunction']
|
||||
|
||||
def __init__(self, func: Callable, schema: Optional[vol.Schema]) -> None:
|
||||
def __init__(self, func: Callable, schema: Optional[vol.Schema],
|
||||
context: Optional[Context] = None) -> None:
|
||||
"""Initialize a service."""
|
||||
self.func = func
|
||||
self.schema = schema
|
||||
@@ -829,23 +874,25 @@ class Service:
|
||||
class ServiceCall:
|
||||
"""Representation of a call to a service."""
|
||||
|
||||
__slots__ = ['domain', 'service', 'data', 'call_id']
|
||||
__slots__ = ['domain', 'service', 'data', 'context']
|
||||
|
||||
def __init__(self, domain: str, service: str, data: Optional[Dict] = None,
|
||||
call_id: Optional[str] = None) -> None:
|
||||
context: Optional[Context] = None) -> None:
|
||||
"""Initialize a service call."""
|
||||
self.domain = domain.lower()
|
||||
self.service = service.lower()
|
||||
self.data = MappingProxyType(data or {})
|
||||
self.call_id = call_id
|
||||
self.context = context or Context()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return the representation of the service."""
|
||||
if self.data:
|
||||
return "<ServiceCall {}.{}: {}>".format(
|
||||
self.domain, self.service, util.repr_helper(self.data))
|
||||
return "<ServiceCall {}.{} (c:{}): {}>".format(
|
||||
self.domain, self.service, self.context.id,
|
||||
util.repr_helper(self.data))
|
||||
|
||||
return "<ServiceCall {}.{}>".format(self.domain, self.service)
|
||||
return "<ServiceCall {}.{} (c:{})>".format(
|
||||
self.domain, self.service, self.context.id)
|
||||
|
||||
|
||||
class ServiceRegistry:
|
||||
@@ -857,15 +904,6 @@ class ServiceRegistry:
|
||||
self._hass = hass
|
||||
self._async_unsub_call_event = None # type: Optional[CALLBACK_TYPE]
|
||||
|
||||
def _gen_unique_id() -> Iterator[str]:
|
||||
cur_id = 1
|
||||
while True:
|
||||
yield '{}-{}'.format(id(self), cur_id)
|
||||
cur_id += 1
|
||||
|
||||
gen = _gen_unique_id()
|
||||
self._generate_unique_id = lambda: next(gen)
|
||||
|
||||
@property
|
||||
def services(self) -> Dict[str, Dict[str, Service]]:
|
||||
"""Return dictionary with per domain a list of available services."""
|
||||
@@ -957,7 +995,8 @@ class ServiceRegistry:
|
||||
|
||||
def call(self, domain: str, service: str,
|
||||
service_data: Optional[Dict] = None,
|
||||
blocking: bool = False) -> Optional[bool]:
|
||||
blocking: bool = False,
|
||||
context: Optional[Context] = None) -> Optional[bool]:
|
||||
"""
|
||||
Call a service.
|
||||
|
||||
@@ -975,13 +1014,14 @@ class ServiceRegistry:
|
||||
the keys ATTR_DOMAIN and ATTR_SERVICE in your service_data.
|
||||
"""
|
||||
return run_coroutine_threadsafe( # type: ignore
|
||||
self.async_call(domain, service, service_data, blocking),
|
||||
self.async_call(domain, service, service_data, blocking, context),
|
||||
self._hass.loop
|
||||
).result()
|
||||
|
||||
async def async_call(self, domain: str, service: str,
|
||||
service_data: Optional[Dict] = None,
|
||||
blocking: bool = False) -> Optional[bool]:
|
||||
blocking: bool = False,
|
||||
context: Optional[Context] = None) -> Optional[bool]:
|
||||
"""
|
||||
Call a service.
|
||||
|
||||
@@ -1000,44 +1040,42 @@ class ServiceRegistry:
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
call_id = self._generate_unique_id()
|
||||
|
||||
context = context or Context()
|
||||
event_data = {
|
||||
ATTR_DOMAIN: domain.lower(),
|
||||
ATTR_SERVICE: service.lower(),
|
||||
ATTR_SERVICE_DATA: service_data,
|
||||
ATTR_SERVICE_CALL_ID: call_id,
|
||||
}
|
||||
|
||||
if blocking:
|
||||
fut = asyncio.Future(loop=self._hass.loop) # type: asyncio.Future
|
||||
if not blocking:
|
||||
self._hass.bus.async_fire(
|
||||
EVENT_CALL_SERVICE, event_data, EventOrigin.local, context)
|
||||
return None
|
||||
|
||||
@callback
|
||||
def service_executed(event: Event) -> None:
|
||||
"""Handle an executed service."""
|
||||
if event.data[ATTR_SERVICE_CALL_ID] == call_id:
|
||||
fut.set_result(True)
|
||||
fut = asyncio.Future() # type: asyncio.Future
|
||||
|
||||
unsub = self._hass.bus.async_listen(
|
||||
EVENT_SERVICE_EXECUTED, service_executed)
|
||||
@callback
|
||||
def service_executed(event: Event) -> None:
|
||||
"""Handle an executed service."""
|
||||
if event.context == context:
|
||||
fut.set_result(True)
|
||||
|
||||
self._hass.bus.async_fire(EVENT_CALL_SERVICE, event_data)
|
||||
unsub = self._hass.bus.async_listen(
|
||||
EVENT_SERVICE_EXECUTED, service_executed)
|
||||
|
||||
done, _ = await asyncio.wait(
|
||||
[fut], loop=self._hass.loop, timeout=SERVICE_CALL_LIMIT)
|
||||
success = bool(done)
|
||||
unsub()
|
||||
return success
|
||||
self._hass.bus.async_fire(EVENT_CALL_SERVICE, event_data,
|
||||
EventOrigin.local, context)
|
||||
|
||||
self._hass.bus.async_fire(EVENT_CALL_SERVICE, event_data)
|
||||
return None
|
||||
done, _ = await asyncio.wait([fut], timeout=SERVICE_CALL_LIMIT)
|
||||
success = bool(done)
|
||||
unsub()
|
||||
return success
|
||||
|
||||
async def _event_to_service_call(self, event: Event) -> None:
|
||||
"""Handle the SERVICE_CALLED events from the EventBus."""
|
||||
service_data = event.data.get(ATTR_SERVICE_DATA) or {}
|
||||
domain = event.data.get(ATTR_DOMAIN).lower() # type: ignore
|
||||
service = event.data.get(ATTR_SERVICE).lower() # type: ignore
|
||||
call_id = event.data.get(ATTR_SERVICE_CALL_ID)
|
||||
|
||||
if not self.has_service(domain, service):
|
||||
if event.origin == EventOrigin.local:
|
||||
@@ -1049,16 +1087,13 @@ class ServiceRegistry:
|
||||
|
||||
def fire_service_executed() -> None:
|
||||
"""Fire service executed event."""
|
||||
if not call_id:
|
||||
return
|
||||
|
||||
data = {ATTR_SERVICE_CALL_ID: call_id}
|
||||
|
||||
if (service_handler.is_coroutinefunction or
|
||||
service_handler.is_callback):
|
||||
self._hass.bus.async_fire(EVENT_SERVICE_EXECUTED, data)
|
||||
self._hass.bus.async_fire(EVENT_SERVICE_EXECUTED, {},
|
||||
EventOrigin.local, event.context)
|
||||
else:
|
||||
self._hass.bus.fire(EVENT_SERVICE_EXECUTED, data)
|
||||
self._hass.bus.fire(EVENT_SERVICE_EXECUTED, {},
|
||||
EventOrigin.local, event.context)
|
||||
|
||||
try:
|
||||
if service_handler.schema:
|
||||
@@ -1069,7 +1104,8 @@ class ServiceRegistry:
|
||||
fire_service_executed()
|
||||
return
|
||||
|
||||
service_call = ServiceCall(domain, service, service_data, call_id)
|
||||
service_call = ServiceCall(
|
||||
domain, service, service_data, event.context)
|
||||
|
||||
try:
|
||||
if service_handler.is_callback:
|
||||
|
||||
Reference in New Issue
Block a user