Skip to content

Commit

Permalink
Make airflow/providers pylint compatible (#7802)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongjiajie authored Mar 23, 2020
1 parent a001489 commit 4bde99f
Show file tree
Hide file tree
Showing 60 changed files with 427 additions and 304 deletions.
19 changes: 19 additions & 0 deletions UPDATING.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,25 @@ https://developers.google.com/style/inclusive-documentation
-->

### Rename parameter name in PinotAdminHook.create_segment

Rename parameter name from ``format`` to ``segment_format`` in PinotAdminHook function create_segment fro pylint compatible

### Rename parameter name in HiveMetastoreHook.get_partitions

Rename parameter name from ``filter`` to ``partition_filter`` in HiveMetastoreHook function get_partitions for pylint compatible

### Remove unnecessary parameter in FTPHook.list_directory

Remove unnecessary parameter ``nlst`` in FTPHook function list_directory for pylint compatible

### Remove unnecessary parameter in PostgresHook function copy_expert

Remove unnecessary parameter ``open`` in PostgresHook function copy_expert for pylint compatible

### Change parameter name in OpsgenieAlertOperator

Change parameter name from ``visibleTo`` to ``visible_to`` in OpsgenieAlertOperator for pylint compatible

### Use NULL as default value for dag.description

Expand Down
15 changes: 9 additions & 6 deletions airflow/providers/apache/druid/hooks/druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def __init__(
raise ValueError("Druid timeout should be equal or greater than 1")

def get_conn_url(self):
"""
Get Druid connection url
"""
conn = self.get_connection(self.druid_ingest_conn_id)
host = conn.host
port = conn.port
Expand All @@ -82,6 +85,9 @@ def get_auth(self):
return None

def submit_indexing_job(self, json_index_spec: str):
"""
Submit Druid ingestion job
"""
url = self.get_conn_url()

self.log.info("Druid ingestion spec: %s", json_index_spec)
Expand All @@ -107,7 +113,7 @@ def submit_indexing_job(self, json_index_spec: str):
# ensure that the job gets killed if the max ingestion time is exceeded
requests.post("{0}/{1}/shutdown".format(url, druid_task_id), auth=self.get_auth())
raise AirflowException('Druid ingestion took more than '
'%s seconds', self.max_ingestion_time)
f'{self.max_ingestion_time} seconds')

time.sleep(self.timeout)

Expand All @@ -122,7 +128,7 @@ def submit_indexing_job(self, json_index_spec: str):
raise AirflowException('Druid indexing job failed, '
'check console for more info')
else:
raise AirflowException('Could not get status of the job, got %s', status)
raise AirflowException(f'Could not get status of the job, got {status}')

self.log.info('Successful index')

Expand All @@ -138,14 +144,11 @@ class DruidDbApiHook(DbApiHook):
default_conn_name = 'druid_broker_default'
supports_autocommit = False

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def get_conn(self):
"""
Establish a connection to druid broker.
"""
conn = self.get_connection(self.druid_broker_conn_id)
conn = self.get_connection(self.druid_broker_conn_id) # pylint: disable=no-member
druid_broker_conn = connect(
host=conn.host,
port=conn.port,
Expand Down
65 changes: 33 additions & 32 deletions airflow/providers/apache/hive/hooks/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(
self.auth = conn.extra_dejson.get('auth', 'noSasl')
self.conn = conn
self.run_as = run_as
self.sub_process = None

if mapred_queue_priority:
mapred_queue_priority = mapred_queue_priority.upper()
Expand Down Expand Up @@ -241,24 +242,24 @@ def run_cli(self, hql, schema=None, verbose=True, hive_conf=None):

if verbose:
self.log.info("%s", " ".join(hive_cmd))
sp = subprocess.Popen(
sub_process = subprocess.Popen(
hive_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
cwd=tmp_dir,
close_fds=True)
self.sp = sp
self.sub_process = sub_process
stdout = ''
while True:
line = sp.stdout.readline()
line = sub_process.stdout.readline()
if not line:
break
stdout += line.decode('UTF-8')
if verbose:
self.log.info(line.decode('UTF-8').strip())
sp.wait()
sub_process.wait()

if sp.returncode:
if sub_process.returncode:
raise AirflowException(stdout)

return stdout
Expand Down Expand Up @@ -338,7 +339,7 @@ def load_df(
"""

def _infer_field_types_from_df(df):
DTYPE_KIND_HIVE_TYPE = {
dtype_kind_hive_type = {
'b': 'BOOLEAN', # boolean
'i': 'BIGINT', # signed integer
'u': 'BIGINT', # unsigned integer
Expand All @@ -351,10 +352,10 @@ def _infer_field_types_from_df(df):
'V': 'STRING' # void
}

d = OrderedDict()
order_type = OrderedDict()
for col, dtype in df.dtypes.iteritems():
d[col] = DTYPE_KIND_HIVE_TYPE[dtype.kind]
return d
order_type[col] = dtype_kind_hive_type[dtype.kind]
return order_type

if pandas_kwargs is None:
pandas_kwargs = {}
Expand Down Expand Up @@ -466,12 +467,15 @@ def load_file(
self.run_cli(hql)

def kill(self):
"""
Kill Hive cli command
"""
if hasattr(self, 'sp'):
if self.sp.poll() is None:
if self.sub_process.poll() is None:
print("Killing the Hive job")
self.sp.terminate()
self.sub_process.terminate()
time.sleep(60)
self.sp.kill()
self.sub_process.kill()


class HiveMetastoreHook(BaseHook):
Expand All @@ -488,9 +492,9 @@ def __init__(self, metastore_conn_id='metastore_default'):
def __getstate__(self):
# This is for pickling to work despite the thirft hive client not
# being pickable
d = dict(self.__dict__)
del d['metastore']
return d
state = dict(self.__dict__)
del state['metastore']
return state

def __setstate__(self, d):
self.__dict__.update(d)
Expand All @@ -504,18 +508,18 @@ def get_metastore_client(self):
from thrift.transport import TSocket, TTransport
from thrift.protocol import TBinaryProtocol

ms = self._find_valid_server()
conn = self._find_valid_server()

if ms is None:
if not conn:
raise AirflowException("Failed to locate the valid server.")

auth_mechanism = ms.extra_dejson.get('authMechanism', 'NOSASL')
auth_mechanism = conn.extra_dejson.get('authMechanism', 'NOSASL')

if conf.get('core', 'security') == 'kerberos':
auth_mechanism = ms.extra_dejson.get('authMechanism', 'GSSAPI')
kerberos_service_name = ms.extra_dejson.get('kerberos_service_name', 'hive')
auth_mechanism = conn.extra_dejson.get('authMechanism', 'GSSAPI')
kerberos_service_name = conn.extra_dejson.get('kerberos_service_name', 'hive')

conn_socket = TSocket.TSocket(ms.host, ms.port)
conn_socket = TSocket.TSocket(conn.host, conn.port)

if conf.get('core', 'security') == 'kerberos' \
and auth_mechanism == 'GSSAPI':
Expand All @@ -526,7 +530,7 @@ def get_metastore_client(self):

def sasl_factory():
sasl_client = sasl.Client()
sasl_client.setAttr("host", ms.host)
sasl_client.setAttr("host", conn.host)
sasl_client.setAttr("service", kerberos_service_name)
sasl_client.init()
return sasl_client
Expand All @@ -551,6 +555,7 @@ def _find_valid_server(self):
return conn
else:
self.log.info("Could not connect to %s:%s", conn.host, conn.port)
return None

def get_conn(self):
return self.metastore
Expand All @@ -577,10 +582,7 @@ def check_for_partition(self, schema, table, partition):
partitions = client.get_partitions_by_filter(
schema, table, partition, 1)

if partitions:
return True
else:
return False
return bool(partitions)

def check_for_named_partition(self, schema, table, partition_name):
"""
Expand Down Expand Up @@ -634,8 +636,7 @@ def get_databases(self, pattern='*'):
with self.metastore as client:
return client.get_databases(pattern)

def get_partitions(
self, schema, table_name, filter=None):
def get_partitions(self, schema, table_name, partition_filter=None):
"""
Returns a list of all partitions in a table. Works only
for tables with less than 32767 (java short max val).
Expand All @@ -654,10 +655,10 @@ def get_partitions(
if len(table.partitionKeys) == 0:
raise AirflowException("The table isn't partitioned")
else:
if filter:
if partition_filter:
parts = client.get_partitions_by_filter(
db_name=schema, tbl_name=table_name,
filter=filter, max_parts=HiveMetastoreHook.MAX_PART_COUNT)
filter=partition_filter, max_parts=HiveMetastoreHook.MAX_PART_COUNT)
else:
parts = client.get_partitions(
db_name=schema, tbl_name=table_name,
Expand Down Expand Up @@ -770,7 +771,7 @@ def table_exists(self, table_name, db='default'):
try:
self.get_table(table_name, db)
return True
except Exception:
except Exception: # pylint: disable=broad-except
return False


Expand Down Expand Up @@ -849,7 +850,7 @@ def _get_results(self, hql, schema='default', fetch_size=None, hive_conf=None):
lowered_statement.startswith('show') or
(lowered_statement.startswith('set') and
'=' not in lowered_statement)):
description = [c for c in cur.description]
description = cur.description
if previous_description and previous_description != description:
message = '''The statements are producing different descriptions:
Current: {}
Expand Down
4 changes: 4 additions & 0 deletions airflow/providers/apache/hive/operators/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class HiveOperator(BaseOperator):
template_ext = ('.hql', '.sql',)
ui_color = '#f0e4ec'

# pylint: disable=too-many-arguments
@apply_defaults
def __init__(
self,
Expand Down Expand Up @@ -104,6 +105,9 @@ def __init__(
self.hook = None

def get_hook(self):
"""
Get Hive cli hook
"""
return HiveCliHook(
hive_cli_conn_id=self.hive_cli_conn_id,
run_as=self.run_as,
Expand Down
33 changes: 18 additions & 15 deletions airflow/providers/apache/hive/operators/hive_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,22 +86,25 @@ def __init__(self,
self.dttm = '{{ execution_date.isoformat() }}'

def get_default_exprs(self, col, col_type):
"""
Get default expressions
"""
if col in self.col_blacklist:
return {}
d = {(col, 'non_null'): "COUNT({col})"}
exp = {(col, 'non_null'): f"COUNT({col})"}
if col_type in ['double', 'int', 'bigint', 'float']:
d[(col, 'sum')] = 'SUM({col})'
d[(col, 'min')] = 'MIN({col})'
d[(col, 'max')] = 'MAX({col})'
d[(col, 'avg')] = 'AVG({col})'
exp[(col, 'sum')] = f'SUM({col})'
exp[(col, 'min')] = f'MIN({col})'
exp[(col, 'max')] = f'MAX({col})'
exp[(col, 'avg')] = f'AVG({col})'
elif col_type == 'boolean':
d[(col, 'true')] = 'SUM(CASE WHEN {col} THEN 1 ELSE 0 END)'
d[(col, 'false')] = 'SUM(CASE WHEN NOT {col} THEN 1 ELSE 0 END)'
exp[(col, 'true')] = f'SUM(CASE WHEN {col} THEN 1 ELSE 0 END)'
exp[(col, 'false')] = f'SUM(CASE WHEN NOT {col} THEN 1 ELSE 0 END)'
elif col_type in ['string']:
d[(col, 'len')] = 'SUM(CAST(LENGTH({col}) AS BIGINT))'
d[(col, 'approx_distinct')] = 'APPROX_DISTINCT({col})'
exp[(col, 'len')] = f'SUM(CAST(LENGTH({col}) AS BIGINT))'
exp[(col, 'approx_distinct')] = f'APPROX_DISTINCT({col})'

return {k: v.format(col=col) for k, v in d.items()}
return exp

def execute(self, context=None):
metastore = HiveMetastoreHook(metastore_conn_id=self.metastore_conn_id)
Expand All @@ -113,12 +116,12 @@ def execute(self, context=None):
}
for col, col_type in list(field_types.items()):
if self.assignment_func:
d = self.assignment_func(col, col_type)
if d is None:
d = self.get_default_exprs(col, col_type)
assign_exprs = self.assignment_func(col, col_type)
if assign_exprs is None:
assign_exprs = self.get_default_exprs(col, col_type)
else:
d = self.get_default_exprs(col, col_type)
exprs.update(d)
assign_exprs = self.get_default_exprs(col, col_type)
exprs.update(assign_exprs)
exprs.update(self.extra_exprs)
exprs = OrderedDict(exprs)
exprs_str = ",\n ".join([
Expand Down
20 changes: 12 additions & 8 deletions airflow/providers/apache/pig/hooks/pig.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
conn = self.get_connection(pig_cli_conn_id)
self.pig_properties = conn.extra_dejson.get('pig_properties', '')
self.conn = conn
self.sub_process = None

def run_cli(self, pig, pig_opts=None, verbose=True):
"""
Expand Down Expand Up @@ -72,27 +73,30 @@ def run_cli(self, pig, pig_opts=None, verbose=True):

if verbose:
self.log.info("%s", " ".join(pig_cmd))
sp = subprocess.Popen(
sub_process = subprocess.Popen(
pig_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
cwd=tmp_dir,
close_fds=True)
self.sp = sp
self.sub_process = sub_process
stdout = ''
for line in iter(sp.stdout.readline, b''):
for line in iter(sub_process.stdout.readline, b''):
stdout += line.decode('utf-8')
if verbose:
self.log.info(line.strip())
sp.wait()
sub_process.wait()

if sp.returncode:
if sub_process.returncode:
raise AirflowException(stdout)

return stdout

def kill(self):
if hasattr(self, 'sp'):
if self.sp.poll() is None:
"""
Kill Pig job
"""
if self.sub_process:
if self.sub_process.poll() is None:
print("Killing the Pig job")
self.sp.kill()
self.sub_process.kill()
Loading

0 comments on commit 4bde99f

Please sign in to comment.