Skip to content

Commit

Permalink
xml ssl tags and replace static self.home with self.logical_path
Browse files Browse the repository at this point in the history
  • Loading branch information
pauldg committed Oct 16, 2024
1 parent f0c7801 commit 6d8a2d6
Showing 1 changed file with 63 additions and 59 deletions.
122 changes: 63 additions & 59 deletions lib/galaxy/objectstore/irods.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,16 @@ def parse_config_xml(config_xml):
_config_xml_error("auth")
username = a_xml[0].get("username")
password = a_xml[0].get("password")
# sslfile = a_xml[0].get("sslfile", None)
client_server_negotiation = a_xml[0].get("client_server_negotiation", None)
client_server_policy = a_xml[0].get("client_server_policy", None)
encryption_algorithm = a_xml[0].get("encryption_algorithm", None)
encryption_key_size = int(a_xml[0].get("encryption_key_size", None))
encryption_num_hash_rounds = int(a_xml[0].get("encryption_num_hash_rounds", None))
encryption_salt_size = int(a_xml[0].get("encryption_salt_size", None))
ssl_verify_server = a_xml[0].get("ssl_verify_server", None)
ssl_ca_certificate_file = a_xml[0].get("ssl_ca_certificate_file", None)

s_xml = config_xml.findall("ssl")
client_server_negotiation = s_xml[0].get("client_server_negotiation", None)
client_server_policy = s_xml[0].get("client_server_policy", None)
encryption_algorithm = s_xml[0].get("encryption_algorithm", None)
encryption_key_size = int(s_xml[0].get("encryption_key_size", None))
encryption_num_hash_rounds = int(s_xml[0].get("encryption_num_hash_rounds", None))
encryption_salt_size = int(s_xml[0].get("encryption_salt_size", None))
ssl_verify_server = s_xml[0].get("ssl_verify_server", None)
ssl_ca_certificate_file = s_xml[0].get("ssl_ca_certificate_file", None)

r_xml = config_xml.findall("resource")
if not r_xml:
Expand All @@ -82,6 +83,11 @@ def parse_config_xml(config_xml):
refresh_time = int(c_xml[0].get("refresh_time", 300))
connection_pool_monitor_interval = int(c_xml[0].get("connection_pool_monitor_interval", -1))

l_xml = config_xml.findall("logical")
if not l_xml:
_config_xml_error("logical")
logical_path = l_xml[0].get("path", None)

c_xml = config_xml.findall("cache")
if not c_xml:
_config_xml_error("cache")
Expand All @@ -99,7 +105,8 @@ def parse_config_xml(config_xml):
"auth": {
"username": username,
"password": password,
# "sslfile": sslfile,
},
"ssl": {
"client_server_negotiation": client_server_negotiation,
"client_server_policy": client_server_policy,
"encryption_algorithm": encryption_algorithm,
Expand All @@ -122,6 +129,9 @@ def parse_config_xml(config_xml):
"refresh_time": refresh_time,
"connection_pool_monitor_interval": connection_pool_monitor_interval,
},
"logical": {
"path": logical_path,
},
"cache": {
"size": cache_size,
"path": staging_path,
Expand Down Expand Up @@ -158,15 +168,17 @@ def __init__(self, config, config_dict):
self.password = auth_dict.get("password")
if self.password is None:
_config_dict_error("auth->password")
# self.sslfile = auth_dict.get("sslfile")
self.client_server_negotiation = auth_dict.get("client_server_negotiation")
self.client_server_policy = auth_dict.get("client_server_policy")
self.encryption_algorithm = auth_dict.get("encryption_algorithm")
self.encryption_key_size = auth_dict.get("encryption_key_size")
self.encryption_num_hash_rounds = auth_dict.get("encryption_num_hash_rounds")
self.encryption_salt_size = auth_dict.get("encryption_salt_size")
self.ssl_verify_server = auth_dict.get("ssl_verify_server")
self.ssl_ca_certificate_file = auth_dict.get("ssl_ca_certificate_file")

ssl_dict = config_dict.get("ssl") or {}

self.client_server_negotiation = ssl_dict.get("client_server_negotiation")
self.client_server_policy = ssl_dict.get("client_server_policy")
self.encryption_algorithm = ssl_dict.get("encryption_algorithm")
self.encryption_key_size = ssl_dict.get("encryption_key_size")
self.encryption_num_hash_rounds = ssl_dict.get("encryption_num_hash_rounds")
self.encryption_salt_size = ssl_dict.get("encryption_salt_size")
self.ssl_verify_server = ssl_dict.get("ssl_verify_server")
self.ssl_ca_certificate_file = ssl_dict.get("ssl_ca_certificate_file")

resource_dict = config_dict["resource"]
if resource_dict is None:
Expand Down Expand Up @@ -201,6 +213,11 @@ def __init__(self, config, config_dict):
if self.connection_pool_monitor_interval is None:
_config_dict_error("connection->connection_pool_monitor_interval")

logical_dict = config_dict.get("logical") or {}
self.logical_path = logical_dict.get("path") or f"/{self.zone}/home/{self.username}"
if self.logical_path is None:
_config_dict_error("logical->path")

cache_dict = config_dict.get("cache") or {}
self.cache_size = cache_dict.get("size") or self.config.object_store_cache_path
if self.cache_size is None:
Expand All @@ -218,46 +235,29 @@ def __init__(self, config, config_dict):
if irods is None:
raise Exception(IRODS_IMPORT_MESSAGE)

# self.home = f"/{self.zone}/home/{self.username}"
self.home = "/vsc_galaxy/home/t1_data_2024_04/ingress/dev-paul"

if irods is None:
raise Exception(IRODS_IMPORT_MESSAGE)

session_params = {
'host': self.host,
'port': self.port,
'user': self.username,
'password': self.password,
'zone': self.zone,
'refresh_time': self.refresh_time,
'client_server_negotiation': self.client_server_negotiation,
'client_server_policy': self.client_server_policy,
'encryption_algorithm': self.encryption_algorithm,
'encryption_key_size': self.encryption_key_size,
'encryption_num_hash_rounds': self.encryption_num_hash_rounds,
'encryption_salt_size': self.encryption_salt_size,
'ssl_verify_server': self.ssl_verify_server,
'ssl_ca_certificate_file': self.ssl_ca_certificate_file,
'ssl_context': ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH)
"host": self.host,
"port": self.port,
"user": self.username,
"password": self.password,
"zone": self.zone,
"refresh_time": self.refresh_time,
"client_server_negotiation": self.client_server_negotiation,
"client_server_policy": self.client_server_policy,
"encryption_algorithm": self.encryption_algorithm,
"encryption_key_size": self.encryption_key_size,
"encryption_num_hash_rounds": self.encryption_num_hash_rounds,
"encryption_salt_size": self.encryption_salt_size,
"ssl_verify_server": self.ssl_verify_server,
"ssl_ca_certificate_file": self.ssl_ca_certificate_file,
"ssl_context": ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH),
}

# Add ssl parameters only if self.sslfile is not None
# if self.sslfile is not None:
# with open(self.sslfile, "r") as file:
# ssl_settings = json.load(file)

# ssl_settings['ssl_context'] = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH)
# session_params.update(ssl_settings)

self.session = iRODSSession(**session_params)

log.debug("SESSION PARAMS: %s", session_params)

coll = self.session.collections.get("/vsc_galaxy/")
for col in coll.subcollections:
log.debug("COLLECTION: %s", col)

# Set connection timeout
self.session.connection_timeout = self.timeout

Expand Down Expand Up @@ -340,7 +340,8 @@ def _config_to_dict(self):
"auth": {
"username": self.username,
"password": self.password,
# "sslfile": self.sslfile,
},
"ssl": {
"client_server_negotiation": self.client_server_negotiation,
"client_server_policy": self.client_server_policy,
"encryption_algorithm": self.encryption_algorithm,
Expand All @@ -363,6 +364,9 @@ def _config_to_dict(self):
"refresh_time": self.refresh_time,
"connection_pool_monitor_interval": self.connection_pool_monitor_interval,
},
"logical": {
"path": self.logical_path,
},
"cache": {
"size": self.cache_size,
"path": self.staging_path,
Expand All @@ -377,7 +381,7 @@ def _get_remote_size(self, rel_path):
data_object_name = p.stem + p.suffix
subcollection_name = p.parent

collection_path = f"{self.home}/{subcollection_name}"
collection_path = f"{self.logical_path}/{subcollection_name}"
data_object_path = f"{collection_path}/{data_object_name}"
options = {kw.DEST_RESC_NAME_KW: self.resource}

Expand All @@ -397,7 +401,7 @@ def _exists_remotely(self, rel_path):
data_object_name = p.stem + p.suffix
subcollection_name = p.parent

collection_path = f"{self.home}/{subcollection_name}"
collection_path = f"{self.logical_path}/{subcollection_name}"
data_object_path = f"{collection_path}/{data_object_name}"
options = {kw.DEST_RESC_NAME_KW: self.resource}

Expand All @@ -419,7 +423,7 @@ def _download(self, rel_path):
data_object_name = p.stem + p.suffix
subcollection_name = p.parent

collection_path = f"{self.home}/{subcollection_name}"
collection_path = f"{self.logical_path}/{subcollection_name}"
data_object_path = f"{collection_path}/{data_object_name}"
# we need to allow irods to override already existing zero-size output files created
# in object store cache during job setup (see also https://github.com/galaxyproject/galaxy/pull/17025#discussion_r1394517033)
Expand Down Expand Up @@ -460,7 +464,7 @@ def _push_to_storage(self, rel_path, source_file=None, from_string=None):
return False

# Check if the data object exists in iRODS
collection_path = f"{self.home}/{subcollection_name}"
collection_path = f"{self.logical_path}/{subcollection_name}"
data_object_path = f"{collection_path}/{data_object_name}"
exists = False

Expand Down Expand Up @@ -538,7 +542,7 @@ def _delete(self, obj, entire_dir: bool = False, **kwargs) -> bool:
if entire_dir and extra_dir:
shutil.rmtree(self._get_cache_path(rel_path), ignore_errors=True)

col_path = f"{self.home}/{rel_path}"
col_path = f"{self.logical_path}/{rel_path}"
col = None
try:
col = self.session.collections.get(col_path)
Expand Down Expand Up @@ -566,7 +570,7 @@ def _delete(self, obj, entire_dir: bool = False, **kwargs) -> bool:
data_object_name = p.stem + p.suffix
subcollection_name = p.parent

collection_path = f"{self.home}/{subcollection_name}"
collection_path = f"{self.logical_path}/{subcollection_name}"
data_object_path = f"{collection_path}/{data_object_name}"

try:
Expand All @@ -592,7 +596,7 @@ def _get_object_url(self, obj, **kwargs):
data_object_name = p.stem + p.suffix
subcollection_name = p.parent

collection_path = f"{self.home}/{subcollection_name}"
collection_path = f"{self.logical_path}/{subcollection_name}"
data_object_path = f"{collection_path}/{data_object_name}"

return data_object_path
Expand Down

0 comments on commit 6d8a2d6

Please sign in to comment.