Skip to content

Commit 0cd0b49

Browse files
author
PJEstrada
authored
feat: add signer param to dataset object (#50)
1 parent a07ec9d commit 0cd0b49

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

sdk/diffgram/core/diffgram_dataset_iterator.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import sys
66
from threading import Thread
77
from concurrent.futures import ThreadPoolExecutor
8+
from typing import Callable
89

910

1011
class DiffgramDatasetIterator:
@@ -15,13 +16,15 @@ class DiffgramDatasetIterator:
1516
file_cache: dict
1617
_internal_file_list: list
1718
current_file_index: int
19+
custom_signer_fn: Callable
1820

1921
def __init__(self,
2022
project,
2123
diffgram_file_id_list,
2224
validate_ids = True,
2325
max_size_cache = 1073741824,
24-
max_num_concurrent_fetches = 25):
26+
max_num_concurrent_fetches = 25,
27+
custom_signer_fn = None):
2528
"""
2629
2730
:param project (sdk.core.core.Project): A Project object from the Diffgram SDK
@@ -30,6 +33,7 @@ def __init__(self,
3033
self.diffgram_file_id_list = []
3134
self.max_size_cache = 1073741824
3235
self.pool = None
36+
self.custom_signer_fn = custom_signer_fn
3337
self.file_cache = {}
3438
self._internal_file_list = []
3539
self.current_file_index = 0
@@ -118,16 +122,24 @@ def __validate_file_ids(self):
118122
raise Exception(
119123
'Some file IDs do not belong to the project. Please provide only files from the same project.')
120124

125+
def set_custom_url_signer(self, signer_fn: Callable):
126+
self.custom_signer_fn = signer_fn
127+
121128
def get_image_data(self, diffgram_file):
122129
MAX_RETRIES = 10
123130
image = None
124131
if hasattr(diffgram_file, 'image'):
125132
for i in range(0, MAX_RETRIES):
126133
try:
134+
url = None
127135
if diffgram_file.image:
128136
url = diffgram_file.image.get('url_signed')
129-
if url:
130-
image = imread(diffgram_file.image.get('url_signed'))
137+
if diffgram_file.image and self.custom_signer_fn is not None:
138+
blob_path = diffgram_file.image['url_signed_blob_path']
139+
bucket_name = diffgram_file.image['bucket_name']
140+
url = self.custom_signer_fn(blob_path, bucket_name)
141+
if url:
142+
image = imread(url)
131143
break
132144
except Exception as e:
133145
if i < MAX_RETRIES - 1:

0 commit comments

Comments
 (0)