Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Padding Dataset to max_seq_length #1416

Closed
loretoparisi opened this issue Aug 27, 2024 · 7 comments
Closed

[Feature Request] Padding Dataset to max_seq_length #1416

loretoparisi opened this issue Aug 27, 2024 · 7 comments

Comments

@loretoparisi
Copy link

loretoparisi commented Aug 27, 2024

When training Llama3 I wish to pad my unstructured text to the same length. This has been addressed by #1394
Anyways, this means that the dataset tokens sequence length will be the max tensor length found in that specific dataset, because this is how the padded_collate function works.
While in Llama3 I want to have in my custom torch Dataset a specific length, defined externally, like:

def load_dataset(seq_length=2048):
     dataset = text_completion_dataset(
        tokenizer,
        source="text",
        column="text",
        data_files="t8.shakespeare.txt",
        split="train",
        max_seq_len=seq_length,
        packed=False
    )
    return dataset

def get_text_completion_dataset_tokens(seq_length,batch_size):
    from torchtune.utils import padded_collate
    dataset = load_dataset(seq_length=seq_length)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=0, shuffle=False, collate_fn=padded_collate)
    tokens = []
    for sample in dataloader:
        batch = sample['tokens'].tolist()
        for sample in batch:
            tokens.append( sample )
    return tokens

class RandomTokenDataset(Dataset):
    def __init__(self, vocab_size: int, seq_length: int, batch_size:int):
        self.vocab_size = vocab_size
        self.seq_length = seq_length # 8
        self.batch_size = batch_size # 128
        
        self.tokens = get_text_completion_dataset_tokens(seq_length,batch_size)
        
    def __len__(self) -> int:
        return self.seq_length

    def __getitem__(self, item: int):
        return self.tokens[item]

and in Llama3 side

class Llama3(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.save_hyperparameters()
        self.model_args = ModelArgs(vocab_size=32000)
        self.model = Transformer(self.model_args)
        self.save_hyperparameters() # save to logging
        
    def on_train_start(self) -> None:
        self.model.init_weights()

    def training_step(self, batch):
        inputs = batch[:, :-1] if torch.is_tensor(batch) else batch['tokens'][:, :-1]
        labels =  batch[:, 1:] if torch.is_tensor(batch) else batch['tokens'][:, 1:]
        
        output = self.model(inputs)
        
        with loss_parallel():
            loss = F.cross_entropy(output.reshape(-1, output.size(-1)), labels.reshape(-1))
            return loss
            
    def on_train_batch_end(self, outputs, batch, batch_idx):
        loss = outputs['loss']
        self.log('train_loss', loss, sync_dist=True, on_step=True, on_epoch=True, prog_bar=True, logger=self.logger)
        
    def backward(self, *args, **kwargs):
        with loss_parallel():
            super().backward(*args, **kwargs)

    def configure_optimizers(self):
        return torch.optim.AdamW(self.model.parameters(), lr=3e-3, foreach=True)

    def train_dataloader(self):
        
        seq_length = 128
        batch_size = 2
        num_workers = 4
        
        dataset = RandomTokenDataset(vocab_size=self.model_args.vocab_size, seq_length=seq_length, batch_size=batch_size)
        return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)

So, how can I use the padding collate to pad the a given sequence length in torchtune rather than to the max tensor length?
The code above will break because my Tensor size to output = self.model(inputs) will be [2,35] (the max tensor length found in that dataset is 35) but sequence length is 128, so it will only work if Tensor torch size will be [2,128] so I will get a CUDA error cuda Assertion srcIndex < srcSelectDimSize failed. after that because tokenizer size will not match that length.

To be more specific, that dimensionality issue will happen in the forward pass of the Transformer block on the Llama3 class in where embeddings are assigned by the tokens: h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens:

class Transformer(nn.Module):
   def __init__(self, model_args: ModelArgs):
        super().__init__()
        self.model_args = model_args
        self.vocab_size = model_args.vocab_size
        self.n_layers = model_args.n_layers

        # vocab_size=32000, dim=3200
        self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim)
    
        self.layers = torch.nn.ModuleDict()
        for layer_id in range(model_args.n_layers):
            self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args)

        self.norm = RMSNorm(dim=model_args.dim, eps=model_args.norm_eps)

        self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
        self.init_weights()

# ....

    def forward(self, tokens: torch.Tensor):
        """Perform a forward pass through the Transformer model.

        Args:
            tokens (torch.Tensor): Input token indices.

        Returns:
            torch.Tensor: Output logits after applying the Transformer model.

        """
        
        #  error here: my tokenizer output did not match model vocabulary size. 
        # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages
        h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens

        for layer in self.layers.values():
            h = layer(h, self.freqs_cis)

        h = self.norm(h) if self.norm else h
        return self.output(h).float() if self.output else h
@loretoparisi loretoparisi changed the title [Feature Request] Padding Dataset to max_seq_size [Feature Request] Padding Dataset to max_seq_length Aug 27, 2024
@ebsmothers
Copy link
Contributor

ebsmothers commented Aug 28, 2024

Hi @loretoparisi thanks for creating the issue. padded_collate will pad to the length of the longest sequence in the batch, so in general it's expected that your tensor shapes will vary across batches.

If you are just padding to a fixed size, you do not even have to do it in the collate function, strictly speaking. The collate function is typically for operations that depend on the entire batch, but since you want to pad every sample to the same fixed length, you can even do it at the dataset level. E.g.

MyTextCompletionDataset(Dataset):
	def __init__(self):
		# same as in your existing code
		self.dataset = text_completion_dataset(
			tokenizer,
			source="text",
			column="text",
			data_files="t8.shakespeare.txt",
			split="train",
			max_seq_len=seq_length,
			packed=False
		)
		self.max_seq_len = seq_length

	def __getitem__(self, index: int):
		unpadded = self.dataset[index]
		tokens = unpadded["tokens"]
		labels = unpadded["labels"]
        pad_amounts = self.max_seq_len - len(unpadded_tokens)
		if pad_amounts > 0:
			tokens = tokens + [pad_idx] * pad_amounts
			labels = labels + [CROSS_ENTROPY_IGNORE_IDX] * pad_amounts
		else:
			padded_tokenns
		return {"tokens": tokens, "labels": labels}
		

Btw I don't claim this will work 100%, but this should be the gist of it.

@loretoparisi
Copy link
Author

loretoparisi commented Aug 28, 2024

@ebsmothers I would say it works I've just added a default 0 value for the padding index in the __getitem__:

class PaddedTextCompletionDataset(Dataset):
    def __init__(self, tokenizer: object, seq_length: int, data_files: str, padding_idx: int = 0):
        self.dataset = text_completion_dataset(
            tokenizer,
            source="text",
            column="text",
            data_files=data_files,
            split="train",
            max_seq_len=seq_length,
            packed=False
        )
        self.padding_idx = padding_idx
        self.max_seq_len = seq_length
        
    def __len__(self) -> int:
        return self.max_seq_len
    
    def __getitem__(self, index: int):
        unpadded = self.dataset[index]
        tokens = unpadded["tokens"]
        labels = unpadded["labels"]
        pad_amounts = self.max_seq_len - len(tokens)
        if pad_amounts > 0:
            tokens = tokens + [self.padding_idx] * pad_amounts
            labels = labels + [self.padding_idx] * pad_amounts
        else:
            padded_tokenns
        return {"tokens": tokens, "labels": labels}

💯
Not sure if there could be any effect on the Cross Entropy or if we have to specify the CROSS_ENTROPY_IGNORE_IDX = -100 here as well doing like in the source code of padding:

input_ids = pad_sequence(
        [torch.tensor(x["tokens"]) for x in batch],
        batch_first=True,
        padding_value=padding_idx,
    )
    labels = pad_sequence(
        [torch.tensor(x["labels"]) for x in batch],
        batch_first=True,
        padding_value=ignore_idx,
    )

shall we use the ignore_idx for the labels in our new wrapper PaddedTextCompletionDataset?

Thanks!

@ebsmothers
Copy link
Contributor

@loretoparisi oops good catch! That's my mistake, you should use CROSS_ENTROPY_IGNORE_IDX when padding the labels and the tokenizer's pad_id when padding the input IDs. I'll update my original comment to reflect this

@loretoparisi
Copy link
Author

@loretoparisi oops good catch! That's my mistake, you should use CROSS_ENTROPY_IGNORE_IDX when padding the labels and the tokenizer's pad_id when padding the input IDs. I'll update my original comment to reflect this

Thank you. For the sake of correction I have removed the else branch that it is not needed:

def __getitem__(self, index: int):
        unpadded = self.dataset[index]
        tokens = unpadded["tokens"]
        labels = unpadded["labels"]
        pad_amounts = self.max_seq_len - len(tokens)
        if pad_amounts > 0:
            tokens = tokens + [self.padding_idx] * pad_amounts
            labels = labels + [self.ignore_idx] * pad_amounts
        return {"tokens": tokens, "labels": labels}

@loretoparisi loretoparisi reopened this Aug 29, 2024
@loretoparisi
Copy link
Author

loretoparisi commented Aug 29, 2024

@ebsmothers there is a format error of the inputs and labels. While I would expected this format

{'tokens': tensor([[128000,    262,   1628,    358,    311,  31238,    369,   1077,      0,
            311,   3821,    369,   1077,      0, 128001,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0],
        [128000,    262,  14636,  79760,    374,    719,    279,   1709,   2230,
          31284, 128001,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0],
        [128000,    220,  19333,   1406,   2078,     13,   3639,   1071,  10466,
           4008,    355,  82162,     30, 128001,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0],
        [128000,    262,    763,   1778,    459,  34662,     26,   1268,   1253,
            358,  23528,    433,     11, 128001,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0],
        [128000,    262,   2052,  49972,  16296,    617,  38617,    757,   8617,
             25, 128001,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0],
        [128000,    262,   6914,    856,  87945,    539,  73093,  26236,  64391,
             11, 128001,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0],
        [     0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0],
        [128000,    262,   3296,   1938,    596,   5603,   1427,    311,    387,
          12263,     13, 128001,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0]],
       device='cuda:2'), 'labels': tensor([[128000,    262,   1628,    358,    311,  31238,    369,   1077,      0,
            311,   3821,    369,   1077,      0, 128001,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100],
        [128000,    262,  14636,  79760,    374,    719,    279,   1709,   2230,
          31284, 128001,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100],
        [128000,    220,  19333,   1406,   2078,     13,   3639,   1071,  10466,
           4008,    355,  82162,     30, 128001,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100],
        [128000,    262,    763,   1778,    459,  34662,     26,   1268,   1253,
            358,  23528,    433,     11, 128001,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100],
        [128000,    262,   2052,  49972,  16296,    617,  38617,    757,   8617,
             25, 128001,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100],
        [128000,    262,   6914,    856,  87945,    539,  73093,  26236,  64391,
             11, 128001,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100],
        [  -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100],
        [128000,    262,   3296,   1938,    596,   5603,   1427,    311,    387,
          12263,     13, 128001,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100,
           -100,   -100,   -100,   -100,   -100,   -100,   -100,   -100]],
       device='cuda:2')}

I'm getting from MyTextCompletionDataset a list based format for both tokens and inputs:

{'tokens': [tensor([128000, 128000, 128000, 128000, 128000, 128000, 128000, 128000]), tensor([ 220, 3909,  220,  220,  220,  220,  220,  220]), tensor([22037,   220, 29970, 30538, 14968, 31685,  9086,  3011]), tensor([  596,    19,   596, 26236, 49972,   297, 11307, 28592]), tensor([   387, 128001,     83,    659,   5992,  94678,    901,  13444]), tensor([  724,     0, 26236, 26236, 61735,  1355,  4325,   596]), tensor([ 6835,     0,  3177, 13444,   813,  3409,   568, 16392]), tensor([ 4400,     0,   596,   596, 13444,   291,   312,  2643]), tensor([  719,     0, 35678, 20160,  2103,   323, 10274,  2646]), tensor([  656,     0,   449,    30,    11, 12743,   339,  2815]), tensor([   339,      0,    659, 128001, 128001,   2136,    505,     11]), tensor([ 39580,      0,  18451,      0,      0,   1475,    279, 128001]), tensor([   11,     0, 77057,     0,     0,  1405,  1938,     0]), tensor([128001,      0,  10633,      0,      0,     25,     11,      0]), tensor([     0,      0,     11,      0,      0, 128001, 128001,      0]), tensor([     0,      0, 128001,      0,      0,      0,      0,      0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 0, 0, 0, 0])], 'labels': [tensor([128000, 128000, 128000, 128000, 128000, 128000, 128000, 128000]), tensor([ 220, 3909,  220,  220,  220,  220,  220,  220]), tensor([22037,   220, 29970, 30538, 14968, 31685,  9086,  3011]), tensor([  596,    19,   596, 26236, 49972,   297, 11307, 28592]), tensor([   387, 128001,     83,    659,   5992,  94678,    901,  13444]), tensor([  724,  -100, 26236, 26236, 61735,  1355,  4325,   596]), tensor([ 6835,  -100,  3177, 13444,   813,  3409,   568, 16392]), tensor([ 4400,  -100,   596,   596, 13444,   291,   312,  2643]), tensor([  719,  -100, 35678, 20160,  2103,   323, 10274,  2646]), tensor([  656,  -100,   449,    30,    11, 12743,   339,  2815]), tensor([   339,   -100,    659, 128001, 128001,   2136,    505,     11]), tensor([ 39580,   -100,  18451,   -100,   -100,   1475,    279, 128001]), tensor([   11,  -100, 77057,  -100,  -100,  1405,  1938,  -100]), tensor([128001,   -100,  10633,   -100,   -100,     25,     11,   -100]), tensor([  -100,   -100,     11,   -100,   -100, 128001, 128001,   -100]), tensor([  -100,   -100, 128001,   -100,   -100,   -100,   -100,   -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100]), tensor([-100, -100, -100, -100, -100, -100, -100, -100])]}

@loretoparisi
Copy link
Author

[UPDATE]
This can be addressed adding a conversion to a Tensor

    def __getitem__(self, index: int):
        unpadded = self.dataset[index]
        tokens = unpadded["tokens"]
        labels = unpadded["labels"]
        pad_amounts = self.max_seq_len - len(tokens)
        if pad_amounts > 0:
            tokens = tokens + [self.padding_idx] * pad_amounts
            labels = labels + [self.ignore_idx] * pad_amounts
        tokens = torch.tensor(tokens, dtype=torch.long)
        labels = torch.tensor(labels, dtype=torch.long)
        return {"tokens": tokens, "labels": labels}

now the output tensor is correct:

{'tokens': tensor([[     0,      0,      0,  ...,      0,      0,      0],
        [128000,    220,   5112,  ...,      0,      0,      0],
        [128000,    262,   1115,  ...,      0,      0,      0],
        ...,
        [128000,    220,  22037,  ...,      0,      0,      0],
        [128000,    220,   2030,  ...,      0,      0,      0],
        [128000,    220,   5046,  ...,      0,      0,      0]]), 'labels': tensor([[  -100,   -100,   -100,  ...,   -100,   -100,   -100],
        [128000,    220,   5112,  ...,   -100,   -100,   -100],
        [128000,    262,   1115,  ...,   -100,   -100,   -100],
        ...,
        [128000,    220,  22037,  ...,   -100,   -100,   -100],
        [128000,    220,   2030,  ...,   -100,   -100,   -100],
        [128000,    220,   5046,  ...,   -100,   -100,   -100]])}
batch_size: 8
128 !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
128   Then were not summer's distillation left
128     This were to be new made when thou art old,
128 !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
128   If ten of thine ten times refigured thee:
128   Nature's bequest gives nothing but doth lend,
128   But thou contracted to thine own bright eyes,
128   Of his self-love to stop posterity? 

@ebsmothers
Copy link
Contributor

Thanks for the updates @loretoparisi! I am gonna close this issue but please feel free to reopen if you run into any other difficulties here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants