Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def get_conn(self) -> WeaviateClient:
grpc_secure = extras.pop("grpc_secure", False)
return weaviate.connect_to_custom(
http_host=conn.host, # type: ignore[arg-type]
http_port=conn.port or 443 if http_secure else 80,
http_port=conn.port or (443 if http_secure else 80),
http_secure=http_secure,
grpc_host=extras.pop("grpc_host", conn.host),
grpc_port=extras.pop("grpc_port", 443 if grpc_secure else 80),
Expand Down
42 changes: 37 additions & 5 deletions providers/weaviate/tests/unit/weaviate/hooks/test_weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def test_get_conn_with_api_key_in_extra(self, mock_connect_to_custom, mock_auth_
mock_auth_api_key.assert_called_once_with(api_key=self.api_key)
mock_connect_to_custom.assert_called_once_with(
http_host=self.host,
http_port=80,
http_port=8000,
http_secure=False,
grpc_host="localhost",
grpc_port=50051,
Expand All @@ -198,7 +198,7 @@ def test_get_conn_with_token_in_extra(self, mock_connect_to_custom, mock_auth_ap
mock_auth_api_key.assert_called_once_with(api_key=self.api_key)
mock_connect_to_custom.assert_called_once_with(
http_host=self.host,
http_port=80,
http_port=8000,
http_secure=False,
grpc_host="localhost",
grpc_port=50051,
Expand All @@ -216,7 +216,7 @@ def test_get_conn_with_access_token_in_extra(self, mock_connect_to_custom, mock_
)
mock_connect_to_custom.assert_called_once_with(
http_host=self.host,
http_port=80,
http_port=8000,
http_secure=False,
grpc_host="localhost",
grpc_port=50051,
Expand All @@ -236,7 +236,7 @@ def test_get_conn_with_client_secret_in_extra(self, mock_connect_to_custom, mock
)
mock_connect_to_custom.assert_called_once_with(
http_host=self.host,
http_port=80,
http_port=8000,
http_secure=False,
grpc_host="localhost",
grpc_port=50051,
Expand All @@ -252,7 +252,7 @@ def test_get_conn_with_client_password_in_extra(self, mock_connect_to_custom, mo
mock_auth_client_password.assert_called_once_with(username="login", password="password", scope=None)
mock_connect_to_custom.assert_called_once_with(
http_host=self.host,
http_port=80,
http_port=8000,
http_secure=False,
grpc_host="localhost",
grpc_port=50051,
Expand Down Expand Up @@ -964,3 +964,35 @@ def test_replace_option_of_create_or_replace_document_objects(
batch_data.call_args_list[0].kwargs["data"],
df[df["doc"].isin(changed_documents.union(new_documents))],
)


@mock.patch("airflow.providers.weaviate.hooks.weaviate.weaviate.connect_to_custom")
@pytest.mark.parametrize(
"http_secure, port, expected",
[
(False, None, 80),
(True, None, 443),
(False, 8000, 8000),
(True, 8000, 8000),
],
)
def test_get_conn_http_port_logic(connect_to_custom, http_secure, port, expected):
from airflow.models import Connection

conn = Connection(
conn_id="weaviate_http_port_logic",
conn_type="weaviate",
host="localhost",
port=port,
extra={"http_secure": http_secure},
)

with mock.patch.object(WeaviateHook, "get_connection", return_value=conn):
hook = WeaviateHook(conn_id="weaviate_http_port_logic")
hook.get_conn()

# Assert: http_port honors provided port, otherwise 80/443 depending on http_secure
kwargs = connect_to_custom.call_args.kwargs
assert kwargs["http_host"] == "localhost"
assert kwargs["http_port"] == expected
assert kwargs["http_secure"] == http_secure