Skip to content

Commit

Permalink
2023/09/07-09:42:26 (Linux sv2111 unknown)
Browse files Browse the repository at this point in the history
  • Loading branch information
pbenner committed Sep 7, 2023
1 parent dab9fd8 commit b3bf0bb
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 14 deletions.
16 changes: 12 additions & 4 deletions coordinationnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,14 +164,22 @@ def save(self, filename : str) -> None:
## ----------------------------------------------------------------------------

class LitGraphCoordinationFeaturesData(LitDataset):
def __init__(self, data : CoordinationFeaturesData, val_size = 0.2, batch_size = 32, num_workers = 2, seed = 42, verbose = False):
def __init__(self, data : CoordinationFeaturesData, verbose = False, **kwargs):

data = GraphCoordinationData(data, verbose = verbose)
self.data_raw = data

super().__init__(data, val_size = val_size, batch_size = batch_size, num_workers = num_workers, seed = seed)
super().__init__(None, load_cached_data = 'dataset.dill', **kwargs)

def prepare_data(self):

data = GraphCoordinationData(self.data_raw, verbose = True)

with open(self.cache_path, 'wb') as f:
dill.dump(data, f)

# Custom method to create a data loader
def get_dataloader(self, data):

return GraphCoordinationFeaturesLoader(data, batch_size = self.batch_size, num_workers = self.num_workers)

## ----------------------------------------------------------------------------
Expand All @@ -184,7 +192,7 @@ def __init__(self, **kwargs):

def fit_scaler(self, data : LitGraphCoordinationFeaturesData):

y = torch.cat([ y_batch for _, y_batch in data.get_dataloader(data.data) ])
y = torch.cat([ y_batch for _, y_batch in data.data_raw ])

self.lit_model.model.scaler_outputs.fit(y)

Expand Down
1 change: 1 addition & 0 deletions coordinationnet/model_gnn_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def __init__(self, dataset, **kwargs) -> None:
super().__init__(dataset, collate_fn=self.collate_fn, **kwargs)

def collate_fn(self, batch):

x = [ item[0] for item in batch ]
y = [ item[1] for item in batch ]

Expand Down
39 changes: 29 additions & 10 deletions coordinationnet/model_lit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
## along with this program. If not, see <http://www.gnu.org/licenses/>.
## ----------------------------------------------------------------------------

import dill
import shutil
import torch
import pytorch_lightning as pl
Expand Down Expand Up @@ -136,18 +137,35 @@ def state(self):
## ----------------------------------------------------------------------------

class LitDataset(pl.LightningDataModule, ABC):
def __init__(self, data, val_size = 0.2, batch_size = 32, num_workers = 2, seed = 42):
def __init__(self, data, val_size = 0.2, batch_size = 32, num_workers = 2, default_root_dir = None, load_cached_data = None, seed = 42):
super().__init__()
self.num_workers = num_workers
self.val_size = val_size
self.batch_size = batch_size
self.data = data
self.seed = seed
self.num_workers = num_workers
self.val_size = val_size
self.batch_size = batch_size
self.data = data
self.default_root_dir = default_root_dir
self.load_cached_data = load_cached_data
self.seed = seed

@property
def cache_path(self):

if self.load_cached_data is None:
return None

if self.default_root_dir is not None:
return os.path.join(self.default_root_dir, self.load_cached_data)

return self.load_cached_data

# This function is called by lightning trainer class with
# the corresponding stage option
def setup(self, stage: Optional[str] = None):

if self.cache_path is not None:
with open(self.cache_path, 'rb') as f:
self.data = dill.load(f)

# Assign train/val datasets for use in dataloaders
if stage == 'fit' or stage == None:
# Take a piece of the training data for validation
Expand Down Expand Up @@ -223,10 +241,11 @@ def __init__(self,
'plugins' : plugins,
}
self.data_options = {
'val_size' : val_size,
'batch_size' : batch_size,
'num_workers' : num_workers,
'seed' : seed,
'val_size' : val_size,
'batch_size' : batch_size,
'num_workers' : num_workers,
'default_root_dir' : default_root_dir,
'seed' : seed,
}
self.reset_trainer()

Expand Down

0 comments on commit b3bf0bb

Please sign in to comment.