Skip to content

Commit

Permalink
Switching run_mlm to xla, wrapping things
Browse files Browse the repository at this point in the history
  • Loading branch information
dchaplinsky committed Oct 22, 2023
1 parent 28067ec commit c2305e1
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ eval/preprocess/
data/
bruk_corpus/
.DS_Store
exps/
4 changes: 3 additions & 1 deletion run_mlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
import wandb
import torch_xla.core.xla_model as xm

# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.18.0.dev0")
Expand Down Expand Up @@ -291,7 +292,7 @@ def main():

for (name, path) in [("train", data_args.train_file), ("validation", data_args.validation_file)]:
if path.endswith(".txt"):
raw_datasets[name] = datasets.Dataset.from_dict({"text": [open(path).read()], "id": [0]})
raw_datasets[name] = datasets.Dataset.from_dict({"text": [open(path).read()], "compound_id": [0], "id": [0]})
else:
raw_datasets[name] = datasets.load_from_disk(data_args.train_file)

Expand Down Expand Up @@ -351,6 +352,7 @@ def main():
model = AutoModelForMaskedLM.from_config(config)

model.resize_token_embeddings(len(tokenizer))
model.to(xm.xla_device())

# Preprocessing the datasets.
# First we tokenize all the texts.
Expand Down

0 comments on commit c2305e1

Please sign in to comment.