Skip to content

Commit 3bcf803

Browse files
committed
Get pygraphistry as a singleton, remove static methods
1 parent 72ce717 commit 3bcf803

File tree

7 files changed

+456
-450
lines changed

7 files changed

+456
-450
lines changed

graphistry/PlotterBase.py

Lines changed: 41 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
import copy, hashlib, numpy as np, pandas as pd, pyarrow as pa, sys, uuid
55
from functools import lru_cache
66
from weakref import WeakValueDictionary
7+
import warnings
78

89
from graphistry.privacy import Privacy, Mode
9-
from graphistry.client_session import ClientSession
10+
from graphistry.client_session import ClientSession, SessionManagerProtocol, DatasetInfo
1011

1112
from .constants import SRC, DST, NODE
1213
from .plugins_types.kusto_types import KustoConfig
@@ -121,8 +122,17 @@ def reset_caches(self):
121122
def __init__(self, *args: Any, **kwargs: Any) -> None:
122123
super().__init__(*args, **kwargs)
123124

124-
from .pygraphistry import PyGraphistry
125-
self._session = kwargs.get('graphistry_client_session', PyGraphistry._session)
125+
# Use late import to avoid circular dependency
126+
pygraphistry = kwargs.get('pygraphistry_session', None)
127+
if pygraphistry is None:
128+
from .pygraphistry import PyGraphistry
129+
pygraphistry = PyGraphistry
130+
131+
if self is PyGraphistry:
132+
# NOTE: This may use global session when that isn't desired.
133+
warnings.warn("Plotter initialized without pygraphistry_session, falling back to global PyGraphistry", UserWarning)
134+
self._pygraphistry: SessionManagerProtocol = pygraphistry
135+
self._session: ClientSession = self._pygraphistry._session
126136

127137
# Bindings
128138
self._edges : Any = None
@@ -1502,21 +1512,19 @@ def server(self, v: Optional[str] = None) -> str:
15021512
15031513
Note that sets are global as PyGraphistry._config entries, so be careful in multi-user environments.
15041514
"""
1505-
from .pygraphistry import PyGraphistry
15061515
if v is not None:
1507-
PyGraphistry._session.hostname = v
1508-
return PyGraphistry._session.hostname
1516+
self._session.hostname = v
1517+
return self._session.hostname
15091518

15101519
def protocol(self, v: Optional[str] = None) -> str:
15111520
"""
15121521
Get or set the server protocol, e.g., "https"
15131522
15141523
Note that sets are global as PyGraphistry._config entries, so be careful in multi-user environments.
15151524
"""
1516-
from .pygraphistry import PyGraphistry
15171525
if v is not None:
1518-
PyGraphistry._session.protocol = v
1519-
return PyGraphistry._session.protocol
1526+
self._session.protocol = v
1527+
return self._session.protocol
15201528

15211529
def client_protocol_hostname(self, v: Optional[str] = None) -> str:
15221530
"""
@@ -1526,18 +1534,15 @@ def client_protocol_hostname(self, v: Optional[str] = None) -> str:
15261534
15271535
Note that sets are global as PyGraphistry._config entries, so be careful in multi-user environments.
15281536
"""
1529-
from .pygraphistry import PyGraphistry
15301537
if v is not None:
1531-
PyGraphistry._session.client_protocol_hostname = v
1532-
return PyGraphistry._session.client_protocol_hostname or f"{PyGraphistry.protocol()}://{PyGraphistry.server()}"
1538+
self._session.client_protocol_hostname = v
1539+
return self._session.client_protocol_hostname or f"{self.protocol()}://{self.server()}"
15331540

15341541
def base_url_server(self, v: Optional[str] = None) -> str:
1535-
from .pygraphistry import PyGraphistry
1536-
return "%s://%s" % (PyGraphistry.protocol(), PyGraphistry.server())
1542+
return "%s://%s" % (self.protocol(), self.server())
15371543

15381544
def base_url_client(self, v: Optional[str] = None) -> str:
1539-
from .pygraphistry import PyGraphistry
1540-
return PyGraphistry.client_protocol_hostname()
1545+
return self.client_protocol_hostname()
15411546

15421547
def upload(
15431548
self,
@@ -1658,8 +1663,7 @@ def plot(
16581663
.plot(es)
16591664
16601665
"""
1661-
from .pygraphistry import PyGraphistry
1662-
logger.debug("1. @PloatterBase plot: PyGraphistry.org_name(): {}".format(PyGraphistry.org_name()))
1666+
logger.debug("1. @PloatterBase plot: _pygraphistry.org_name: {}".format(self._session.org_name))
16631667

16641668
if graph is None:
16651669
if self._edges is None:
@@ -1673,36 +1677,34 @@ def plot(
16731677

16741678
self._check_mandatory_bindings(not isinstance(n, type(None)))
16751679

1676-
# from .pygraphistry import PyGraphistry
1677-
api_version = PyGraphistry.api_version()
1678-
logger.debug("2. @PloatterBase plot: PyGraphistry.org_name(): {}".format(PyGraphistry.org_name()))
1680+
logger.debug("2. @PloatterBase plot: self._pygraphistry.org_name: {}".format(self._session.org_name))
16791681
dataset: Union[ArrowUploader, Dict[str, Any], None] = None
1680-
if api_version == 1:
1682+
if self._session.api_version == 1:
16811683
dataset = self._plot_dispatch(g, n, name, description, 'json', self._style, memoize)
16821684
if skip_upload:
16831685
return dataset
1684-
info = PyGraphistry._etl1(dataset)
1685-
elif api_version == 3:
1686-
logger.debug("3. @PloatterBase plot: PyGraphistry.org_name(): {}".format(PyGraphistry.org_name()))
1687-
PyGraphistry.refresh()
1688-
logger.debug("4. @PloatterBase plot: PyGraphistry.org_name(): {}".format(PyGraphistry.org_name()))
1686+
info = self._pygraphistry._etl1(dataset)
1687+
elif self._session.api_version == 3:
1688+
logger.debug("3. @PloatterBase plot: self._pygraphistry.org_name: {}".format(self._session.org_name))
1689+
self._pygraphistry.refresh()
1690+
logger.debug("4. @PloatterBase plot: self._pygraphistry.org_name: {}".format(self._session.org_name))
16891691

16901692
uploader = dataset = self._plot_dispatch_arrow(g, n, name, description, self._style, memoize)
16911693
assert uploader is not None
16921694
if skip_upload:
16931695
return uploader
1694-
uploader.token = PyGraphistry.api_token() # type: ignore[assignment]
1696+
uploader.token = self._session.api_token # type: ignore[assignment]
16951697
uploader.post(as_files=as_files, memoize=memoize, validate=validate, erase_files_on_fail=erase_files_on_fail)
16961698
uploader.maybe_post_share_link(self)
1697-
info = {
1699+
info: DatasetInfo = {
16981700
'name': uploader.dataset_id,
16991701
'type': 'arrow',
17001702
'viztoken': str(uuid.uuid4())
17011703
}
17021704

1703-
viz_url = PyGraphistry._viz_url(info, self._url_params)
1704-
cfg_client_protocol_hostname = PyGraphistry._session.client_protocol_hostname
1705-
full_url = ('%s:%s' % (PyGraphistry._session.protocol, viz_url)) if cfg_client_protocol_hostname is None else viz_url
1705+
viz_url = self._pygraphistry._viz_url(info, self._url_params)
1706+
cfg_client_protocol_hostname = self._session.client_protocol_hostname
1707+
full_url = ('%s:%s' % (self._session.protocol, viz_url)) if cfg_client_protocol_hostname is None else viz_url
17061708

17071709
render_mode = resolve_render_mode(self, render)
17081710
if render_mode == "url":
@@ -2245,8 +2247,6 @@ def _make_dataset(self, edges, nodes, name, description, mode, metadata=None, me
22452247
# Main helper for creating ETL1 payload
22462248
def _make_json_dataset(self, edges, nodes, name) -> Dict[str, Any]:
22472249

2248-
from .pygraphistry import PyGraphistry
2249-
22502250
def flatten_categorical(df):
22512251
# Avoid cat_col.where(...)-related exceptions
22522252
df2 = df.copy()
@@ -2260,7 +2260,7 @@ def flatten_categorical(df):
22602260

22612261
bindings = {'idField': self._node or PlotterBase._defaultNodeId,
22622262
'destinationField': self._destination, 'sourceField': self._source}
2263-
dataset: Dict[str, Any] = {'name': PyGraphistry._session.dataset_prefix + name,
2263+
dataset: Dict[str, Any] = {'name': self._session.dataset_prefix + name,
22642264
'bindings': bindings, 'type': 'edgelist', 'graph': edict}
22652265

22662266
if nlist is not None:
@@ -2271,20 +2271,20 @@ def flatten_categorical(df):
22712271

22722272
def _make_arrow_dataset(self, edges: pa.Table, nodes: pa.Table, name: str, description: str, metadata: Optional[Dict[str, Any]]) -> ArrowUploader:
22732273

2274-
from .pygraphistry import PyGraphistry
22752274
au : ArrowUploader = ArrowUploader(
2276-
server_base_path=PyGraphistry.protocol() + '://' + PyGraphistry.server(),
2275+
client_session=self._session,
2276+
server_base_path=self._session.protocol + '://' + self._session.hostname,
22772277
edges=edges, nodes=nodes,
22782278
name=name, description=description,
22792279
metadata={
2280-
'usertag': PyGraphistry._tag,
2281-
'key': PyGraphistry.api_key(),
2280+
'usertag': self._session._tag,
2281+
'key': self._session.api_key,
22822282
'agent': 'pygraphistry',
22832283
'apiversion' : '3',
22842284
'agentversion': sys.modules['graphistry'].__version__, # type: ignore
22852285
**(metadata or {})
22862286
},
2287-
certificate_validation=PyGraphistry.certificate_validation())
2287+
certificate_validation=self._pygraphistry.certificate_validation())
22882288

22892289
au.edge_encodings = au.g_to_edge_encodings(self)
22902290
au.node_encodings = au.g_to_node_encodings(self)
@@ -2485,10 +2485,8 @@ def cypher(self, query: str, params: Dict[str, Any] = {}) -> Plottable:
24852485
24862486
"""
24872487

2488-
from .pygraphistry import PyGraphistry
2489-
24902488
res = copy.copy(self)
2491-
driver = self._bolt_driver or PyGraphistry._session.bolt_driver
2489+
driver = self._bolt_driver or self._session.bolt_driver
24922490
if driver is None:
24932491
raise ValueError("BOLT connection information not provided. Must first call graphistry.register(bolt=...) or g.bolt(...).")
24942492
with driver.session() as session:

graphistry/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
name,
1515
description,
1616
bind,
17+
client,
1718
style,
1819
addStyle,
1920
edges,

graphistry/client_session.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import os
2-
from typing import Any, Optional, Type, TypeVar, Union, overload, Literal, cast
2+
from typing import Any, Optional, Type, TypeVar, Union, overload, Literal, cast, Protocol, TypedDict, Dict
33
from typing_extensions import deprecated
44
from functools import lru_cache
55
import json
66

77
from graphistry.privacy import Privacy
88
from . import util
9+
from .plugins_types.spanner_types import SpannerSession
10+
from .plugins_types.kusto_types import KustoSession
911

1012

1113

@@ -22,6 +24,9 @@
2224
class ClientSession:
2325

2426
def __init__(self) -> None:
27+
self._is_authenticated: bool = False
28+
self._tag = util.fingerprint()
29+
2530
self.api_key: Optional[str] = get_from_env(ENV_GRAPHISTRY_API_KEY, str)
2631
self.api_token: Optional[str] = get_from_env("GRAPHISTRY_API_TOKEN", str)
2732
# self.api_token_refresh_ms: Optional[int] = None
@@ -55,6 +60,33 @@ def __init__(self) -> None:
5560

5661
# TODO: Migrate to a pattern like Kusto or Spanner
5762
self.bolt_driver: Optional[Any] = None
63+
64+
# Plugin sessions
65+
self.kusto: KustoSession = KustoSession()
66+
self.spanner: SpannerSession = SpannerSession()
67+
68+
69+
class DatasetInfo(TypedDict):
70+
name: str
71+
viztoken: str
72+
type: Literal["arrow", "vgraph"]
73+
74+
75+
76+
class SessionManagerProtocol(Protocol):
77+
_session: ClientSession
78+
79+
def _etl1(self, dataset: Any) -> DatasetInfo:
80+
...
81+
82+
def refresh(self, token: Optional[str] = None, fail_silent: bool = False) -> Optional[str]:
83+
...
84+
85+
def _viz_url(self, info: DatasetInfo, url_params: Dict[str, Any]) -> str:
86+
...
87+
88+
def certificate_validation(self, value: Optional[bool] = None) -> bool:
89+
...
5890

5991

6092
@deprecated("Use the session pattern instead")

graphistry/plugins/kustograph.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010
from graphistry.Plottable import Plottable
1111
from graphistry.util import setup_logger
12-
from graphistry.pygraphistry import PyGraphistry
1312
from graphistry.plugins_types.kusto_types import KustoConfig, KustoConnectionError, KustoQueryResult, KustoSession
13+
from graphistry.client_session import ClientSession
1414

1515
logger = setup_logger(__name__)
1616

@@ -27,14 +27,13 @@ class KustoGraph:
2727
def __init__(
2828
self,
2929
*args,
30-
kusto_session: Optional[KustoSession] = None,
3130
**kwargs: Any,
3231
) -> None:
3332
# NOTE: Cooperative Mixin initialization passes args and kwargs along
34-
kwargs['kusto_session'] = kusto_session
3533
super().__init__(*args, **kwargs)
3634

37-
self._kusto_session = kusto_session or KustoSession()
35+
session = kwargs.get('pygraphistry_session', {}).get('_session', None)
36+
self._kusto_session = session.kusto if isinstance(session, ClientSession) else KustoSession()
3837

3938
def from_kusto_client(self, client: KustoClient, database: str) -> 'KustoGraph':
4039
self._kusto_session.client = client
@@ -173,7 +172,8 @@ def kql_graph(self, graph_name: str, snap_name: Optional[str] = None) -> Plottab
173172
::
174173
g = graphistry.kusto_query_graph("HoneypotNetwork").plot()
175174
"""
176-
g = self if isinstance(self, Plottable) else PyGraphistry.bind()
175+
from ..plotter import Plotter
176+
g = self if isinstance(self, Plottable) else Plotter() # type: ignore
177177

178178
if snap_name:
179179
graph_query = f'graph("{graph_name}", "{snap_name}" | graph-to-table nodes as N with_node_id=NodeId, edges as E with_source_id=src with_target_id=dst; N;E'
@@ -184,7 +184,7 @@ def kql_graph(self, graph_name: str, snap_name: Optional[str] = None) -> Plottab
184184
raise ValueError(f"Expected 2 results, got {len(results)}")
185185
nodes = pd.DataFrame(results[0].data, columns=results[0].column_names)
186186
edges = pd.DataFrame(results[1].data, columns=results[1].column_names)
187-
return g.nodes(nodes, node='NodeId').edges(edges, source='src', destination='dst')
187+
return g.nodes(nodes, node='NodeId').edges(edges, source='src', destination='dst') # type: ignore
188188

189189

190190
def _kql(self, query: str) -> List[KustoQueryResult]:

graphistry/plugins/spannergraph.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44

55
import pandas as pd
66
from graphistry.Plottable import Plottable
7-
from graphistry.pygraphistry import PyGraphistry
87
from graphistry.util import setup_logger
98
from graphistry.plugins_types.spanner_types import (
109
SpannerConfig,
1110
SpannerConnectionError,
1211
SpannerQueryResult,
1312
SpannerSession,
1413
)
14+
from graphistry.client_session import ClientSession
1515

1616
if TYPE_CHECKING:
1717
from google.cloud.spanner_dbapi.connection import Connection
@@ -32,13 +32,13 @@ class SpannerGraph:
3232
def __init__(
3333
self,
3434
*args,
35-
spanner_session: Optional[SpannerSession] = None,
3635
**kwargs: Any,
3736
) -> None:
3837
# NOTE: Cooperative Mixin initialization passes args and kwargs along
39-
kwargs["spanner_session"] = spanner_session
4038
super().__init__(*args, **kwargs)
41-
self._spanner_session = spanner_session or SpannerSession()
39+
40+
session = kwargs.get('pygraphistry_session', {}).get('_session', None)
41+
self._spanner_session = session.spanner if isinstance(session, ClientSession) else SpannerSession()
4242

4343
def from_spanner_client(self, client: Connection, database: str) -> "SpannerGraph":
4444
self._spanner_session.client = client
@@ -237,7 +237,8 @@ def gql_to_graph(self, query: str) -> Plottable:
237237
g.plot()
238238
239239
"""
240-
g = self if isinstance(self, Plottable) else PyGraphistry.bind()
240+
from ..plotter import Plotter
241+
g = self if isinstance(self, Plottable) else Plotter() # type: ignore
241242
query_result = self._gql(query)
242243

243244
# convert json result set to a list
@@ -249,7 +250,7 @@ def gql_to_graph(self, query: str) -> Plottable:
249250
edges_df = self.get_edges_df(json_data)
250251

251252
# TODO(tcook): add more error handling here if nodes or edges are empty
252-
return g.nodes(nodes_df, 'identifier').edges(edges_df, 'source', 'destination')
253+
return g.nodes(nodes_df, 'identifier').edges(edges_df, 'source', 'destination') # type: ignore
253254

254255
def gql_to_df(self, query: str) -> pd.DataFrame:
255256
"""

0 commit comments

Comments
 (0)