Skip to content

Commit

Permalink
Fix compatibility with different versions of wandb when sweeping.
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Oct 10, 2020
1 parent 6370a1b commit 15e360c
Show file tree
Hide file tree
Showing 12 changed files with 31 additions and 11 deletions.
3 changes: 2 additions & 1 deletion simpletransformers/classification/classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
from simpletransformers.classification.transformer_models.xlnet_model import XLNetForSequenceClassification
from simpletransformers.config.global_args import global_args
from simpletransformers.config.model_args import ClassificationArgs
from simpletransformers.config.utils import sweep_config_to_sweep_values
from simpletransformers.custom_models.models import ElectraForSequenceClassification

try:
Expand Down Expand Up @@ -149,7 +150,7 @@ def __init__(

if "sweep_config" in kwargs:
sweep_config = kwargs.pop("sweep_config")
sweep_values = {key: value["value"] for key, value in sweep_config.as_dict().items() if key != "_wandb"}
sweep_values = sweep_config_to_sweep_values(sweep_config)
self.args.update_from_dict(sweep_values)

if self.args.manual_seed:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from simpletransformers.classification import ClassificationModel
from simpletransformers.config.global_args import global_args
from simpletransformers.config.model_args import MultiLabelClassificationArgs
from simpletransformers.config.utils import sweep_config_to_sweep_values
from simpletransformers.custom_models.models import (
AlbertForMultiLabelSequenceClassification,
BertForMultiLabelSequenceClassification,
Expand Down Expand Up @@ -108,7 +109,7 @@ def __init__(

if "sweep_config" in kwargs:
sweep_config = kwargs.pop("sweep_config")
sweep_values = {key: value["value"] for key, value in sweep_config.as_dict().items() if key != "_wandb"}
sweep_values = sweep_config_to_sweep_values(sweep_config)
self.args.update_from_dict(sweep_values)

if self.args.manual_seed:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from simpletransformers.classification.transformer_models.mmbt_model import MMBTForClassification
from simpletransformers.config.global_args import global_args
from simpletransformers.config.model_args import MultiModalClassificationArgs
from simpletransformers.config.utils import sweep_config_to_sweep_values

try:
import wandb
Expand Down Expand Up @@ -105,7 +106,7 @@ def __init__(

if "sweep_config" in kwargs:
sweep_config = kwargs.pop("sweep_config")
sweep_values = {key: value["value"] for key, value in sweep_config.as_dict().items() if key != "_wandb"}
sweep_values = sweep_config_to_sweep_values(sweep_config)
self.args.update_from_dict(sweep_values)

if self.args.manual_seed:
Expand Down
9 changes: 9 additions & 0 deletions simpletransformers/config/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
def sweep_config_to_sweep_values(sweep_config):
"""
Converts an instance of wandb.Config to plain values map.
wandb.Config varies across versions quite significantly,
so we use the `keys` method that works consistently.
"""

return {key: sweep_config[key] for key in sweep_config.keys()}
3 changes: 2 additions & 1 deletion simpletransformers/conv_ai/conv_ai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from simpletransformers.classification.classification_utils import InputExample, convert_examples_to_features
from simpletransformers.config.global_args import global_args
from simpletransformers.config.model_args import ConvAIArgs
from simpletransformers.config.utils import sweep_config_to_sweep_values
from simpletransformers.conv_ai.conv_ai_utils import get_dataset

try:
Expand Down Expand Up @@ -100,7 +101,7 @@ def __init__(

if "sweep_config" in kwargs:
sweep_config = kwargs.pop("sweep_config")
sweep_values = {key: value["value"] for key, value in sweep_config.as_dict().items() if key != "_wandb"}
sweep_values = sweep_config_to_sweep_values(sweep_config)
self.args.update_from_dict(sweep_values)

if self.args.manual_seed:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from simpletransformers.config.global_args import global_args
from simpletransformers.config.model_args import LanguageGenerationArgs
from simpletransformers.config.utils import sweep_config_to_sweep_values
from simpletransformers.language_generation.language_generation_utils import PREPROCESSING_FUNCTIONS

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -71,7 +72,7 @@ def __init__(

if "sweep_config" in kwargs:
sweep_config = kwargs.pop("sweep_config")
sweep_values = {key: value["value"] for key, value in sweep_config.as_dict().items() if key != "_wandb"}
sweep_values = sweep_config_to_sweep_values(sweep_config)
self.args.update_from_dict(sweep_values)

if self.args.manual_seed:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@

from simpletransformers.config.global_args import global_args
from simpletransformers.config.model_args import LanguageModelingArgs
from simpletransformers.config.utils import sweep_config_to_sweep_values
from simpletransformers.custom_models.models import ElectraForLanguageModelingModel
from simpletransformers.language_modeling.language_modeling_utils import SimpleDataset, mask_tokens

Expand Down Expand Up @@ -132,7 +133,7 @@ def __init__(

if "sweep_config" in kwargs:
sweep_config = kwargs.pop("sweep_config")
sweep_values = {key: value["value"] for key, value in sweep_config.as_dict().items() if key != "_wandb"}
sweep_values = sweep_config_to_sweep_values(sweep_config)
self.args.update_from_dict(sweep_values)

if self.args.manual_seed:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from transformers import BertConfig, BertTokenizer, GPT2Config, GPT2Tokenizer, RobertaConfig, RobertaTokenizer

from simpletransformers.config.model_args import ModelArgs
from simpletransformers.config.utils import sweep_config_to_sweep_values
from simpletransformers.language_representation.transformer_models.bert_model import BertForTextRepresentation
from simpletransformers.language_representation.transformer_models.gpt2_model import GPT2ForTextRepresentation

Expand Down Expand Up @@ -73,7 +74,7 @@ def __init__(

if "sweep_config" in kwargs:
sweep_config = kwargs.pop("sweep_config")
sweep_values = {key: value["value"] for key, value in sweep_config.as_dict().items() if key != "_wandb"}
sweep_values = sweep_config_to_sweep_values(sweep_config)
self.args.update_from_dict(sweep_values)

if self.args.manual_seed:
Expand Down
3 changes: 2 additions & 1 deletion simpletransformers/ner/ner_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@

from simpletransformers.config.global_args import global_args
from simpletransformers.config.model_args import NERArgs
from simpletransformers.config.utils import sweep_config_to_sweep_values
from simpletransformers.ner.ner_utils import (
InputExample,
LazyNERDataset,
Expand Down Expand Up @@ -125,7 +126,7 @@ def __init__(

if "sweep_config" in kwargs:
sweep_config = kwargs.pop("sweep_config")
sweep_values = {key: value["value"] for key, value in sweep_config.as_dict().items() if key != "_wandb"}
sweep_values = sweep_config_to_sweep_values(sweep_config)
self.args.update_from_dict(sweep_values)

if self.args.manual_seed:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@

from simpletransformers.config.global_args import global_args
from simpletransformers.config.model_args import QuestionAnsweringArgs
from simpletransformers.config.utils import sweep_config_to_sweep_values
from simpletransformers.custom_models.models import ElectraForQuestionAnswering, XLMRobertaForQuestionAnswering
from simpletransformers.question_answering.question_answering_utils import (
LazyQuestionAnsweringDataset,
Expand Down Expand Up @@ -131,7 +132,7 @@ def __init__(self, model_type, model_name, args=None, use_cuda=True, cuda_device

if "sweep_config" in kwargs:
sweep_config = kwargs.pop("sweep_config")
sweep_values = {key: value["value"] for key, value in sweep_config.as_dict().items() if key != "_wandb"}
sweep_values = sweep_config_to_sweep_values(sweep_config)
self.args.update_from_dict(sweep_values)

if self.args.manual_seed:
Expand Down
3 changes: 2 additions & 1 deletion simpletransformers/seq2seq/seq2seq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@

from simpletransformers.config.global_args import global_args
from simpletransformers.config.model_args import Seq2SeqArgs
from simpletransformers.config.utils import sweep_config_to_sweep_values
from simpletransformers.seq2seq.seq2seq_utils import Seq2SeqDataset, SimpleSummarizationDataset

try:
Expand Down Expand Up @@ -139,7 +140,7 @@ def __init__(

if "sweep_config" in kwargs:
sweep_config = kwargs.pop("sweep_config")
sweep_values = {key: value["value"] for key, value in sweep_config.as_dict().items() if key != "_wandb"}
sweep_values = sweep_config_to_sweep_values(sweep_config)
self.args.update_from_dict(sweep_values)

if self.args.manual_seed:
Expand Down
3 changes: 2 additions & 1 deletion simpletransformers/t5/t5_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from simpletransformers.config.global_args import global_args
from simpletransformers.config.model_args import T5Args
from simpletransformers.config.utils import sweep_config_to_sweep_values
from simpletransformers.t5.t5_utils import T5Dataset

try:
Expand Down Expand Up @@ -64,7 +65,7 @@ def __init__(

if "sweep_config" in kwargs:
sweep_config = kwargs.pop("sweep_config")
sweep_values = {key: value["value"] for key, value in sweep_config.as_dict().items() if key != "_wandb"}
sweep_values = sweep_config_to_sweep_values(sweep_config)
self.args.update_from_dict(sweep_values)

if self.args.manual_seed:
Expand Down

0 comments on commit 15e360c

Please sign in to comment.