Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 80 additions & 13 deletions python/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,26 @@
import pyarrow.flight as flight
import pyodbc
import requests
from gqlalchemy import Memgraph, Neo4j
from gqlalchemy import Memgraph
from neo4j import GraphDatabase
from neo4j.time import DateTime as Neo4jDateTime
from neo4j.time import Date as Neo4jDate


class Constants:
BATCH_SIZE = 1000
COLUMN_NAMES = "column_names"
CONNECTION = "connection"
CURSOR = "cursor"
DATABASE = "database"
DRIVER = "driver"
HOST = "host"
I_COLUMN_NAME = 0
PASSWORD = "password"
PORT = "port"
RESULT = "result"
SESSION = "session"
URI_SCHEME = "uri_scheme"
USERNAME = "username"


Expand Down Expand Up @@ -519,12 +526,27 @@ def init_migrate_neo4j(
if len(config_path) > 0:
config = _combine_config(config=config, config_path=config_path)

neo4j_db = Neo4j(**config)
uri = _build_neo4j_uri(config)
username = config.get(Constants.USERNAME, "neo4j")
password = config.get(Constants.PASSWORD, "password")
database = config.get(Constants.DATABASE, None) # None means default database

driver = GraphDatabase.driver(uri, auth=(username, password))

# Create session with optional database parameter
if database:
session = driver.session(database=database)
else:
session = driver.session()

query = _formulate_cypher_query(label_or_rel_or_query)
cursor = neo4j_db.execute_and_fetch(query, params)
# Neo4j expects params to be a dict or None
cypher_params = params if params is not None else {}
result = session.run(query, parameters=cypher_params)

neo4j_dict[thread_id][Constants.CONNECTION] = neo4j_db
neo4j_dict[thread_id][Constants.CURSOR] = cursor
neo4j_dict[thread_id][Constants.DRIVER] = driver
neo4j_dict[thread_id][Constants.SESSION] = session
neo4j_dict[thread_id][Constants.RESULT] = result


def neo4j(
Expand All @@ -545,21 +567,31 @@ def neo4j(
global neo4j_dict

thread_id = threading.get_native_id()
cursor = neo4j_dict[thread_id][Constants.CURSOR]
result = neo4j_dict[thread_id][Constants.RESULT]

return [
mgp.Record(row=row)
for row in (next(cursor, None) for _ in range(Constants.BATCH_SIZE))
if row is not None
]
# Fetch up to BATCH_SIZE records
batch = []
for record in result:
# Convert neo4j.Record to dict with proper type conversion
batch.append(mgp.Record(row=_convert_neo4j_record(record)))

# Check if we've reached the batch size limit
if len(batch) >= Constants.BATCH_SIZE:
break

return batch


def cleanup_migrate_neo4j():
global neo4j_dict

thread_id = threading.get_native_id()
if Constants.CONNECTION in neo4j_dict[thread_id]:
neo4j_dict[thread_id][Constants.CONNECTION].close()
session = neo4j_dict[thread_id].get(Constants.SESSION)
driver = neo4j_dict[thread_id].get(Constants.DRIVER)
if session:
session.close()
if driver:
driver.close()
neo4j_dict.pop(thread_id, None)


Expand Down Expand Up @@ -1029,3 +1061,38 @@ def _check_params_type(params: Any, types=(dict, list, tuple)) -> None:
raise TypeError(
"Database query parameter values must be passed in a container of type List[Any] (or Map, if migrating from MySQL or Oracle DB)"
)


def _convert_neo4j_value(value):
"""Convert Neo4j values to Python-compatible formats."""
if value is None:
return None

# Handle Neo4j DateTime objects
try:
if isinstance(value, Neo4jDateTime) or isinstance(value, Neo4jDate):
return value.to_native()
except ImportError:
pass

# Handle lists and dicts recursively
if isinstance(value, list):
return [_convert_neo4j_value(item) for item in value]

if isinstance(value, dict):
return {key: _convert_neo4j_value(val) for key, val in value.items()}

# For other types, return as is
return value


def _convert_neo4j_record(record):
"""Convert a Neo4j record to a Python dict with proper type conversion."""
return {key: _convert_neo4j_value(value) for key, value in record.items()}


def _build_neo4j_uri(config: mgp.Map) -> str:
host = config.get(Constants.HOST, "localhost")
port = config.get(Constants.PORT, 7687)
uri_scheme = config.get(Constants.URI_SCHEME, "bolt")
return f"{uri_scheme}://{host}:{port}"
3 changes: 2 additions & 1 deletion python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ duckdb==1.2.1
elasticsearch==8.17.0
gekko==1.2.1
gensim==4.3.3
gqlalchemy==1.6.0
gqlalchemy==1.8.0
igraph==0.11.8
mysql-connector-python==9.1.0
neo4j==5.28.2
networkx==2.8.8
oracledb==2.5.1
pandas==2.2.3
Expand Down