Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Allow for Snowflake key-pair authentication with --dbt #450

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 47 additions & 3 deletions data_diff/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
from typing import List, Optional, Dict, Tuple
from pathlib import Path

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.asymmetric import dsa
from cryptography.hazmat.primitives import serialization

import requests


Expand Down Expand Up @@ -366,22 +371,61 @@ def _get_connection_creds(self) -> Tuple[Dict[str, str], str]:

return credentials, conn_type

@staticmethod
def _get_snowflake_private_key(credentials):
"""Get Snowflake private key by path, from a Base64 encoded DER bytestring or None."""
if credentials.get("private_key") and credentials.get("private_key_path"):
raise Exception("Cannot specify both `private_key` and `private_key_path`")
if credentials.get("private_key_passphrase"):
encoded_passphrase = credentials.get("private_key_passphrase").encode()
else:
encoded_passphrase = None

if credentials.get("private_key"):
p_key = serialization.load_der_private_key(
base64.b64decode(credentials.get("private_key")),
password=encoded_passphrase,
backend=default_backend(),
)
elif credentials.get("private_key_path"):
with open(credentials.get("private_key_path"), "rb") as key:
p_key = serialization.load_pem_private_key(
key.read(), password=encoded_passphrase, backend=default_backend()
)
else:
return None

return p_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)

def set_connection(self):
credentials, conn_type = self._get_connection_creds()

if conn_type == "snowflake":
if credentials.get("password") is None or credentials.get("private_key_path") is not None:
raise Exception("Only password authentication is currently supported for Snowflake.")
if credentials.get("authenticator") is not None:
raise Exception("Federated authentication is not currently supported for Snowflake.")
conn_info = {
"driver": conn_type,
"user": credentials.get("user"),
"password": credentials.get("password"),
"account": credentials.get("account"),
"database": credentials.get("database"),
"warehouse": credentials.get("warehouse"),
"role": credentials.get("role"),
"schema": credentials.get("schema"),
}

if credentials.get("private_key") is not None or credentials.get("private_key_path") is not None:
if credentials.get("password") is not None:
raise Exception("Cannot use password and key at the same time")
conn_info["private_key"] = self._get_snowflake_private_key(credentials)
elif credentials.get("password") is not None:
conn_info["password"] = credentials.get("password")
else:
raise Exception("Password or key authentication not provided.")

self.threads = credentials.get("threads")
self.requires_upper = True
elif conn_type == "bigquery":
Expand Down
49 changes: 47 additions & 2 deletions tests/test_dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def test_set_project_dict(self, mock_open):
self.assertEqual(mock_self.project_dict, expected_dict)
mock_open.assert_called_once_with(Path(PROJECT_FILE))

def test_set_connection_snowflake_success(self):
def test_set_connection_snowflake_password_success(self):
expected_driver = "snowflake"
expected_credentials = {"user": "user", "password": "password"}
mock_self = Mock()
Expand All @@ -158,7 +158,52 @@ def test_set_connection_snowflake_success(self):
self.assertEqual(mock_self.connection.get("password"), expected_credentials["password"])
self.assertEqual(mock_self.requires_upper, True)

def test_set_connection_snowflake_no_password(self):
def test_set_connection_snowflake_private_key_success(self):
expected_driver = "snowflake"
expected_credentials = {"user": "user", "private_key": "password", "private_key_passphrase": "pass"}
expected_connection = {"user": "user", "private_key": "password"}
mock_self = Mock()
mock_self._get_connection_creds.return_value = (expected_credentials, expected_driver)
mock_self._get_snowflake_private_key.return_value = expected_connection["private_key"]

DbtParser.set_connection(mock_self)

self.assertIsInstance(mock_self.connection, dict)
self.assertEqual(mock_self.connection.get("driver"), expected_driver)
self.assertEqual(mock_self.connection.get("user"), expected_connection["user"])
self.assertEqual(mock_self.connection.get("private_key"), expected_connection["private_key"])
self.assertEqual(mock_self.connection.get("private_key_passphrase"), None)
self.assertEqual(mock_self.requires_upper, True)

def test_set_connection_snowflake_private_key_path_success(self):
expected_driver = "snowflake"
expected_credentials = {"user": "user", "private_key_path": "password", "private_key_passphrase": "pass"}
expected_connection = {"user": "user", "private_key": "password"}
mock_self = Mock()
mock_self._get_connection_creds.return_value = (expected_credentials, expected_driver)
mock_self._get_snowflake_private_key.return_value = expected_connection["private_key"]

DbtParser.set_connection(mock_self)

self.assertIsInstance(mock_self.connection, dict)
self.assertEqual(mock_self.connection.get("driver"), expected_driver)
self.assertEqual(mock_self.connection.get("user"), expected_connection["user"])
self.assertEqual(mock_self.connection.get("private_key"), expected_connection["private_key"])
self.assertEqual(mock_self.connection.get("private_key_passphrase"), None)
self.assertEqual(mock_self.requires_upper, True)

def test_set_connection_snowflake_multiple_authentication(self):
expected_driver = "snowflake"
expected_credentials = {"user": "user", "password": "password", "private_key": "password"}
mock_self = Mock()
mock_self._get_connection_creds.return_value = (expected_credentials, expected_driver)

with self.assertRaises(Exception):
DbtParser.set_connection(mock_self)

self.assertNotIsInstance(mock_self.connection, dict)

def test_set_connection_snowflake_no_authentication(self):
expected_driver = "snowflake"
expected_credentials = {"user": "user"}
mock_self = Mock()
Expand Down