diff --git a/simpletransformers/classification/classification_model.py b/simpletransformers/classification/classification_model.py index 7a1497c8..d104c887 100755 --- a/simpletransformers/classification/classification_model.py +++ b/simpletransformers/classification/classification_model.py @@ -80,6 +80,7 @@ from simpletransformers.classification.transformer_models.xlnet_model import XLNetForSequenceClassification from simpletransformers.config.global_args import global_args from simpletransformers.config.model_args import ClassificationArgs +from simpletransformers.config.utils import sweep_config_to_sweep_values from simpletransformers.custom_models.models import ElectraForSequenceClassification try: @@ -149,7 +150,7 @@ def __init__( if "sweep_config" in kwargs: sweep_config = kwargs.pop("sweep_config") - sweep_values = {key: value["value"] for key, value in sweep_config.as_dict().items() if key != "_wandb"} + sweep_values = sweep_config_to_sweep_values(sweep_config) self.args.update_from_dict(sweep_values) if self.args.manual_seed: diff --git a/simpletransformers/classification/multi_label_classification_model.py b/simpletransformers/classification/multi_label_classification_model.py index d14a7bbe..90b6169d 100755 --- a/simpletransformers/classification/multi_label_classification_model.py +++ b/simpletransformers/classification/multi_label_classification_model.py @@ -33,6 +33,7 @@ from simpletransformers.classification import ClassificationModel from simpletransformers.config.global_args import global_args from simpletransformers.config.model_args import MultiLabelClassificationArgs +from simpletransformers.config.utils import sweep_config_to_sweep_values from simpletransformers.custom_models.models import ( AlbertForMultiLabelSequenceClassification, BertForMultiLabelSequenceClassification, @@ -108,7 +109,7 @@ def __init__( if "sweep_config" in kwargs: sweep_config = kwargs.pop("sweep_config") - sweep_values = {key: value["value"] for key, value in sweep_config.as_dict().items() if key != "_wandb"} + sweep_values = sweep_config_to_sweep_values(sweep_config) self.args.update_from_dict(sweep_values) if self.args.manual_seed: diff --git a/simpletransformers/classification/multi_modal_classification_model.py b/simpletransformers/classification/multi_modal_classification_model.py index 2f88ffd1..10d6e6ac 100644 --- a/simpletransformers/classification/multi_modal_classification_model.py +++ b/simpletransformers/classification/multi_modal_classification_model.py @@ -49,6 +49,7 @@ from simpletransformers.classification.transformer_models.mmbt_model import MMBTForClassification from simpletransformers.config.global_args import global_args from simpletransformers.config.model_args import MultiModalClassificationArgs +from simpletransformers.config.utils import sweep_config_to_sweep_values try: import wandb @@ -105,7 +106,7 @@ def __init__( if "sweep_config" in kwargs: sweep_config = kwargs.pop("sweep_config") - sweep_values = {key: value["value"] for key, value in sweep_config.as_dict().items() if key != "_wandb"} + sweep_values = sweep_config_to_sweep_values(sweep_config) self.args.update_from_dict(sweep_values) if self.args.manual_seed: diff --git a/simpletransformers/config/utils.py b/simpletransformers/config/utils.py new file mode 100644 index 00000000..a00f8d0b --- /dev/null +++ b/simpletransformers/config/utils.py @@ -0,0 +1,9 @@ +def sweep_config_to_sweep_values(sweep_config): + """ + Converts an instance of wandb.Config to plain values map. + + wandb.Config varies across versions quite significantly, + so we use the `keys` method that works consistently. + """ + + return {key: sweep_config[key] for key in sweep_config.keys()} diff --git a/simpletransformers/conv_ai/conv_ai_model.py b/simpletransformers/conv_ai/conv_ai_model.py index 4b0becbe..e45a36e0 100644 --- a/simpletransformers/conv_ai/conv_ai_model.py +++ b/simpletransformers/conv_ai/conv_ai_model.py @@ -47,6 +47,7 @@ from simpletransformers.classification.classification_utils import InputExample, convert_examples_to_features from simpletransformers.config.global_args import global_args from simpletransformers.config.model_args import ConvAIArgs +from simpletransformers.config.utils import sweep_config_to_sweep_values from simpletransformers.conv_ai.conv_ai_utils import get_dataset try: @@ -100,7 +101,7 @@ def __init__( if "sweep_config" in kwargs: sweep_config = kwargs.pop("sweep_config") - sweep_values = {key: value["value"] for key, value in sweep_config.as_dict().items() if key != "_wandb"} + sweep_values = sweep_config_to_sweep_values(sweep_config) self.args.update_from_dict(sweep_values) if self.args.manual_seed: diff --git a/simpletransformers/language_generation/language_generation_model.py b/simpletransformers/language_generation/language_generation_model.py index 67f3865b..08fb5841 100644 --- a/simpletransformers/language_generation/language_generation_model.py +++ b/simpletransformers/language_generation/language_generation_model.py @@ -29,6 +29,7 @@ from simpletransformers.config.global_args import global_args from simpletransformers.config.model_args import LanguageGenerationArgs +from simpletransformers.config.utils import sweep_config_to_sweep_values from simpletransformers.language_generation.language_generation_utils import PREPROCESSING_FUNCTIONS logger = logging.getLogger(__name__) @@ -71,7 +72,7 @@ def __init__( if "sweep_config" in kwargs: sweep_config = kwargs.pop("sweep_config") - sweep_values = {key: value["value"] for key, value in sweep_config.as_dict().items() if key != "_wandb"} + sweep_values = sweep_config_to_sweep_values(sweep_config) self.args.update_from_dict(sweep_values) if self.args.manual_seed: diff --git a/simpletransformers/language_modeling/language_modeling_model.py b/simpletransformers/language_modeling/language_modeling_model.py index c411f3b7..3ae66ec8 100755 --- a/simpletransformers/language_modeling/language_modeling_model.py +++ b/simpletransformers/language_modeling/language_modeling_model.py @@ -69,6 +69,7 @@ from simpletransformers.config.global_args import global_args from simpletransformers.config.model_args import LanguageModelingArgs +from simpletransformers.config.utils import sweep_config_to_sweep_values from simpletransformers.custom_models.models import ElectraForLanguageModelingModel from simpletransformers.language_modeling.language_modeling_utils import SimpleDataset, mask_tokens @@ -132,7 +133,7 @@ def __init__( if "sweep_config" in kwargs: sweep_config = kwargs.pop("sweep_config") - sweep_values = {key: value["value"] for key, value in sweep_config.as_dict().items() if key != "_wandb"} + sweep_values = sweep_config_to_sweep_values(sweep_config) self.args.update_from_dict(sweep_values) if self.args.manual_seed: diff --git a/simpletransformers/language_representation/representation_model.py b/simpletransformers/language_representation/representation_model.py index 7452f6e5..d6cf1b3e 100644 --- a/simpletransformers/language_representation/representation_model.py +++ b/simpletransformers/language_representation/representation_model.py @@ -13,6 +13,7 @@ from transformers import BertConfig, BertTokenizer, GPT2Config, GPT2Tokenizer, RobertaConfig, RobertaTokenizer from simpletransformers.config.model_args import ModelArgs +from simpletransformers.config.utils import sweep_config_to_sweep_values from simpletransformers.language_representation.transformer_models.bert_model import BertForTextRepresentation from simpletransformers.language_representation.transformer_models.gpt2_model import GPT2ForTextRepresentation @@ -73,7 +74,7 @@ def __init__( if "sweep_config" in kwargs: sweep_config = kwargs.pop("sweep_config") - sweep_values = {key: value["value"] for key, value in sweep_config.as_dict().items() if key != "_wandb"} + sweep_values = sweep_config_to_sweep_values(sweep_config) self.args.update_from_dict(sweep_values) if self.args.manual_seed: diff --git a/simpletransformers/ner/ner_model.py b/simpletransformers/ner/ner_model.py index 03f11b27..04cbee2d 100755 --- a/simpletransformers/ner/ner_model.py +++ b/simpletransformers/ner/ner_model.py @@ -60,6 +60,7 @@ from simpletransformers.config.global_args import global_args from simpletransformers.config.model_args import NERArgs +from simpletransformers.config.utils import sweep_config_to_sweep_values from simpletransformers.ner.ner_utils import ( InputExample, LazyNERDataset, @@ -125,7 +126,7 @@ def __init__( if "sweep_config" in kwargs: sweep_config = kwargs.pop("sweep_config") - sweep_values = {key: value["value"] for key, value in sweep_config.as_dict().items() if key != "_wandb"} + sweep_values = sweep_config_to_sweep_values(sweep_config) self.args.update_from_dict(sweep_values) if self.args.manual_seed: diff --git a/simpletransformers/question_answering/question_answering_model.py b/simpletransformers/question_answering/question_answering_model.py index 0be38d4b..582e6c1c 100755 --- a/simpletransformers/question_answering/question_answering_model.py +++ b/simpletransformers/question_answering/question_answering_model.py @@ -65,6 +65,7 @@ from simpletransformers.config.global_args import global_args from simpletransformers.config.model_args import QuestionAnsweringArgs +from simpletransformers.config.utils import sweep_config_to_sweep_values from simpletransformers.custom_models.models import ElectraForQuestionAnswering, XLMRobertaForQuestionAnswering from simpletransformers.question_answering.question_answering_utils import ( LazyQuestionAnsweringDataset, @@ -131,7 +132,7 @@ def __init__(self, model_type, model_name, args=None, use_cuda=True, cuda_device if "sweep_config" in kwargs: sweep_config = kwargs.pop("sweep_config") - sweep_values = {key: value["value"] for key, value in sweep_config.as_dict().items() if key != "_wandb"} + sweep_values = sweep_config_to_sweep_values(sweep_config) self.args.update_from_dict(sweep_values) if self.args.manual_seed: diff --git a/simpletransformers/seq2seq/seq2seq_model.py b/simpletransformers/seq2seq/seq2seq_model.py index 33846a10..d22f5c9b 100644 --- a/simpletransformers/seq2seq/seq2seq_model.py +++ b/simpletransformers/seq2seq/seq2seq_model.py @@ -58,6 +58,7 @@ from simpletransformers.config.global_args import global_args from simpletransformers.config.model_args import Seq2SeqArgs +from simpletransformers.config.utils import sweep_config_to_sweep_values from simpletransformers.seq2seq.seq2seq_utils import Seq2SeqDataset, SimpleSummarizationDataset try: @@ -139,7 +140,7 @@ def __init__( if "sweep_config" in kwargs: sweep_config = kwargs.pop("sweep_config") - sweep_values = {key: value["value"] for key, value in sweep_config.as_dict().items() if key != "_wandb"} + sweep_values = sweep_config_to_sweep_values(sweep_config) self.args.update_from_dict(sweep_values) if self.args.manual_seed: diff --git a/simpletransformers/t5/t5_model.py b/simpletransformers/t5/t5_model.py index b23cee5c..9696ede0 100644 --- a/simpletransformers/t5/t5_model.py +++ b/simpletransformers/t5/t5_model.py @@ -21,6 +21,7 @@ from simpletransformers.config.global_args import global_args from simpletransformers.config.model_args import T5Args +from simpletransformers.config.utils import sweep_config_to_sweep_values from simpletransformers.t5.t5_utils import T5Dataset try: @@ -64,7 +65,7 @@ def __init__( if "sweep_config" in kwargs: sweep_config = kwargs.pop("sweep_config") - sweep_values = {key: value["value"] for key, value in sweep_config.as_dict().items() if key != "_wandb"} + sweep_values = sweep_config_to_sweep_values(sweep_config) self.args.update_from_dict(sweep_values) if self.args.manual_seed: