Skip to content

Commit 7a09284

Browse files
committed
Fix codestyle
1 parent 50c52e1 commit 7a09284

File tree

6 files changed

+33
-45
lines changed

6 files changed

+33
-45
lines changed

graphdatascience/arrow_client/v2/job_client.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def get_summary(client: AuthenticatedArrowClient, job_id: str) -> dict[str, Any]
4545
return deserialize_single(res)
4646

4747
@staticmethod
48-
def stream_results(client: AuthenticatedArrowClient,graph_name: str, job_id: str) -> DataFrame:
48+
def stream_results(client: AuthenticatedArrowClient, graph_name: str, job_id: str) -> DataFrame:
4949
payload = {
5050
"graphName": graph_name,
5151
"jobId": job_id,
@@ -54,15 +54,10 @@ def stream_results(client: AuthenticatedArrowClient,graph_name: str, job_id: str
5454
res = client.do_action_with_retry("v2/results.stream", json.dumps(payload).encode("utf-8"))
5555
export_job_id = JobIdConfig(**deserialize_single(res)).job_id
5656

57-
payload = {
58-
"version": "v2",
59-
"name": export_job_id,
60-
"body": {}
61-
}
57+
stream_payload = {"version": "v2", "name": export_job_id, "body": {}}
6258

63-
ticket = Ticket(json.dumps(payload).encode("utf-8"))
59+
ticket = Ticket(json.dumps(stream_payload).encode("utf-8"))
6460

6561
get = client.get_stream(ticket)
6662
arrow_table = get.read_all()
6763
return arrow_table.to_pandas(types_mapper=ArrowDtype) # type: ignore
68-

graphdatascience/procedure_surface/arrow/arrow_wcc_endpoints.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,7 @@ def write(
174174

175175
return WccWriteResult(**computation_result)
176176

177-
def estimate(self, graph_name: Optional[str] = None,
178-
projection_config: Optional[dict[str, Any]] = None) -> EstimationResult:
177+
def estimate(
178+
self, graph_name: Optional[str] = None, projection_config: Optional[dict[str, Any]] = None
179+
) -> EstimationResult:
179180
pass
180-
181-

graphdatascience/tests/integrationV2/procedure_surface/arrow/conftest.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,12 @@
88
from testcontainers.neo4j import Neo4jContainer
99

1010
from graphdatascience.arrow_client.arrow_authentication import UsernamePasswordAuthentication
11+
from graphdatascience.arrow_client.arrow_info import ArrowInfo
1112
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
12-
from graphdatascience.arrow_client.v2.write_back_client import WriteBackClient
13-
from graphdatascience.query_runner.arrow_info import ArrowInfo
1413

1514

1615
@pytest.fixture(scope="session")
17-
def password_file():
16+
def password_file() -> Generator[str, None, None]:
1817
"""Create a temporary file and return its path."""
1918
temp_dir = tempfile.mkdtemp()
2019
temp_file_path = os.path.join(temp_dir, "password")
@@ -84,12 +83,3 @@ def arrow_client(session_container: DockerContainer) -> AuthenticatedArrowClient
8483
auth=UsernamePasswordAuthentication("neo4j", "password"),
8584
encrypted=False,
8685
)
87-
88-
89-
@pytest.fixture
90-
def write_back_client(neo4j_session_container: DockerContainer) -> WriteBackClient:
91-
"""Create a write-back client for the session container."""
92-
host = neo4j_session_container.get_container_host_ip()
93-
port = neo4j_session_container.get_exposed_port(8491)
94-
95-
return WriteBackClient(host=host, port=port, username="neo4j", password="test_password")

graphdatascience/tests/integrationV2/procedure_surface/arrow/test_wcc_arrow_endpoints.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import json
2+
from typing import Generator
23

34
import pytest
45

56
from graphdatascience import Graph
67
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
7-
from graphdatascience.arrow_client.v2.data_mapper_utils import deserialize_single
88
from graphdatascience.procedure_surface.arrow.arrow_wcc_endpoints import WccArrowEndpoints
99

1010

@@ -17,36 +17,36 @@ def name(self) -> str:
1717

1818

1919
@pytest.fixture
20-
def sample_graph(arrow_client: AuthenticatedArrowClient):
20+
def sample_graph(arrow_client: AuthenticatedArrowClient) -> Generator[Graph, None, None]:
2121
gdl = """
2222
(a: Node)
2323
(b: Node)
2424
(c: Node)
2525
(a)-[:REL]->(c)
2626
"""
2727

28-
arrow_client.do_action( "v2/graph.fromGDL", json.dumps({"graphName": "g", "gdlGraph": gdl}).encode("utf-8"))
28+
arrow_client.do_action("v2/graph.fromGDL", json.dumps({"graphName": "g", "gdlGraph": gdl}).encode("utf-8"))
2929
yield MockGraph("g")
30-
arrow_client.do_action( "v2/graph.drop", json.dumps({"graphName": "g"}).encode("utf-8"))
30+
arrow_client.do_action("v2/graph.drop", json.dumps({"graphName": "g"}).encode("utf-8"))
31+
3132

3233
@pytest.fixture
33-
def wcc_endpoints(arrow_client: AuthenticatedArrowClient):
34+
def wcc_endpoints(arrow_client: AuthenticatedArrowClient) -> Generator[WccArrowEndpoints, None, None]:
3435
yield WccArrowEndpoints(arrow_client)
3536

3637

37-
def test_wcc_stats(wcc_endpoints: WccArrowEndpoints, sample_graph: Graph):
38+
def test_wcc_stats(wcc_endpoints: WccArrowEndpoints, sample_graph: Graph) -> None:
3839
"""Test WCC stats operation."""
39-
result = wcc_endpoints.stats(
40-
G=sample_graph
41-
)
40+
result = wcc_endpoints.stats(G=sample_graph)
4241

4342
assert result.component_count == 2
4443
assert result.compute_millis > 0
4544
assert result.pre_processing_millis > 0
4645
assert result.post_processing_millis > 0
4746
assert "p10" in result.component_distribution
4847

49-
def test_wcc_stream(wcc_endpoints: WccArrowEndpoints, sample_graph: Graph):
48+
49+
def test_wcc_stream(wcc_endpoints: WccArrowEndpoints, sample_graph: Graph) -> None:
5050
"""Test WCC stream operation."""
5151
result_df = wcc_endpoints.stream(
5252
G=sample_graph,
@@ -56,7 +56,8 @@ def test_wcc_stream(wcc_endpoints: WccArrowEndpoints, sample_graph: Graph):
5656
assert "componentId" in result_df.columns
5757
assert len(result_df.columns) == 2
5858

59-
def test_wcc_mutate(wcc_endpoints: WccArrowEndpoints, sample_graph: Graph):
59+
60+
def test_wcc_mutate(wcc_endpoints: WccArrowEndpoints, sample_graph: Graph) -> None:
6061
"""Test WCC mutate operation."""
6162
result = wcc_endpoints.mutate(
6263
G=sample_graph,

graphdatascience/tests/integrationV2/procedure_surface/cypher/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from testcontainers.core.waiting_utils import wait_for_logs
77
from testcontainers.neo4j import Neo4jContainer
88

9+
from graphdatascience import QueryRunner
910
from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner
1011

1112

@@ -27,7 +28,7 @@ def neo4j_database_container() -> Generator[Neo4jContainer, None, None]:
2728

2829

2930
@pytest.fixture
30-
def query_runner(neo4j_database_container: DockerContainer):
31+
def query_runner(neo4j_database_container: DockerContainer) -> Generator[QueryRunner, None, None]:
3132
yield Neo4jQueryRunner.create_for_db(
3233
f"bolt://localhost:{neo4j_database_container.get_exposed_port(7687)}",
3334
("neo4j", "password"),

graphdatascience/tests/integrationV2/procedure_surface/cypher/test_wcc_cypher_endpoints.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1+
from typing import Generator
2+
13
import pytest
24

35
from graphdatascience import Graph, QueryRunner
46
from graphdatascience.procedure_surface.arrow.arrow_wcc_endpoints import WccArrowEndpoints
57
from graphdatascience.procedure_surface.cypher.wcc_cypher_endpoints import WccCypherEndpoints
6-
from graphdatascience.tests.integrationV2.procedure_surface.cypher.conftest import query_runner
78

89

910
@pytest.fixture
10-
def sample_graph(query_runner: QueryRunner):
11+
def sample_graph(query_runner: QueryRunner) -> Generator[Graph, None, None]:
1112
create_statement = """
1213
CREATE
1314
(a: Node),
@@ -30,24 +31,24 @@ def sample_graph(query_runner: QueryRunner):
3031
query_runner.run_cypher("CALL gds.graph.drop('g')")
3132
query_runner.run_cypher("MATCH (n) DETACH DELETE n")
3233

34+
3335
@pytest.fixture
34-
def wcc_endpoints(query_runner: QueryRunner):
36+
def wcc_endpoints(query_runner: QueryRunner) -> Generator[WccCypherEndpoints, None, None]:
3537
yield WccCypherEndpoints(query_runner)
3638

3739

38-
def test_wcc_stats(wcc_endpoints: WccArrowEndpoints, sample_graph: Graph):
40+
def test_wcc_stats(wcc_endpoints: WccArrowEndpoints, sample_graph: Graph) -> None:
3941
"""Test WCC stats operation."""
40-
result = wcc_endpoints.stats(
41-
G=sample_graph
42-
)
42+
result = wcc_endpoints.stats(G=sample_graph)
4343

4444
assert result.component_count == 2
4545
assert result.compute_millis > 0
4646
assert result.pre_processing_millis > 0
4747
assert result.post_processing_millis > 0
4848
assert "p10" in result.component_distribution
4949

50-
def test_wcc_stream(wcc_endpoints: WccArrowEndpoints, sample_graph: Graph):
50+
51+
def test_wcc_stream(wcc_endpoints: WccArrowEndpoints, sample_graph: Graph) -> None:
5152
"""Test WCC stream operation."""
5253
result_df = wcc_endpoints.stream(
5354
G=sample_graph,
@@ -57,7 +58,8 @@ def test_wcc_stream(wcc_endpoints: WccArrowEndpoints, sample_graph: Graph):
5758
assert "componentId" in result_df.columns
5859
assert len(result_df.columns) == 2
5960

60-
def test_wcc_mutate(wcc_endpoints: WccArrowEndpoints, sample_graph: Graph):
61+
62+
def test_wcc_mutate(wcc_endpoints: WccArrowEndpoints, sample_graph: Graph) -> None:
6163
"""Test WCC mutate operation."""
6264
result = wcc_endpoints.mutate(
6365
G=sample_graph,

0 commit comments

Comments
 (0)