Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 31 additions & 18 deletions airflow/providers/arenadata/hbase/client/thrift2_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@
from __future__ import annotations

import logging
import ssl
import time
from typing import Any

from thrift.protocol import TBinaryProtocol
from thrift.transport import TSocket, TTransport
from thrift.transport import TSocket, TTransport, TSSLSocket

from airflow.providers.arenadata.hbase.hbase_thrift2_generated import THBaseService, ttypes

Expand All @@ -36,7 +35,7 @@ class HBaseThrift2Client:
"""Lightweight HBase Thrift2 client."""

def __init__(self, host: str, port: int = 9090, timeout: int = 30000,
ssl_context: ssl.SSLContext | None = None,
ssl_options: dict[str, Any] | None = None,
retry_max_attempts: int = 3,
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0):
Expand All @@ -46,19 +45,21 @@ def __init__(self, host: str, port: int = 9090, timeout: int = 30000,
host: HBase Thrift2 server host
port: HBase Thrift2 server port (default 9090 for Arenadata/Apache HBase)
timeout: Connection timeout in milliseconds
ssl_context: SSL context for secure connections (optional)
ssl_options: SSL options dict with keys: ca_certs, cert_file, key_file, validate (optional)
retry_max_attempts: Maximum number of connection attempts
retry_delay: Initial delay between retry attempts in seconds
retry_backoff_factor: Multiplier for delay after each failed attempt
"""
self.host = host
self.port = port
self.timeout = timeout
self.ssl_context = ssl_context
self.ssl_options = ssl_options
self.retry_max_attempts = retry_max_attempts
self.retry_delay = retry_delay
self.retry_backoff_factor = retry_backoff_factor
self._client = None

logger.debug("HBaseThrift2Client initialized with ssl_options: %s", ssl_options)

def __enter__(self):
self.open()
Expand All @@ -73,20 +74,31 @@ def open(self):

for attempt in range(self.retry_max_attempts):
try:
# Create socket
socket = TSocket.TSocket(self.host, self.port)
socket.setTimeout(self.timeout)

# Create transport (buffered or SSL)
if self.ssl_context:
# Wrap socket with SSL
# Create socket (SSL or regular)
if self.ssl_options:
import ssl as ssl_module
ssl_socket = self.ssl_context.wrap_socket(
socket.handle,
server_hostname=self.host
)
socket.handle = ssl_socket
# Map our options to TSSLSocket parameters
ssl_params = {
'host': self.host,
'port': self.port,
}
if 'ca_certs' in self.ssl_options:
ssl_params['ca_certs'] = self.ssl_options['ca_certs']
if 'cert_file' in self.ssl_options:
ssl_params['certfile'] = self.ssl_options['cert_file']
if 'key_file' in self.ssl_options:
ssl_params['keyfile'] = self.ssl_options['key_file']
if 'validate' in self.ssl_options:
# Map validate to cert_reqs
ssl_params['cert_reqs'] = ssl_module.CERT_REQUIRED if self.ssl_options['validate'] else ssl_module.CERT_NONE

socket = TSSLSocket.TSSLSocket(**ssl_params)
else:
socket = TSocket.TSocket(self.host, self.port)

socket.setTimeout(self.timeout)

# Create transport
self._transport = TTransport.TBufferedTransport(socket)

# Create protocol
Expand All @@ -98,7 +110,8 @@ def open(self):
# Open transport
self._transport.open()

logger.info("Successfully connected to HBase Thrift2 at %s:%s", self.host, self.port)
logger.info("Successfully connected to HBase Thrift2 at %s:%s (SSL: %s)",
self.host, self.port, bool(self.ssl_options) if self.ssl_options else False)
return

except (ConnectionError, TimeoutError, OSError, Exception) as e:
Expand Down
216 changes: 140 additions & 76 deletions airflow/providers/arenadata/hbase/example_dags/example_hbase_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,41 +16,27 @@
# specific language governing permissions and limitations
# under the License.
"""
Example DAG showing HBase Thrift2 provider usage with SSL/TLS connection.

This DAG demonstrates secure connections to HBase using Thrift2 protocol with SSL.

Connection Configuration (hbase_thrift2_ssl):
{
"connection_mode": "thrift2",
"host": "localhost",
"port": 9090,
"use_ssl": true,
"ssl_verify_mode": "CERT_REQUIRED",
"ssl_ca_secret": "hbase/ca-cert",
"ssl_cert_secret": "hbase/client-cert",
"ssl_key_secret": "hbase/client-key",
"ssl_min_version": "TLSv1_2"
}
Example DAG showing HBase Thrift2 with SSL/TLS using Hook.

Before running this DAG, create an Airflow Connection:

Prerequisites:
1. HBase Thrift2 server with SSL enabled
2. SSL certificates stored in Airflow Variables:
- hbase/ca-cert: CA certificate
- hbase/client-cert: Client certificate
- hbase/client-key: Client private key
3. Create Airflow Connection 'hbase_thrift2_ssl' with above config
Connection ID: hbase_thrift2_ssl_mtls
Connection Type: Generic
Host: your-hbase-host (e.g., bezlootskiy-2.ru-central1.internal)
Port: 9090 (Thrift2 default port)
Extra: {
"ca_certs": "/etc/ssl/hbase/ca-bundle.crt",
"cert_file": "/etc/ssl/hbase/client.crt",
"key_file": "/etc/ssl/hbase/client.key",
"validate": true
}
"""

from datetime import datetime, timedelta

from airflow import DAG
from airflow.providers.arenadata.hbase.operators.hbase import (
HBaseCreateTableOperator,
HBaseDeleteTableOperator,
HBasePutOperator,
)
from airflow.providers.arenadata.hbase.sensors.hbase import HBaseTableSensor, HBaseRowSensor
from airflow.operators.python import PythonOperator
from airflow.providers.arenadata.hbase.hooks.hbase import HBaseThriftHook

default_args = {
"owner": "airflow",
Expand All @@ -65,74 +51,152 @@
dag = DAG(
"example_hbase_ssl",
default_args=default_args,
description="Example HBase Thrift2 DAG with SSL/TLS connection",
description="Example HBase Thrift2 DAG with SSL",
schedule_interval=None,
catchup=False,
tags=["example", "hbase", "thrift2", "ssl"],
tags=["example", "hbase", "ssl"],
)

# Connection ID for Thrift2 with SSL
HBASE_CONN_ID = "hbase_thrift2_ssl"
TABLE_NAME = "test_table_ssl"

# Delete table if exists for idempotency
delete_table_cleanup = HBaseDeleteTableOperator(
task_id="delete_table_cleanup",
table_name=TABLE_NAME,
hbase_conn_id=HBASE_CONN_ID,
def create_table_task():
"""Create HBase table using Hook."""
hook = HBaseThriftHook(hbase_conn_id="hbase_thrift2_ssl_mtls")

# Delete table if exists
if hook.table_exists("test_table_ssl"):
hook.delete_table("test_table_ssl")
print("Deleted existing table")

# Create table
hook.create_table(
"test_table_ssl",
families={
"cf1": {},
"cf2": {},
}
)
print("Created table: test_table_ssl")


def put_data_task():
"""Put data into HBase table using Hook."""
hook = HBaseThriftHook(hbase_conn_id="hbase_thrift2_ssl_mtls")

# Put single row
hook.put_row(
"test_table_ssl",
"row1",
{
"cf1:col1": "value1",
"cf1:col2": "value2",
"cf2:col1": "value3",
}
)
print("Put data for row1")

# Put more rows
for i in range(2, 6):
hook.put_row(
"test_table_ssl",
f"row{i}",
{
"cf1:col1": f"value{i}_1",
"cf2:col1": f"value{i}_2",
}
)
print("Put data for rows 2-5")


def get_data_task():
"""Get data from HBase table using Hook."""
hook = HBaseThriftHook(hbase_conn_id="hbase_thrift2_ssl_mtls")

# Get single row
result = hook.get_row("test_table_ssl", "row1")
print(f"Got row1: {result}")

# Get specific columns
result = hook.get_row(
"test_table_ssl",
"row1",
columns=["cf1:col1", "cf2:col1"]
)
print(f"Got row1 (specific columns): {result}")


def scan_table_task():
"""Scan HBase table using Hook."""
hook = HBaseThriftHook(hbase_conn_id="hbase_thrift2_ssl_mtls")

# Scan all rows
results = hook.scan_table("test_table_ssl")
print(f"Scanned {len(results)} rows")
for row_key, data in results:
print(f" Row: {row_key}, Columns: {len(data)}")

# Scan with limit
results = hook.scan_table("test_table_ssl", limit=3)
print(f"Scanned with limit=3: {len(results)} rows")


def delete_row_task():
"""Delete row from HBase table using Hook."""
hook = HBaseThriftHook(hbase_conn_id="hbase_thrift2_ssl_mtls")

# Delete specific columns
hook.delete_row("test_table_ssl", "row2", columns=["cf1:col1"])
print("Deleted cf1:col1 from row2")

# Delete entire row
hook.delete_row("test_table_ssl", "row3")
print("Deleted row3")


def cleanup_task():
"""Delete test table using Hook."""
hook = HBaseThriftHook(hbase_conn_id="hbase_thrift2_ssl_mtls")

if hook.table_exists("test_table_ssl"):
hook.delete_table("test_table_ssl")
print("Deleted table: test_table_ssl")


# Define tasks
create_table = PythonOperator(
task_id="create_table",
python_callable=create_table_task,
dag=dag,
)

# Create table using SSL connection
create_table = HBaseCreateTableOperator(
task_id="create_table",
table_name=TABLE_NAME,
families={
"cf1": {}, # Column family 1
"cf2": {}, # Column family 2
},
hbase_conn_id=HBASE_CONN_ID,
put_data = PythonOperator(
task_id="put_data",
python_callable=put_data_task,
dag=dag,
)

check_table = HBaseTableSensor(
task_id="check_table_exists",
table_name=TABLE_NAME,
hbase_conn_id=HBASE_CONN_ID,
timeout=60,
poke_interval=10,
get_data = PythonOperator(
task_id="get_data",
python_callable=get_data_task,
dag=dag,
)

put_data = HBasePutOperator(
task_id="put_data",
table_name=TABLE_NAME,
row_key="ssl_row1",
data={
"cf1:col1": "ssl_value1",
"cf1:col2": "ssl_value2",
"cf2:col1": "ssl_value3",
},
hbase_conn_id=HBASE_CONN_ID,
scan_table = PythonOperator(
task_id="scan_table",
python_callable=scan_table_task,
dag=dag,
)

check_row = HBaseRowSensor(
task_id="check_row_exists",
table_name=TABLE_NAME,
row_key="ssl_row1",
hbase_conn_id=HBASE_CONN_ID,
timeout=60,
poke_interval=10,
delete_row = PythonOperator(
task_id="delete_row",
python_callable=delete_row_task,
dag=dag,
)

delete_table = HBaseDeleteTableOperator(
task_id="delete_table",
table_name=TABLE_NAME,
hbase_conn_id=HBASE_CONN_ID,
cleanup = PythonOperator(
task_id="cleanup",
python_callable=cleanup_task,
dag=dag,
)

# Set dependencies
delete_table_cleanup >> create_table >> check_table >> put_data >> check_row >> delete_table
create_table >> put_data >> get_data >> scan_table >> delete_row >> cleanup
Loading