mirror of
https://github.com/home-assistant/core.git
synced 2025-12-24 21:06:19 +00:00
Spread async love (#3575)
* Convert Entity.update_ha_state to be async * Make Service.call async * Update entity.py * Add Entity.async_update * Make automation zone trigger async * Fix linting * Reduce flakiness in hass.block_till_done * Make automation.numeric_state async * Make mqtt.subscribe async * Make automation.mqtt async * Make automation.time async * Make automation.sun async * Add async_track_point_in_utc_time * Make helpers.track_sunrise/set async * Add async_track_state_change * Make automation.state async * Clean up helpers/entity.py tests * Lint * Lint * Core.is_state and Core.is_state_attr are async friendly * Lint * Lint
This commit is contained in:
@@ -248,12 +248,16 @@ class HomeAssistant(object):
|
||||
|
||||
def notify_when_done():
|
||||
"""Notify event loop when pool done."""
|
||||
count = 0
|
||||
while True:
|
||||
# Wait for the work queue to empty
|
||||
self.pool.block_till_done()
|
||||
|
||||
# Verify the loop is empty
|
||||
if self._loop_empty():
|
||||
count += 1
|
||||
|
||||
if count == 2:
|
||||
break
|
||||
|
||||
# sleep in the loop executor, this forces execution back into
|
||||
@@ -675,40 +679,29 @@ class StateMachine(object):
|
||||
return list(self._states.values())
|
||||
|
||||
def get(self, entity_id):
|
||||
"""Retrieve state of entity_id or None if not found."""
|
||||
"""Retrieve state of entity_id or None if not found.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
return self._states.get(entity_id.lower())
|
||||
|
||||
def is_state(self, entity_id, state):
|
||||
"""Test if entity exists and is specified state."""
|
||||
return run_callback_threadsafe(
|
||||
self._loop, self.async_is_state, entity_id, state
|
||||
).result()
|
||||
|
||||
def async_is_state(self, entity_id, state):
|
||||
"""Test if entity exists and is specified state.
|
||||
|
||||
This method must be run in the event loop.
|
||||
Async friendly.
|
||||
"""
|
||||
entity_id = entity_id.lower()
|
||||
state_obj = self.get(entity_id)
|
||||
|
||||
return (entity_id in self._states and
|
||||
self._states[entity_id].state == state)
|
||||
return state_obj and state_obj.state == state
|
||||
|
||||
def is_state_attr(self, entity_id, name, value):
|
||||
"""Test if entity exists and has a state attribute set to value."""
|
||||
return run_callback_threadsafe(
|
||||
self._loop, self.async_is_state_attr, entity_id, name, value
|
||||
).result()
|
||||
|
||||
def async_is_state_attr(self, entity_id, name, value):
|
||||
"""Test if entity exists and has a state attribute set to value.
|
||||
|
||||
This method must be run in the event loop.
|
||||
Async friendly.
|
||||
"""
|
||||
entity_id = entity_id.lower()
|
||||
state_obj = self.get(entity_id)
|
||||
|
||||
return (entity_id in self._states and
|
||||
self._states[entity_id].attributes.get(name, None) == value)
|
||||
return state_obj and state_obj.attributes.get(name, None) == value
|
||||
|
||||
def remove(self, entity_id):
|
||||
"""Remove the state of an entity.
|
||||
@@ -799,7 +792,8 @@ class StateMachine(object):
|
||||
class Service(object):
|
||||
"""Represents a callable service."""
|
||||
|
||||
__slots__ = ['func', 'description', 'fields', 'schema']
|
||||
__slots__ = ['func', 'description', 'fields', 'schema',
|
||||
'iscoroutinefunction']
|
||||
|
||||
def __init__(self, func, description, fields, schema):
|
||||
"""Initialize a service."""
|
||||
@@ -807,6 +801,7 @@ class Service(object):
|
||||
self.description = description or ''
|
||||
self.fields = fields or {}
|
||||
self.schema = schema
|
||||
self.iscoroutinefunction = asyncio.iscoroutinefunction(func)
|
||||
|
||||
def as_dict(self):
|
||||
"""Return dictionary representation of this service."""
|
||||
@@ -815,19 +810,6 @@ class Service(object):
|
||||
'fields': self.fields,
|
||||
}
|
||||
|
||||
def __call__(self, call):
|
||||
"""Execute the service."""
|
||||
try:
|
||||
if self.schema:
|
||||
call.data = self.schema(call.data)
|
||||
call.data = MappingProxyType(call.data)
|
||||
|
||||
self.func(call)
|
||||
except vol.MultipleInvalid as ex:
|
||||
_LOGGER.error('Invalid service data for %s.%s: %s',
|
||||
call.domain, call.service,
|
||||
humanize_error(call.data, ex))
|
||||
|
||||
|
||||
# pylint: disable=too-few-public-methods
|
||||
class ServiceCall(object):
|
||||
@@ -839,7 +821,7 @@ class ServiceCall(object):
|
||||
"""Initialize a service call."""
|
||||
self.domain = domain.lower()
|
||||
self.service = service.lower()
|
||||
self.data = data or {}
|
||||
self.data = MappingProxyType(data or {})
|
||||
self.call_id = call_id
|
||||
|
||||
def __repr__(self):
|
||||
@@ -983,9 +965,9 @@ class ServiceRegistry(object):
|
||||
fut = asyncio.Future(loop=self._loop)
|
||||
|
||||
@asyncio.coroutine
|
||||
def service_executed(call):
|
||||
def service_executed(event):
|
||||
"""Callback method that is called when service is executed."""
|
||||
if call.data[ATTR_SERVICE_CALL_ID] == call_id:
|
||||
if event.data[ATTR_SERVICE_CALL_ID] == call_id:
|
||||
fut.set_result(True)
|
||||
|
||||
unsub = self._bus.async_listen(EVENT_SERVICE_EXECUTED,
|
||||
@@ -1000,9 +982,10 @@ class ServiceRegistry(object):
|
||||
unsub()
|
||||
return success
|
||||
|
||||
@asyncio.coroutine
|
||||
def _event_to_service_call(self, event):
|
||||
"""Callback for SERVICE_CALLED events from the event bus."""
|
||||
service_data = event.data.get(ATTR_SERVICE_DATA)
|
||||
service_data = event.data.get(ATTR_SERVICE_DATA) or {}
|
||||
domain = event.data.get(ATTR_DOMAIN).lower()
|
||||
service = event.data.get(ATTR_SERVICE).lower()
|
||||
call_id = event.data.get(ATTR_SERVICE_CALL_ID)
|
||||
@@ -1014,19 +997,41 @@ class ServiceRegistry(object):
|
||||
return
|
||||
|
||||
service_handler = self._services[domain][service]
|
||||
|
||||
def fire_service_executed():
|
||||
"""Fire service executed event."""
|
||||
if not call_id:
|
||||
return
|
||||
|
||||
data = {ATTR_SERVICE_CALL_ID: call_id}
|
||||
|
||||
if service_handler.iscoroutinefunction:
|
||||
self._bus.async_fire(EVENT_SERVICE_EXECUTED, data)
|
||||
else:
|
||||
self._bus.fire(EVENT_SERVICE_EXECUTED, data)
|
||||
|
||||
try:
|
||||
if service_handler.schema:
|
||||
service_data = service_handler.schema(service_data)
|
||||
except vol.Invalid as ex:
|
||||
_LOGGER.error('Invalid service data for %s.%s: %s',
|
||||
domain, service, humanize_error(service_data, ex))
|
||||
fire_service_executed()
|
||||
return
|
||||
|
||||
service_call = ServiceCall(domain, service, service_data, call_id)
|
||||
|
||||
# Add a job to the pool that calls _execute_service
|
||||
self._add_job(self._execute_service, service_handler, service_call,
|
||||
priority=JobPriority.EVENT_SERVICE)
|
||||
if not service_handler.iscoroutinefunction:
|
||||
def execute_service():
|
||||
"""Execute a service and fires a SERVICE_EXECUTED event."""
|
||||
service_handler.func(service_call)
|
||||
fire_service_executed()
|
||||
|
||||
def _execute_service(self, service, call):
|
||||
"""Execute a service and fires a SERVICE_EXECUTED event."""
|
||||
service(call)
|
||||
self._add_job(execute_service, priority=JobPriority.EVENT_SERVICE)
|
||||
return
|
||||
|
||||
if call.call_id is not None:
|
||||
self._bus.fire(
|
||||
EVENT_SERVICE_EXECUTED, {ATTR_SERVICE_CALL_ID: call.call_id})
|
||||
yield from service_handler.func(service_call)
|
||||
fire_service_executed()
|
||||
|
||||
def _generate_unique_id(self):
|
||||
"""Generate a unique service call id."""
|
||||
|
||||
Reference in New Issue
Block a user