1
0
mirror of https://github.com/home-assistant/core.git synced 2025-12-24 21:06:19 +00:00

Spread async love (#3575)

* Convert Entity.update_ha_state to be async

* Make Service.call async

* Update entity.py

* Add Entity.async_update

* Make automation zone trigger async

* Fix linting

* Reduce flakiness in hass.block_till_done

* Make automation.numeric_state async

* Make mqtt.subscribe async

* Make automation.mqtt async

* Make automation.time async

* Make automation.sun async

* Add async_track_point_in_utc_time

* Make helpers.track_sunrise/set async

* Add async_track_state_change

* Make automation.state async

* Clean up helpers/entity.py tests

* Lint

* Lint

* Core.is_state and Core.is_state_attr are async friendly

* Lint

* Lint
This commit is contained in:
Paulus Schoutsen
2016-09-30 12:57:24 -07:00
committed by GitHub
parent 7e50ccd32a
commit b650b2b0db
17 changed files with 323 additions and 151 deletions

View File

@@ -248,12 +248,16 @@ class HomeAssistant(object):
def notify_when_done():
"""Notify event loop when pool done."""
count = 0
while True:
# Wait for the work queue to empty
self.pool.block_till_done()
# Verify the loop is empty
if self._loop_empty():
count += 1
if count == 2:
break
# sleep in the loop executor, this forces execution back into
@@ -675,40 +679,29 @@ class StateMachine(object):
return list(self._states.values())
def get(self, entity_id):
"""Retrieve state of entity_id or None if not found."""
"""Retrieve state of entity_id or None if not found.
Async friendly.
"""
return self._states.get(entity_id.lower())
def is_state(self, entity_id, state):
"""Test if entity exists and is specified state."""
return run_callback_threadsafe(
self._loop, self.async_is_state, entity_id, state
).result()
def async_is_state(self, entity_id, state):
"""Test if entity exists and is specified state.
This method must be run in the event loop.
Async friendly.
"""
entity_id = entity_id.lower()
state_obj = self.get(entity_id)
return (entity_id in self._states and
self._states[entity_id].state == state)
return state_obj and state_obj.state == state
def is_state_attr(self, entity_id, name, value):
"""Test if entity exists and has a state attribute set to value."""
return run_callback_threadsafe(
self._loop, self.async_is_state_attr, entity_id, name, value
).result()
def async_is_state_attr(self, entity_id, name, value):
"""Test if entity exists and has a state attribute set to value.
This method must be run in the event loop.
Async friendly.
"""
entity_id = entity_id.lower()
state_obj = self.get(entity_id)
return (entity_id in self._states and
self._states[entity_id].attributes.get(name, None) == value)
return state_obj and state_obj.attributes.get(name, None) == value
def remove(self, entity_id):
"""Remove the state of an entity.
@@ -799,7 +792,8 @@ class StateMachine(object):
class Service(object):
"""Represents a callable service."""
__slots__ = ['func', 'description', 'fields', 'schema']
__slots__ = ['func', 'description', 'fields', 'schema',
'iscoroutinefunction']
def __init__(self, func, description, fields, schema):
"""Initialize a service."""
@@ -807,6 +801,7 @@ class Service(object):
self.description = description or ''
self.fields = fields or {}
self.schema = schema
self.iscoroutinefunction = asyncio.iscoroutinefunction(func)
def as_dict(self):
"""Return dictionary representation of this service."""
@@ -815,19 +810,6 @@ class Service(object):
'fields': self.fields,
}
def __call__(self, call):
"""Execute the service."""
try:
if self.schema:
call.data = self.schema(call.data)
call.data = MappingProxyType(call.data)
self.func(call)
except vol.MultipleInvalid as ex:
_LOGGER.error('Invalid service data for %s.%s: %s',
call.domain, call.service,
humanize_error(call.data, ex))
# pylint: disable=too-few-public-methods
class ServiceCall(object):
@@ -839,7 +821,7 @@ class ServiceCall(object):
"""Initialize a service call."""
self.domain = domain.lower()
self.service = service.lower()
self.data = data or {}
self.data = MappingProxyType(data or {})
self.call_id = call_id
def __repr__(self):
@@ -983,9 +965,9 @@ class ServiceRegistry(object):
fut = asyncio.Future(loop=self._loop)
@asyncio.coroutine
def service_executed(call):
def service_executed(event):
"""Callback method that is called when service is executed."""
if call.data[ATTR_SERVICE_CALL_ID] == call_id:
if event.data[ATTR_SERVICE_CALL_ID] == call_id:
fut.set_result(True)
unsub = self._bus.async_listen(EVENT_SERVICE_EXECUTED,
@@ -1000,9 +982,10 @@ class ServiceRegistry(object):
unsub()
return success
@asyncio.coroutine
def _event_to_service_call(self, event):
"""Callback for SERVICE_CALLED events from the event bus."""
service_data = event.data.get(ATTR_SERVICE_DATA)
service_data = event.data.get(ATTR_SERVICE_DATA) or {}
domain = event.data.get(ATTR_DOMAIN).lower()
service = event.data.get(ATTR_SERVICE).lower()
call_id = event.data.get(ATTR_SERVICE_CALL_ID)
@@ -1014,19 +997,41 @@ class ServiceRegistry(object):
return
service_handler = self._services[domain][service]
def fire_service_executed():
"""Fire service executed event."""
if not call_id:
return
data = {ATTR_SERVICE_CALL_ID: call_id}
if service_handler.iscoroutinefunction:
self._bus.async_fire(EVENT_SERVICE_EXECUTED, data)
else:
self._bus.fire(EVENT_SERVICE_EXECUTED, data)
try:
if service_handler.schema:
service_data = service_handler.schema(service_data)
except vol.Invalid as ex:
_LOGGER.error('Invalid service data for %s.%s: %s',
domain, service, humanize_error(service_data, ex))
fire_service_executed()
return
service_call = ServiceCall(domain, service, service_data, call_id)
# Add a job to the pool that calls _execute_service
self._add_job(self._execute_service, service_handler, service_call,
priority=JobPriority.EVENT_SERVICE)
if not service_handler.iscoroutinefunction:
def execute_service():
"""Execute a service and fires a SERVICE_EXECUTED event."""
service_handler.func(service_call)
fire_service_executed()
def _execute_service(self, service, call):
"""Execute a service and fires a SERVICE_EXECUTED event."""
service(call)
self._add_job(execute_service, priority=JobPriority.EVENT_SERVICE)
return
if call.call_id is not None:
self._bus.fire(
EVENT_SERVICE_EXECUTED, {ATTR_SERVICE_CALL_ID: call.call_id})
yield from service_handler.func(service_call)
fire_service_executed()
def _generate_unique_id(self):
"""Generate a unique service call id."""