Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 49 additions & 43 deletions airflow/providers/hbase/client/thrift2_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,12 @@
import logging
import ssl
import time
from pathlib import Path
from typing import Any

import thriftpy2
from thriftpy2.rpc import make_client
from thriftpy2.transport.base import TTransportException
from thrift.protocol import TBinaryProtocol
from thrift.transport import TSocket, TTransport


# Load Thrift2 definitions
THRIFT2_FILE = Path(__file__).parent.parent / "thrift_definitions" / "hbase_thrift2.thrift"
hbase_thrift2 = thriftpy2.load(str(THRIFT2_FILE), module_name="hbase_thrift2_thrift")
from airflow.providers.hbase.hbase_thrift2_generated import THBaseService, ttypes

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -78,25 +73,35 @@ def open(self):

for attempt in range(self.retry_max_attempts):
try:
# Build connection parameters
kwargs = {
'timeout': self.timeout
}
# Create socket
socket = TSocket.TSocket(self.host, self.port)
socket.setTimeout(self.timeout)

# Add SSL if context provided
# Create transport (buffered or SSL)
if self.ssl_context:
kwargs['ssl_context'] = self.ssl_context
# Wrap socket with SSL
import ssl as ssl_module
ssl_socket = self.ssl_context.wrap_socket(
socket.handle,
server_hostname=self.host
)
socket.handle = ssl_socket

self._transport = TTransport.TBufferedTransport(socket)

# Create protocol
protocol = TBinaryProtocol.TBinaryProtocol(self._transport)

# Create client
self._client = THBaseService.Client(protocol)

# Open transport
self._transport.open()

self._client = make_client(
hbase_thrift2.THBaseService,
host=self.host,
port=self.port,
**kwargs
)
logger.info("Successfully connected to HBase Thrift2 at %s:%s", self.host, self.port)
return

except (ConnectionError, TimeoutError, TTransportException, OSError) as e:
except (ConnectionError, TimeoutError, OSError, Exception) as e:
last_exception = e
if attempt == self.retry_max_attempts - 1: # Last attempt
logger.error("All %d connection attempts failed. Last error: %s", self.retry_max_attempts, e)
Expand All @@ -115,9 +120,10 @@ def open(self):

def close(self):
"""Close connection."""
if self._client:
self._client.close()
self._client = None
if hasattr(self, '_transport') and self._transport:
self._transport.close()
self._transport = None
self._client = None

def list_tables(self) -> list[str]:
"""List all tables."""
Expand All @@ -126,7 +132,7 @@ def list_tables(self) -> list[str]:

def table_exists(self, table_name: str) -> bool:
"""Check if table exists."""
table_name_obj = hbase_thrift2.TTableName(
table_name_obj = ttypes.TTableName(
ns=b"default",
qualifier=table_name.encode()
)
Expand All @@ -139,19 +145,19 @@ def create_table(self, table_name: str, families: dict[str, dict]) -> None:
table_name: Name of the table
families: Dictionary of column families
"""
table_name_obj = hbase_thrift2.TTableName(
table_name_obj = ttypes.TTableName(
ns=b"default",
qualifier=table_name.encode()
)

column_families = []
for family_name in families.keys():
col_desc = hbase_thrift2.TColumnFamilyDescriptor(
col_desc = ttypes.TColumnFamilyDescriptor(
name=family_name.encode()
)
column_families.append(col_desc)

table_desc = hbase_thrift2.TTableDescriptor(
table_desc = ttypes.TTableDescriptor(
tableName=table_name_obj,
columns=column_families
)
Expand All @@ -164,7 +170,7 @@ def delete_table(self, table_name: str) -> None:
Args:
table_name: Name of the table
"""
table_name_obj = hbase_thrift2.TTableName(
table_name_obj = ttypes.TTableName(
ns=b"default",
qualifier=table_name.encode()
)
Expand All @@ -185,14 +191,14 @@ def put(self, table_name: str, row_key: str, data: dict[str, str]) -> None:
column_values = []
for column, value in data.items():
family, qualifier = column.split(":", 1)
col_val = hbase_thrift2.TColumnValue(
col_val = ttypes.TColumnValue(
family=family.encode(),
qualifier=qualifier.encode(),
value=value.encode() if isinstance(value, str) else value
)
column_values.append(col_val)

tput = hbase_thrift2.TPut(
tput = ttypes.TPut(
row=row_key.encode(),
columnValues=column_values
)
Expand All @@ -212,14 +218,14 @@ def put_multiple(self, table_name: str, puts: list[tuple[str, dict[str, str]]])
column_values = []
for column, value in data.items():
family, qualifier = column.split(":", 1)
col_val = hbase_thrift2.TColumnValue(
col_val = ttypes.TColumnValue(
family=family.encode(),
qualifier=qualifier.encode(),
value=value.encode() if isinstance(value, str) else value
)
column_values.append(col_val)

tput = hbase_thrift2.TPut(
tput = ttypes.TPut(
row=row_key.encode(),
columnValues=column_values
)
Expand All @@ -238,13 +244,13 @@ def get(self, table_name: str, row_key: str, columns: list[str] | None = None) -
Returns:
Dictionary with row data
"""
tget = hbase_thrift2.TGet(row=row_key.encode())
tget = ttypes.TGet(row=row_key.encode())

if columns:
tget.columns = []
for column in columns:
family, qualifier = column.split(":", 1)
tcol = hbase_thrift2.TColumn(
tcol = ttypes.TColumn(
family=family.encode(),
qualifier=qualifier.encode()
)
Expand All @@ -266,13 +272,13 @@ def get_multiple(self, table_name: str, row_keys: list[str], columns: list[str]
"""
tgets = []
for row_key in row_keys:
tget = hbase_thrift2.TGet(row=row_key.encode())
tget = ttypes.TGet(row=row_key.encode())

if columns:
tget.columns = []
for column in columns:
family, qualifier = column.split(":", 1)
tcol = hbase_thrift2.TColumn(
tcol = ttypes.TColumn(
family=family.encode(),
qualifier=qualifier.encode()
)
Expand All @@ -291,13 +297,13 @@ def delete(self, table_name: str, row_key: str, columns: list[str] | None = None
row_key: Row key
columns: List of columns to delete (if None, deletes entire row)
"""
tdelete = hbase_thrift2.TDelete(row=row_key.encode())
tdelete = ttypes.TDelete(row=row_key.encode())

if columns:
tdelete.columns = []
for column in columns:
family, qualifier = column.split(":", 1)
tcol = hbase_thrift2.TColumn(
tcol = ttypes.TColumn(
family=family.encode(),
qualifier=qualifier.encode()
)
Expand All @@ -314,13 +320,13 @@ def delete_multiple(self, table_name: str, deletes: list[tuple[str, list[str] |
"""
tdeletes = []
for row_key, columns in deletes:
tdelete = hbase_thrift2.TDelete(row=row_key.encode())
tdelete = ttypes.TDelete(row=row_key.encode())

if columns:
tdelete.columns = []
for column in columns:
family, qualifier = column.split(":", 1)
tcol = hbase_thrift2.TColumn(
tcol = ttypes.TColumn(
family=family.encode(),
qualifier=qualifier.encode()
)
Expand Down Expand Up @@ -350,7 +356,7 @@ def scan(
Returns:
List of row data dictionaries
"""
tscan = hbase_thrift2.TScan()
tscan = ttypes.TScan()

if start_row:
tscan.startRow = start_row.encode()
Expand All @@ -363,7 +369,7 @@ def scan(
tscan.columns = []
for column in columns:
family, qualifier = column.split(":", 1)
tcol = hbase_thrift2.TColumn(
tcol = ttypes.TColumn(
family=family.encode(),
qualifier=qualifier.encode()
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@
Connection ID: hbase_thrift2
Connection Type: HBase
Host: your-hbase-host
Port: 9091
Extra: {"connection_mode": "thrift2"}
Port: 9090
"""

from datetime import datetime, timedelta
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
from airflow.providers.hbase.hooks.hbase import HBaseHook
from airflow.operators.python import PythonOperator

# Connection ID for Thrift2 mode
# Connection extra should contain: {"connection_mode": "thrift2"}
# Connection ID for Thrift2
HBASE_CONN_ID = "hbase_thrift2"
TABLE_NAME = "test_batch_table"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,8 @@
from airflow.operators.python import PythonOperator

# Connection IDs
# Single connection: {"connection_mode": "thrift2", "host": "localhost", "port": 9090}
# Pooled connection: {"connection_mode": "thrift2", "host": "localhost", "port": 9090,
# "connection_pool": {"enabled": true, "size": 10}}
# Single connection: no pool
# Pooled connection: {"connection_pool": {"enabled": true, "size": 10}}
SINGLE_CONN_ID = "hbase_thrift2"
POOLED_CONN_ID = "hbase_thrift2_pooled"
TABLE_NAME = "perf_test_table"
Expand Down
Loading