Skip to content

Commit

Permalink
Add type annotations for redis provider (#9815)
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 0a2acf0b6542b717f87dee6bbff43397bbb0e83b
  • Loading branch information
scrambldchannel authored and Cloud Composer Team committed Sep 12, 2024
1 parent acffc26 commit 53026f2
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 10 deletions.
2 changes: 1 addition & 1 deletion airflow/providers/redis/hooks/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class RedisHook(BaseHook):
Also you can set ssl parameters as:
``{"ssl": true, "ssl_cert_reqs": "require", "ssl_cert_file": "/path/to/cert.pem", etc}``.
"""
def __init__(self, redis_conn_id='redis_default'):
def __init__(self, redis_conn_id: str = 'redis_default') -> None:
"""
Prepares hook to connect to a Redis database.
Expand Down
12 changes: 7 additions & 5 deletions airflow/providers/redis/operators/redis_publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# specific language governing permissions and limitations
# under the License.

from typing import Dict

from airflow.models import BaseOperator
from airflow.providers.redis.hooks.redis import RedisHook
from airflow.utils.decorators import apply_defaults
Expand All @@ -38,17 +40,17 @@ class RedisPublishOperator(BaseOperator):
@apply_defaults
def __init__(
self,
channel,
message,
redis_conn_id='redis_default',
*args, **kwargs):
channel: str,
message: str,
redis_conn_id: str = 'redis_default',
*args, **kwargs) -> None:

super().__init__(*args, **kwargs)
self.redis_conn_id = redis_conn_id
self.channel = channel
self.message = message

def execute(self, context):
def execute(self, context: Dict) -> None:
"""
Publish the message to Redis channel
Expand Down
6 changes: 4 additions & 2 deletions airflow/providers/redis/sensors/redis_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Dict

from airflow.providers.redis.hooks.redis import RedisHook
from airflow.sensors.base_sensor_operator import BaseSensorOperator
from airflow.utils.decorators import apply_defaults
Expand All @@ -28,11 +30,11 @@ class RedisKeySensor(BaseSensorOperator):
ui_color = '#f0eee4'

@apply_defaults
def __init__(self, key, redis_conn_id, *args, **kwargs):
def __init__(self, key: str, redis_conn_id: str, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.redis_conn_id = redis_conn_id
self.key = key

def poke(self, context):
def poke(self, context: Dict) -> bool:
self.log.info('Sensor checks for existence of key: %s', self.key)
return RedisHook(self.redis_conn_id).get_conn().exists(self.key)
6 changes: 4 additions & 2 deletions airflow/providers/redis/sensors/redis_pub_sub.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# specific language governing permissions and limitations
# under the License.

from typing import Dict, List, Union

from airflow.providers.redis.hooks.redis import RedisHook
from airflow.sensors.base_sensor_operator import BaseSensorOperator
from airflow.utils.decorators import apply_defaults
Expand All @@ -34,14 +36,14 @@ class RedisPubSubSensor(BaseSensorOperator):
ui_color = '#f0eee4'

@apply_defaults
def __init__(self, channels, redis_conn_id, *args, **kwargs):
def __init__(self, channels: Union[List[str], str], redis_conn_id: str, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.channels = channels
self.redis_conn_id = redis_conn_id
self.pubsub = RedisHook(redis_conn_id=self.redis_conn_id).get_conn().pubsub()
self.pubsub.subscribe(self.channels)

def poke(self, context):
def poke(self, context: Dict) -> bool:
"""
Check for message on subscribed channels and write to xcom the message with key ``message``
Expand Down

0 comments on commit 53026f2

Please sign in to comment.