Skip to content

Commit

Permalink
Created dataset.py
Browse files Browse the repository at this point in the history
Up-directories the module
  • Loading branch information
cpuguy96 committed Dec 2, 2023
1 parent cafc860 commit 1de0b30
Show file tree
Hide file tree
Showing 6 changed files with 439 additions and 345 deletions.
4 changes: 2 additions & 2 deletions stepcovnet/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import numpy as np
from sklearn.model_selection import train_test_split

from stepcovnet import dataset
from stepcovnet.common.constants import NUM_ARROW_COMBS
from stepcovnet.common.utils import get_channel_scalers
from stepcovnet.dataset.ModelDataset import ModelDataset
from stepcovnet.training.TrainingHyperparameters import TrainingHyperparameters


Expand Down Expand Up @@ -54,7 +54,7 @@ class TrainingConfig(AbstractConfig):
def __init__(
self,
dataset_path: str,
dataset_type: Type[ModelDataset],
dataset_type: Type[dataset.ModelDataset],
dataset_config,
hyperparameters: TrainingHyperparameters,
all_scalers=None,
Expand Down
7 changes: 3 additions & 4 deletions stepcovnet/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@

from transformers import GPT2Tokenizer

from stepcovnet.dataset.DistributedModelDataset import DistributedModelDataset
from stepcovnet.dataset.ModelDataset import ModelDataset
from stepcovnet import dataset


class Tokenizers(Enum):
GPT2 = GPT2Tokenizer.from_pretrained("gpt2")


class ModelDatasetTypes(Enum):
SINGULAR_DATASET = ModelDataset
DISTRIBUTED_DATASET = DistributedModelDataset
SINGULAR_DATASET = dataset.ModelDataset
DISTRIBUTED_DATASET = dataset.DistributedModelDataset
Loading

0 comments on commit 1de0b30

Please sign in to comment.