diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index a81d202c6634af..946940cb019481 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -185,6 +185,9 @@ generation. [[autodoc]] SuppressTokensLogitsProcessor - __call__ +[[autodoc]] SynthIDTextWatermarkLogitsProcessor + - __call__ + [[autodoc]] TemperatureLogitsWarper - __call__ @@ -418,5 +421,20 @@ A [`Constraint`] can be used to force the generation to include specific tokens ## Watermark Utils +[[autodoc]] WatermarkingConfig + - __call__ + [[autodoc]] WatermarkDetector - __call__ + +[[autodoc]] BayesianDetectorConfig + - __call__ + +[[autodoc]] BayesianDetectorModel + - __call__ + +[[autodoc]] SynthIDTextWatermarkingConfig + - __call__ + +[[autodoc]] SynthIDTextWatermarkDetector + - __call__ diff --git a/docs/source/en/main_classes/text_generation.md b/docs/source/en/main_classes/text_generation.md index 574e4c75a6ac8a..76a0f1381cd6bc 100644 --- a/docs/source/en/main_classes/text_generation.md +++ b/docs/source/en/main_classes/text_generation.md @@ -41,8 +41,6 @@ like token streaming. - validate - get_generation_mode -[[autodoc]] generation.WatermarkingConfig - ## GenerationMixin [[autodoc]] GenerationMixin diff --git a/examples/research_projects/synthid_text/README.md b/examples/research_projects/synthid_text/README.md new file mode 100644 index 00000000000000..30ab999037374d --- /dev/null +++ b/examples/research_projects/synthid_text/README.md @@ -0,0 +1,34 @@ +# SynthID Text + +This project showcases the use of SynthIDText for watermarking LLMs. The code shown in this repo also +demostrates the training of the detector for detecting such watermarked text. This detector can be uploaded onto +a private HF hub repo (private for security reasons) and can be initialized again through pretrained model loading also shown in this script. + +See our blog post: https://huggingface.co/blog/synthid-text + + +## Python version + +User would need python 3.9 to run this example. + +## Installation and running + +Once you install transformers you would need to install requirements for this project through requirements.txt provided in this folder. + +``` +pip install -r requirements.txt +``` + +## To run the detector training + +``` +python detector_training.py --model_name=google/gemma-7b-it +``` + +Check the script for more parameters are are tunable and check out paper at link +https://www.nature.com/articles/s41586-024-08025-4 for more information on these parameters. + +## Caveat + +Make sure to run the training of the detector and the detection on the same hardware +CPU, GPU or TPU to get consistent results (we use detecterministic randomness which is hardware dependent). diff --git a/examples/research_projects/synthid_text/detector_training.py b/examples/research_projects/synthid_text/detector_training.py new file mode 100644 index 00000000000000..35d0ea22f42b23 --- /dev/null +++ b/examples/research_projects/synthid_text/detector_training.py @@ -0,0 +1,502 @@ +# coding=utf-8 +# Copyright 2024 Google DeepMind. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import dataclasses +import enum +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch + +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + BayesianDetectorConfig, + BayesianDetectorModel, + SynthIDTextWatermarkDetector, + SynthIDTextWatermarkingConfig, + SynthIDTextWatermarkLogitsProcessor, +) +from utils import ( + get_tokenized_uwm_outputs, + get_tokenized_wm_outputs, + process_raw_model_outputs, + update_fn_if_fpr_tpr, + upload_model_to_hf, +) + + +@enum.unique +class ValidationMetric(enum.Enum): + """Direction along the z-axis.""" + + TPR_AT_FPR = "tpr_at_fpr" + CROSS_ENTROPY = "cross_entropy" + + +@dataclasses.dataclass +class TrainingArguments: + """Training arguments pertaining to the training loop itself.""" + + eval_metric: Optional[str] = dataclasses.field( + default=ValidationMetric.TPR_AT_FPR, metadata={"help": "The evaluation metric used."} + ) + + +def train_detector( + detector: torch.nn.Module, + g_values: torch.Tensor, + mask: torch.Tensor, + watermarked: torch.Tensor, + epochs: int = 250, + learning_rate: float = 1e-3, + minibatch_size: int = 64, + seed: int = 0, + l2_weight: float = 0.0, + shuffle: bool = True, + g_values_val: Optional[torch.Tensor] = None, + mask_val: Optional[torch.Tensor] = None, + watermarked_val: Optional[torch.Tensor] = None, + verbose: bool = False, + validation_metric: ValidationMetric = ValidationMetric.TPR_AT_FPR, +) -> Tuple[Dict[str, Any], float]: + """Trains a Bayesian detector model. + + Args: + g_values: g-values of shape [num_train, seq_len, watermarking_depth]. + mask: A binary array shape [num_train, seq_len] indicating which g-values + should be used. g-values with mask value 0 are discarded. + watermarked: A binary array of shape [num_train] indicating whether the + example is watermarked (0: unwatermarked, 1: watermarked). + epochs: Number of epochs to train for. + learning_rate: Learning rate for optimizer. + minibatch_size: Minibatch size for training. Note that a minibatch + requires ~ 32 * minibatch_size * seq_len * watermarked_depth * + watermarked_depth bits of memory. + seed: Seed for parameter initialization. + l2_weight: Weight to apply to L2 regularization for delta parameters. + shuffle: Whether to shuffle before training. + g_values_val: Validation g-values of shape [num_val, seq_len, + watermarking_depth]. + mask_val: Validation mask of shape [num_val, seq_len]. + watermarked_val: Validation watermark labels of shape [num_val]. + verbose: Boolean indicating verbosity of training. If true, the loss will + be printed. Defaulted to False. + use_tpr_fpr_for_val: Whether to use TPR@FPR=1% as metric for validation. + If false, use cross entropy loss. + + Returns: + Tuple of + training_history: Training history keyed by epoch number where the + values are + dictionaries containing the loss, validation loss, and model + parameters, + keyed by + 'loss', 'val_loss', and 'params', respectively. + min_val_loss: Minimum validation loss achieved during training. + """ + + # Set the random seed for reproducibility + torch.manual_seed(seed) + + # Shuffle the data if required + if shuffle: + indices = torch.randperm(len(g_values)) + g_values = g_values[indices] + mask = mask[indices] + watermarked = watermarked[indices] + + # Initialize optimizer + optimizer = torch.optim.Adam(detector.parameters(), lr=learning_rate) + history = {} + min_val_loss = float("inf") + + for epoch in range(epochs): + losses = [] + detector.train() + num_batches = len(g_values) // minibatch_size + for i in range(0, len(g_values), minibatch_size): + end = i + minibatch_size + if end > len(g_values): + break + loss_batch_weight = l2_weight / num_batches + + optimizer.zero_grad() + loss = detector( + g_values=g_values[i:end], + mask=mask[i:end], + labels=watermarked[i:end], + loss_batch_weight=loss_batch_weight, + )[1] + loss.backward() + optimizer.step() + losses.append(loss.item()) + train_loss = sum(losses) / len(losses) + + val_losses = [] + if g_values_val is not None: + detector.eval() + if validation_metric == ValidationMetric.TPR_AT_FPR: + val_loss = update_fn_if_fpr_tpr( + detector, + g_values_val, + mask_val, + watermarked_val, + minibatch_size=minibatch_size, + ) + else: + for i in range(0, len(g_values_val), minibatch_size): + end = i + minibatch_size + if end > len(g_values_val): + break + with torch.no_grad(): + v_loss = detector( + g_values=g_values_val[i:end], + mask=mask_val[i:end], + labels=watermarked_val[i:end], + loss_batch_weight=0, + )[1] + val_losses.append(v_loss.item()) + val_loss = sum(val_losses) / len(val_losses) + + # Store training history + history[epoch + 1] = {"loss": train_loss, "val_loss": val_loss} + if verbose: + if val_loss is not None: + print(f"Epoch {epoch}: loss {loss} (train), {val_loss} (val)") + else: + print(f"Epoch {epoch}: loss {loss} (train)") + + if val_loss is not None and val_loss < min_val_loss: + min_val_loss = val_loss + best_val_epoch = epoch + + if verbose: + print(f"Best val Epoch: {best_val_epoch}, min_val_loss: {min_val_loss}") + + return history, min_val_loss + + +def train_best_detector( + tokenized_wm_outputs: Union[List[np.ndarray], np.ndarray], + tokenized_uwm_outputs: Union[List[np.ndarray], np.ndarray], + logits_processor: SynthIDTextWatermarkLogitsProcessor, + tokenizer: Any, + torch_device: torch.device, + test_size: float = 0.3, + pos_truncation_length: Optional[int] = 200, + neg_truncation_length: Optional[int] = 100, + max_padded_length: int = 2300, + n_epochs: int = 50, + learning_rate: float = 2.1e-2, + l2_weights: np.ndarray = np.logspace(-3, -2, num=4), + verbose: bool = False, + validation_metric: ValidationMetric = ValidationMetric.TPR_AT_FPR, +): + """Train and return the best detector given range of hyperparameters. + + In practice, we have found that tuning pos_truncation_length, + neg_truncation_length, n_epochs, learning_rate and l2_weights can help + improve the performance of the detector. We reccommend tuning these + parameters for your data. + """ + l2_weights = list(l2_weights) + + ( + train_g_values, + train_masks, + train_labels, + cv_g_values, + cv_masks, + cv_labels, + ) = process_raw_model_outputs( + logits_processor, + tokenizer, + pos_truncation_length, + neg_truncation_length, + max_padded_length, + tokenized_wm_outputs, + test_size, + tokenized_uwm_outputs, + torch_device, + ) + + best_detector = None + lowest_loss = float("inf") + val_losses = [] + for l2_weight in l2_weights: + config = BayesianDetectorConfig(watermarking_depth=len(logits_processor.keys)) + detector = BayesianDetectorModel(config).to(torch_device) + _, min_val_loss = train_detector( + detector=detector, + g_values=train_g_values, + mask=train_masks, + watermarked=train_labels, + g_values_val=cv_g_values, + mask_val=cv_masks, + watermarked_val=cv_labels, + learning_rate=learning_rate, + l2_weight=l2_weight, + epochs=n_epochs, + verbose=verbose, + validation_metric=validation_metric, + ) + val_losses.append(min_val_loss) + if min_val_loss < lowest_loss: + lowest_loss = min_val_loss + best_detector = detector + return best_detector, lowest_loss + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_name", + type=str, + default="google/gemma-2b-it", + help=("LM model to train the detector for."), + ) + parser.add_argument( + "--temperature", + type=float, + default=1.0, + help=("Temperature to sample from the model."), + ) + parser.add_argument( + "--top_k", + type=int, + default=40, + help=("Top K for sampling."), + ) + parser.add_argument( + "--top_p", + type=float, + default=1.0, + help=("Top P for sampling."), + ) + parser.add_argument( + "--num_negatives", + type=int, + default=10000, + help=("Number of negatives for detector training."), + ) + parser.add_argument( + "--pos_batch_size", + type=int, + default=32, + help=("Batch size of watermarked positives while sampling."), + ) + parser.add_argument( + "--num_pos_batch", + type=int, + default=313, + help=("Number of positive batches for training."), + ) + parser.add_argument( + "--generation_length", + type=int, + default=512, + help=("Generation length for sampling."), + ) + parser.add_argument( + "--save_model_to_hf_hub", + action="store_true", + help=("Whether to save the trained model HF hub. By default it will be a private repo."), + ) + parser.add_argument( + "--load_from_hf_hub", + action="store_true", + help=( + "Whether to load trained detector model from HF Hub, make sure its the model trained on the same model " + "we are loading in the script." + ), + ) + parser.add_argument( + "--hf_hub_model_name", + type=str, + default=None, + help=("HF hub model name for loading of saving the model."), + ) + parser.add_argument( + "--eval_detector_on_prompts", + action="store_true", + help=("Evaluate detector on a prompt and print probability of watermark."), + ) + + args = parser.parse_args() + model_name = args.model_name + temperature = args.temperature + top_k = args.top_k + top_p = args.top_p + num_negatives = args.num_negatives + pos_batch_size = args.pos_batch_size + num_pos_batch = args.num_pos_batch + if num_pos_batch < 10: + raise ValueError("--num_pos_batch should be greater than 10.") + generation_length = args.generation_length + save_model_to_hf_hub = args.save_model_to_hf_hub + load_from_hf_hub = args.load_from_hf_hub + repo_name = args.hf_hub_model_name + eval_detector_on_prompts = args.eval_detector_on_prompts + + NEG_BATCH_SIZE = 32 + + # Truncate outputs to this length for training. + POS_TRUNCATION_LENGTH = 200 + NEG_TRUNCATION_LENGTH = 100 + # Pad trucated outputs to this length for equal shape across all batches. + MAX_PADDED_LENGTH = 1000 + + DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") + if DEVICE.type not in ("cuda", "tpu"): + raise ValueError("We have found the training stable on GPU and TPU, we are working on" " a fix for CPUs") + + model = None + if not load_from_hf_hub: + # Change this to make your watermark unique. Check documentation in the paper to understand the + # impact of these parameters. + DEFAULT_WATERMARKING_CONFIG = { + "ngram_len": 5, # This corresponds to H=4 context window size in the paper. + "keys": [ + 654, + 400, + 836, + 123, + 340, + 443, + 597, + 160, + 57, + 29, + 590, + 639, + 13, + 715, + 468, + 990, + 966, + 226, + 324, + 585, + 118, + 504, + 421, + 521, + 129, + 669, + 732, + 225, + 90, + 960, + ], + "sampling_table_size": 2**16, + "sampling_table_seed": 0, + "context_history_size": 1024, + } + watermark_config = SynthIDTextWatermarkingConfig(**DEFAULT_WATERMARKING_CONFIG) + + model = AutoModelForCausalLM.from_pretrained(model_name).to(DEVICE) + tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer.pad_token = tokenizer.eos_token + + logits_processor = SynthIDTextWatermarkLogitsProcessor(**DEFAULT_WATERMARKING_CONFIG, device=DEVICE) + tokenized_wm_outputs = get_tokenized_wm_outputs( + model, + tokenizer, + watermark_config, + num_pos_batch, + pos_batch_size, + temperature, + generation_length, + top_k, + top_p, + DEVICE, + ) + tokenized_uwm_outputs = get_tokenized_uwm_outputs(num_negatives, NEG_BATCH_SIZE, tokenizer, DEVICE) + + best_detector, lowest_loss = train_best_detector( + tokenized_wm_outputs=tokenized_wm_outputs, + tokenized_uwm_outputs=tokenized_uwm_outputs, + logits_processor=logits_processor, + tokenizer=tokenizer, + torch_device=DEVICE, + test_size=0.3, + pos_truncation_length=POS_TRUNCATION_LENGTH, + neg_truncation_length=NEG_TRUNCATION_LENGTH, + max_padded_length=MAX_PADDED_LENGTH, + n_epochs=100, + learning_rate=3e-3, + l2_weights=[ + 0, + ], + verbose=True, + validation_metric=ValidationMetric.TPR_AT_FPR, + ) + else: + if repo_name is None: + raise ValueError("When loading from pretrained detector model name cannot be None.") + best_detector = BayesianDetectorModel.from_pretrained(repo_name).to(DEVICE) + + best_detector.config.set_detector_information( + model_name=model_name, watermarking_config=DEFAULT_WATERMARKING_CONFIG + ) + if save_model_to_hf_hub: + upload_model_to_hf(best_detector, repo_name) + + # Evaluate model response with the detector + if eval_detector_on_prompts: + model_name = best_detector.config.model_name + watermark_config_dict = best_detector.config.watermarking_config + logits_processor = SynthIDTextWatermarkLogitsProcessor(**watermark_config_dict, device=DEVICE) + tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer.pad_token = tokenizer.eos_token + synthid_text_detector = SynthIDTextWatermarkDetector(best_detector, logits_processor, tokenizer) + + if model is None: + model = AutoModelForCausalLM.from_pretrained(model_name).to(DEVICE) + watermarking_config = SynthIDTextWatermarkingConfig(**watermark_config_dict) + + prompts = ["Write a essay on cats."] + inputs = tokenizer( + prompts, + return_tensors="pt", + padding=True, + ).to(DEVICE) + + _, inputs_len = inputs["input_ids"].shape + + outputs = model.generate( + **inputs, + watermarking_config=watermarking_config, + do_sample=True, + max_length=inputs_len + generation_length, + temperature=temperature, + top_k=40, + top_p=1.0, + ) + outputs = outputs[:, inputs_len:] + result = synthid_text_detector(outputs) + + # You should set this based on expected fpr (false positive rate) and tpr (true positive rate). + # Check our demo at HF Spaces for more info. + upper_threshold = 0.95 + lower_threshold = 0.12 + if result[0][0] > upper_threshold: + print("The text is watermarked.") + elif lower_threshold < result[0][0] < upper_threshold: + print("It is hard to determine if the text is watermarked or not.") + else: + print("The text is not watermarked.") diff --git a/examples/research_projects/synthid_text/requirements.txt b/examples/research_projects/synthid_text/requirements.txt new file mode 100644 index 00000000000000..9e40a93ee08f09 --- /dev/null +++ b/examples/research_projects/synthid_text/requirements.txt @@ -0,0 +1,5 @@ +tensorflow-datasets>=4.9.3 +torch >= 1.3 +datasets +scikit-learn +tensorflow diff --git a/examples/research_projects/synthid_text/utils.py b/examples/research_projects/synthid_text/utils.py new file mode 100644 index 00000000000000..abcb6ca2f28255 --- /dev/null +++ b/examples/research_projects/synthid_text/utils.py @@ -0,0 +1,408 @@ +# coding=utf-8 +# Copyright 2024 Google DeepMind. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +from typing import Any, List, Optional, Tuple + +import datasets +import numpy as np +import tensorflow as tf +import tensorflow_datasets as tfds +import torch +import tqdm +from huggingface_hub import HfApi, create_repo +from huggingface_hub.utils import RepositoryNotFoundError +from sklearn import model_selection + +import transformers + + +def pad_to_len( + arr: torch.Tensor, + target_len: int, + left_pad: bool, + eos_token: int, + device: torch.device, +) -> torch.Tensor: + """Pad or truncate array to given length.""" + if arr.shape[1] < target_len: + shape_for_ones = list(arr.shape) + shape_for_ones[1] = target_len - shape_for_ones[1] + padded = ( + torch.ones( + shape_for_ones, + device=device, + dtype=torch.long, + ) + * eos_token + ) + if not left_pad: + arr = torch.concatenate((arr, padded), dim=1) + else: + arr = torch.concatenate((padded, arr), dim=1) + else: + arr = arr[:, :target_len] + return arr + + +def filter_and_truncate( + outputs: torch.Tensor, + truncation_length: Optional[int], + eos_token_mask: torch.Tensor, +) -> torch.Tensor: + """Filter and truncate outputs to given length. + + Args: + outputs: output tensor of shape [batch_size, output_len] + truncation_length: Length to truncate the final output. + eos_token_mask: EOS token mask of shape [batch_size, output_len] + + Returns: + output tensor of shape [batch_size, truncation_length]. + """ + if truncation_length: + outputs = outputs[:, :truncation_length] + truncation_mask = torch.sum(eos_token_mask, dim=1) >= truncation_length + return outputs[truncation_mask, :] + return outputs + + +def process_outputs_for_training( + all_outputs: List[torch.Tensor], + logits_processor: transformers.generation.SynthIDTextWatermarkLogitsProcessor, + tokenizer: Any, + pos_truncation_length: Optional[int], + neg_truncation_length: Optional[int], + max_length: int, + is_cv: bool, + is_pos: bool, + torch_device: torch.device, +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """Process raw model outputs into format understandable by the detector. + + Args: + all_outputs: sequence of outputs of shape [batch_size, output_len]. + logits_processor: logits processor used for watermarking. + tokenizer: tokenizer used for the model. + pos_truncation_length: Length to truncate wm outputs. + neg_truncation_length: Length to truncate uwm outputs. + max_length: Length to pad truncated outputs so that all processed entries. + have same shape. + is_cv: Process given outputs for cross validation. + is_pos: Process given outputs for positives. + torch_device: torch device to use. + + Returns: + Tuple of + all_masks: list of masks of shape [batch_size, max_length]. + all_g_values: list of g_values of shape [batch_size, max_length, depth]. + """ + all_masks = [] + all_g_values = [] + for outputs in tqdm.tqdm(all_outputs): + # outputs is of shape [batch_size, output_len]. + # output_len can differ from batch to batch. + eos_token_mask = logits_processor.compute_eos_token_mask( + input_ids=outputs, + eos_token_id=tokenizer.eos_token_id, + ) + if is_pos or is_cv: + # filter with length for positives for both train and CV. + # We also filter for length when CV negatives are processed. + outputs = filter_and_truncate(outputs, pos_truncation_length, eos_token_mask) + elif not is_pos and not is_cv: + outputs = filter_and_truncate(outputs, neg_truncation_length, eos_token_mask) + + # If no filtered outputs skip this batch. + if outputs.shape[0] == 0: + continue + + # All outputs are padded to max-length with eos-tokens. + outputs = pad_to_len(outputs, max_length, False, tokenizer.eos_token_id, torch_device) + # outputs shape [num_filtered_entries, max_length] + + eos_token_mask = logits_processor.compute_eos_token_mask( + input_ids=outputs, + eos_token_id=tokenizer.eos_token_id, + ) + + context_repetition_mask = logits_processor.compute_context_repetition_mask( + input_ids=outputs, + ) + + # context_repetition_mask of shape [num_filtered_entries, max_length - + # (ngram_len - 1)]. + context_repetition_mask = pad_to_len(context_repetition_mask, max_length, True, 0, torch_device) + # We pad on left to get same max_length shape. + # context_repetition_mask of shape [num_filtered_entries, max_length]. + combined_mask = context_repetition_mask * eos_token_mask + + g_values = logits_processor.compute_g_values( + input_ids=outputs, + ) + + # g_values of shape [num_filtered_entries, max_length - (ngram_len - 1), + # depth]. + g_values = pad_to_len(g_values, max_length, True, 0, torch_device) + + # We pad on left to get same max_length shape. + # g_values of shape [num_filtered_entries, max_length, depth]. + all_masks.append(combined_mask) + all_g_values.append(g_values) + return all_masks, all_g_values + + +def tpr_at_fpr(detector, detector_inputs, w_true, minibatch_size, target_fpr=0.01) -> torch.Tensor: + """Calculates true positive rate (TPR) at false positive rate (FPR)=target_fpr.""" + positive_idxs = w_true == 1 + negative_idxs = w_true == 0 + num_samples = detector_inputs[0].size(0) + + w_preds = [] + for start in range(0, num_samples, minibatch_size): + end = start + minibatch_size + detector_inputs_ = ( + detector_inputs[0][start:end], + detector_inputs[1][start:end], + ) + with torch.no_grad(): + w_pred = detector(*detector_inputs_)[0] + w_preds.append(w_pred) + + w_pred = torch.cat(w_preds, dim=0) # Concatenate predictions + positive_scores = w_pred[positive_idxs] + negative_scores = w_pred[negative_idxs] + + # Calculate the FPR threshold + # Note: percentile -> quantile + fpr_threshold = torch.quantile(negative_scores, 1 - target_fpr) + # Note: need to switch to FP32 since torch.mean doesn't work with torch.bool + return torch.mean((positive_scores >= fpr_threshold).to(dtype=torch.float32)).item() # TPR + + +def update_fn_if_fpr_tpr(detector, g_values_val, mask_val, watermarked_val, minibatch_size): + """Loss function for negative TPR@FPR=1% as the validation loss.""" + tpr_ = tpr_at_fpr( + detector=detector, + detector_inputs=(g_values_val, mask_val), + w_true=watermarked_val, + minibatch_size=minibatch_size, + ) + return -tpr_ + + +def process_raw_model_outputs( + logits_processor, + tokenizer, + pos_truncation_length, + neg_truncation_length, + max_padded_length, + tokenized_wm_outputs, + test_size, + tokenized_uwm_outputs, + torch_device, +): + # Split data into train and CV + train_wm_outputs, cv_wm_outputs = model_selection.train_test_split(tokenized_wm_outputs, test_size=test_size) + + train_uwm_outputs, cv_uwm_outputs = model_selection.train_test_split(tokenized_uwm_outputs, test_size=test_size) + + process_kwargs = { + "logits_processor": logits_processor, + "tokenizer": tokenizer, + "pos_truncation_length": pos_truncation_length, + "neg_truncation_length": neg_truncation_length, + "max_length": max_padded_length, + "torch_device": torch_device, + } + + # Process both train and CV data for training + wm_masks_train, wm_g_values_train = process_outputs_for_training( + [torch.tensor(outputs, device=torch_device, dtype=torch.long) for outputs in train_wm_outputs], + is_pos=True, + is_cv=False, + **process_kwargs, + ) + wm_masks_cv, wm_g_values_cv = process_outputs_for_training( + [torch.tensor(outputs, device=torch_device, dtype=torch.long) for outputs in cv_wm_outputs], + is_pos=True, + is_cv=True, + **process_kwargs, + ) + uwm_masks_train, uwm_g_values_train = process_outputs_for_training( + [torch.tensor(outputs, device=torch_device, dtype=torch.long) for outputs in train_uwm_outputs], + is_pos=False, + is_cv=False, + **process_kwargs, + ) + uwm_masks_cv, uwm_g_values_cv = process_outputs_for_training( + [torch.tensor(outputs, device=torch_device, dtype=torch.long) for outputs in cv_uwm_outputs], + is_pos=False, + is_cv=True, + **process_kwargs, + ) + + # We get list of data; here we concat all together to be passed to the detector. + def pack(mask, g_values): + mask = torch.cat(mask, dim=0) + g = torch.cat(g_values, dim=0) + return mask, g + + wm_masks_train, wm_g_values_train = pack(wm_masks_train, wm_g_values_train) + # Note: Use float instead of bool. Otherwise, the entropy calculation doesn't work + wm_labels_train = torch.ones((wm_masks_train.shape[0],), dtype=torch.float, device=torch_device) + + wm_masks_cv, wm_g_values_cv = pack(wm_masks_cv, wm_g_values_cv) + wm_labels_cv = torch.ones((wm_masks_cv.shape[0],), dtype=torch.float, device=torch_device) + + uwm_masks_train, uwm_g_values_train = pack(uwm_masks_train, uwm_g_values_train) + uwm_labels_train = torch.zeros((uwm_masks_train.shape[0],), dtype=torch.float, device=torch_device) + + uwm_masks_cv, uwm_g_values_cv = pack(uwm_masks_cv, uwm_g_values_cv) + uwm_labels_cv = torch.zeros((uwm_masks_cv.shape[0],), dtype=torch.float, device=torch_device) + + # Concat pos and negatives data together. + train_g_values = torch.cat((wm_g_values_train, uwm_g_values_train), dim=0).squeeze() + train_labels = torch.cat((wm_labels_train, uwm_labels_train), axis=0).squeeze() + train_masks = torch.cat((wm_masks_train, uwm_masks_train), axis=0).squeeze() + + cv_g_values = torch.cat((wm_g_values_cv, uwm_g_values_cv), axis=0).squeeze() + cv_labels = torch.cat((wm_labels_cv, uwm_labels_cv), axis=0).squeeze() + cv_masks = torch.cat((wm_masks_cv, uwm_masks_cv), axis=0).squeeze() + + # Shuffle data. + shuffled_idx = torch.randperm(train_g_values.shape[0]) # Use torch for GPU compatibility + + train_g_values = train_g_values[shuffled_idx] + train_labels = train_labels[shuffled_idx] + train_masks = train_masks[shuffled_idx] + + # Shuffle the cross-validation data + shuffled_idx_cv = torch.randperm(cv_g_values.shape[0]) # Use torch for GPU compatibility + cv_g_values = cv_g_values[shuffled_idx_cv] + cv_labels = cv_labels[shuffled_idx_cv] + cv_masks = cv_masks[shuffled_idx_cv] + + # Del some variables so we free up GPU memory. + del ( + wm_g_values_train, + wm_labels_train, + wm_masks_train, + wm_g_values_cv, + wm_labels_cv, + wm_masks_cv, + ) + gc.collect() + torch.cuda.empty_cache() + + return train_g_values, train_masks, train_labels, cv_g_values, cv_masks, cv_labels + + +def get_tokenized_uwm_outputs(num_negatives, neg_batch_size, tokenizer, device): + dataset, info = tfds.load("wikipedia/20230601.en", split="train", with_info=True) + dataset = dataset.take(num_negatives) + + # Convert the dataset to a DataFrame + df = tfds.as_dataframe(dataset, info) + ds = tf.data.Dataset.from_tensor_slices(dict(df)) + tf.random.set_seed(0) + ds = ds.shuffle(buffer_size=10_000) + ds = ds.batch(batch_size=neg_batch_size) + + tokenized_uwm_outputs = [] + # Pad to this length (on the right) for batching. + padded_length = 1000 + for i, batch in tqdm.tqdm(enumerate(ds)): + responses = [val.decode() for val in batch["text"].numpy()] + inputs = tokenizer( + responses, + return_tensors="pt", + padding=True, + ).to(device) + inputs = inputs["input_ids"].cpu().numpy() + if inputs.shape[1] >= padded_length: + inputs = inputs[:, :padded_length] + else: + inputs = np.concatenate( + [inputs, np.ones((neg_batch_size, padded_length - inputs.shape[1])) * tokenizer.eos_token_id], axis=1 + ) + tokenized_uwm_outputs.append(inputs) + if len(tokenized_uwm_outputs) * neg_batch_size > num_negatives: + break + return tokenized_uwm_outputs + + +def get_tokenized_wm_outputs( + model, + tokenizer, + watermark_config, + num_pos_batches, + pos_batch_size, + temperature, + max_output_len, + top_k, + top_p, + device, +): + eli5_prompts = datasets.load_dataset("Pavithree/eli5") + + wm_outputs = [] + + for batch_id in tqdm.tqdm(range(num_pos_batches)): + prompts = eli5_prompts["train"]["title"][batch_id * pos_batch_size : (batch_id + 1) * pos_batch_size] + prompts = [prompt.strip('"') for prompt in prompts] + inputs = tokenizer( + prompts, + return_tensors="pt", + padding=True, + ).to(device) + _, inputs_len = inputs["input_ids"].shape + + outputs = model.generate( + **inputs, + watermarking_config=watermark_config, + do_sample=True, + max_length=inputs_len + max_output_len, + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) + + wm_outputs.append(outputs[:, inputs_len:].cpu().detach()) + + del outputs, inputs, prompts + gc.collect() + + gc.collect() + torch.cuda.empty_cache() + return wm_outputs + + +def upload_model_to_hf(model, hf_repo_name: str, private: bool = True): + api = HfApi() + + # Check if the repository exists + try: + api.repo_info(repo_id=hf_repo_name, use_auth_token=True) + print(f"Repository '{hf_repo_name}' already exists.") + except RepositoryNotFoundError: + # If the repository does not exist, create it + print(f"Repository '{hf_repo_name}' not found. Creating it...") + create_repo(repo_id=hf_repo_name, private=private, use_auth_token=True) + print(f"Repository '{hf_repo_name}' created successfully.") + + # Push the model to the Hugging Face Hub + print(f"Uploading model to Hugging Face repo '{hf_repo_name}'...") + model.push_to_hub(repo_id=hf_repo_name, use_auth_token=True) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 7f408859c539b0..771e3e8f0ae8b8 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1301,6 +1301,8 @@ _import_structure["generation"].extend( [ "AlternatingCodebooksLogitsProcessor", + "BayesianDetectorConfig", + "BayesianDetectorModel", "BeamScorer", "BeamSearchScorer", "ClassifierFreeGuidanceLogitsProcessor", @@ -1339,6 +1341,9 @@ "StopStringCriteria", "SuppressTokensAtBeginLogitsProcessor", "SuppressTokensLogitsProcessor", + "SynthIDTextWatermarkDetector", + "SynthIDTextWatermarkingConfig", + "SynthIDTextWatermarkLogitsProcessor", "TemperatureLogitsWarper", "TopKLogitsWarper", "TopPLogitsWarper", @@ -6213,6 +6218,8 @@ ) from .generation import ( AlternatingCodebooksLogitsProcessor, + BayesianDetectorConfig, + BayesianDetectorModel, BeamScorer, BeamSearchScorer, ClassifierFreeGuidanceLogitsProcessor, @@ -6251,6 +6258,9 @@ StopStringCriteria, SuppressTokensAtBeginLogitsProcessor, SuppressTokensLogitsProcessor, + SynthIDTextWatermarkDetector, + SynthIDTextWatermarkingConfig, + SynthIDTextWatermarkLogitsProcessor, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, diff --git a/src/transformers/generation/__init__.py b/src/transformers/generation/__init__.py index 2bea00261951c7..b487fa3c7fe6ec 100644 --- a/src/transformers/generation/__init__.py +++ b/src/transformers/generation/__init__.py @@ -18,7 +18,13 @@ _import_structure = { - "configuration_utils": ["GenerationConfig", "GenerationMode", "WatermarkingConfig"], + "configuration_utils": [ + "BaseWatermarkingConfig", + "GenerationConfig", + "GenerationMode", + "SynthIDTextWatermarkingConfig", + "WatermarkingConfig", + ], "streamers": ["TextIteratorStreamer", "TextStreamer"], } @@ -71,6 +77,7 @@ "SequenceBiasLogitsProcessor", "SuppressTokensLogitsProcessor", "SuppressTokensAtBeginLogitsProcessor", + "SynthIDTextWatermarkLogitsProcessor", "TemperatureLogitsWarper", "TopKLogitsWarper", "TopPLogitsWarper", @@ -110,6 +117,9 @@ _import_structure["watermarking"] = [ "WatermarkDetector", "WatermarkDetectorOutput", + "BayesianDetectorModel", + "BayesianDetectorConfig", + "SynthIDTextWatermarkDetector", ] try: @@ -179,7 +189,13 @@ ] if TYPE_CHECKING: - from .configuration_utils import GenerationConfig, GenerationMode, WatermarkingConfig + from .configuration_utils import ( + BaseWatermarkingConfig, + GenerationConfig, + GenerationMode, + SynthIDTextWatermarkingConfig, + WatermarkingConfig, + ) from .streamers import TextIteratorStreamer, TextStreamer try: @@ -217,6 +233,7 @@ SequenceBiasLogitsProcessor, SuppressTokensAtBeginLogitsProcessor, SuppressTokensLogitsProcessor, + SynthIDTextWatermarkLogitsProcessor, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, @@ -254,6 +271,9 @@ SampleEncoderDecoderOutput, ) from .watermarking import ( + BayesianDetectorConfig, + BayesianDetectorModel, + SynthIDTextWatermarkDetector, WatermarkDetector, WatermarkDetectorOutput, ) diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 37d57248c46a17..c460a19885afc5 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -18,8 +18,9 @@ import json import os import warnings +from abc import ABC, abstractmethod from dataclasses import dataclass, is_dataclass -from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from .. import __version__ from ..configuration_utils import PretrainedConfig @@ -59,6 +60,7 @@ StaticCache, StaticCacheConfig, ) + from .logits_process import SynthIDTextWatermarkLogitsProcessor, WatermarkLogitsProcessor NEEDS_CACHE_CONFIG["quantized"] = QuantizedCacheConfig NEEDS_CACHE_CONFIG["static"] = StaticCacheConfig @@ -280,23 +282,10 @@ class GenerationConfig(PushToHubMixin): low_memory (`bool`, *optional*): Switch to sequential beam search and sequential topk for contrastive search to reduce peak memory. Used with beam search and contrastive search. - watermarking_config (`WatermarkingConfig` or `dict`, *optional*): - Arguments used to watermark the model outputs by adding a small bias to randomly selected set of "green" tokens. - If passed as `Dict`, it will be converted to a `WatermarkingConfig` internally. - See [this paper](https://arxiv.org/abs/2306.04634) for more details. Accepts the following keys: - - greenlist_ratio (`float`): - Used for watermarking. The ratio of "green" tokens used to the vocabulary size. Defaults to 0.25. - - bias (`float`): - Used with watermarking. The bias added to the selected "green" tokens' logits. Defaults to 2.0. - - hashing_key (`int`): - Hahsing key used for watermarking. Defaults to 15485863 (the millionth prime). - - seeding_scheme (`str`): - Algorithm to use for watermarking. Accepts values: - - "lefthash" (default): "green" tokens selection depend on the last token (Algorithm 2 from the paper) - - "selfhash": "green" tokens selection depends on the current token itself (Algorithm 3 from the paper) - The downside of this scheme is that it considers all possible next tokens and can be slower than "lefthash". - - context_width (`int`): - The context length of previous tokens to use in seeding. Higher context length makes watermarking more robust. + watermarking_config (`BaseWatermarkingConfig` or `dict`, *optional*): + Arguments used to watermark the model outputs by adding a small bias to randomly selected set of "green" + tokens. See the docs of [`SynthIDTextWatermarkingConfig`] and [`WatermarkingConfig`] for more + details. If passed as `Dict`, it will be converted to a `WatermarkingConfig` internally. > Parameters that define the output variables of generate @@ -430,7 +419,7 @@ def __init__(self, **kwargs): watermarking_config = kwargs.pop("watermarking_config", None) if watermarking_config is None: self.watermarking_config = None - elif isinstance(watermarking_config, WatermarkingConfig): + elif isinstance(watermarking_config, BaseWatermarkingConfig): self.watermarking_config = watermarking_config else: self.watermarking_config = WatermarkingConfig.from_dict(watermarking_config) @@ -766,7 +755,15 @@ def validate(self, is_init=False): # 6. check watermarking arguments if self.watermarking_config is not None: - if not isinstance(self.watermarking_config, WatermarkingConfig): + if not ( + isinstance(self.watermarking_config, WatermarkingConfig) + or isinstance(self.watermarking_config, SynthIDTextWatermarkingConfig) + ): + warnings.warn( + "`watermarking_config` as a dict is deprecated. Please construct `watermarking_config` object with " + "`WatermarkingConfig` or `SynthIDTextWatermarkingConfig` class.", + FutureWarning, + ) self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config) self.watermarking_config.validate() @@ -1287,52 +1284,20 @@ def update(self, **kwargs): @dataclass -class WatermarkingConfig: - """ - Class that holds arguments for watermark generation and should be passed into `GenerationConfig` during `generate`. - See [this paper](https://arxiv.org/abs/2306.04634) for more details on the arguments. - - Accepts the following keys: - - greenlist_ratio (`float`): - Used for watermarking. The ratio of "green" tokens used to the vocabulary size. Defaults to 0.25. - - bias (`float`): - Used with watermarking. The bias added to the selected "green" tokens' logits. Defaults to 2.0. - - hashing_key (`int`): - Hashing key used for watermarking. Defaults to 15485863 (the millionth prime). - - seeding_scheme (`str`): - Algorithm to use for watermarking. Accepts values: - - "lefthash" (default): "green" tokens selection depend on the last token (Algorithm 2 from the paper) - - "selfhash": "green" tokens selection depends on the current token itself (Algorithm 3 from the paper) - The downside of this scheme is that it considers all possible next tokens and can be slower than "lefthash". - - context_width(`int`): - The context length of previous tokens to use in seeding. Higher context length makes watermarking more robust. - """ - - def __init__( - self, - greenlist_ratio: Optional[float] = 0.25, - bias: Optional[float] = 2.0, - hashing_key: Optional[int] = 15485863, - seeding_scheme: Optional[str] = "lefthash", - context_width: Optional[int] = 1, - ): - self.greenlist_ratio = greenlist_ratio - self.bias = bias - self.hashing_key = hashing_key - self.seeding_scheme = seeding_scheme - self.context_width = context_width +class BaseWatermarkingConfig(ABC): + """Generic watermarking config""" @classmethod def from_dict(cls, config_dict, **kwargs): """ - Constructs a WatermarkingConfig instance from a dictionary of parameters. + Constructs a BaseWatermarkingConfig instance from a dictionary of parameters. Args: config_dict (Dict[str, Any]): Dictionary containing configuration parameters. **kwargs: Additional keyword arguments to override dictionary values. Returns: - WatermarkingConfig: Instance of WatermarkingConfig constructed from the dictionary. + BaseWatermarkingConfig: Instance of BaseWatermarkingConfig constructed from the dictionary. """ config = cls(**config_dict) to_remove = [] @@ -1394,6 +1359,49 @@ def update(self, **kwargs): if hasattr(self, key): setattr(self, key, value) + @abstractmethod + def validate(self): ... + + @abstractmethod + def construct_processor(self, vocab_size): ... + + +@dataclass +class WatermarkingConfig(BaseWatermarkingConfig): + """ + Class that holds arguments for watermark generation and should be passed into `GenerationConfig` during `generate`. + See [this paper](https://arxiv.org/abs/2306.04634) for more details on the arguments. + + Accepts the following keys: + - greenlist_ratio (`float`): + Used for watermarking. The ratio of "green" tokens used to the vocabulary size. Defaults to 0.25. + - bias (`float`): + Used with watermarking. The bias added to the selected "green" tokens' logits. Defaults to 2.0. + - hashing_key (`int`): + Hashing key used for watermarking. Defaults to 15485863 (the millionth prime). + - seeding_scheme (`str`): + Algorithm to use for watermarking. Accepts values: + - "lefthash" (default): "green" tokens selection depend on the last token (Algorithm 2 from the paper) + - "selfhash": "green" tokens selection depends on the current token itself (Algorithm 3 from the paper) + The downside of this scheme is that it considers all possible next tokens and can be slower than "lefthash". + - context_width(`int`): + The context length of previous tokens to use in seeding. Higher context length makes watermarking more robust. + """ + + def __init__( + self, + greenlist_ratio: Optional[float] = 0.25, + bias: Optional[float] = 2.0, + hashing_key: Optional[int] = 15485863, + seeding_scheme: Optional[str] = "lefthash", + context_width: Optional[int] = 1, + ): + self.greenlist_ratio = greenlist_ratio + self.bias = bias + self.hashing_key = hashing_key + self.seeding_scheme = seeding_scheme + self.context_width = context_width + def validate(self): watermark_missing_arg_msg = ( "Some of the keys in `watermarking_config` are defined incorrectly. `{key}` should be {correct_value}` " @@ -1423,3 +1431,104 @@ def validate(self): found_value=self.context_width, ), ) + + def construct_processor(self, vocab_size: int, device) -> "WatermarkLogitsProcessor": + return WatermarkLogitsProcessor( + vocab_size=vocab_size, + device=device, + greenlist_ratio=self.greenlist_ratio, + bias=self.bias, + hashing_key=self.hashing_key, + seeding_scheme=self.seeding_scheme, + context_width=self.context_width, + ) + + +@dataclass +class SynthIDTextWatermarkingConfig(BaseWatermarkingConfig): + """ + Class that holds arguments for watermark generation and should be passed into `GenerationConfig` during `generate`. + See [this paper](https://www.nature.com/articles/s41586-024-08025-4) for more details on the arguments. + + Args: + ngram_len (`int`): + Ngram length. + keys (`List[int]`): + A sequence of watermarking keys, one for each depth. + context_history_size (`int`, *optional*, defaults to 1024): + Size of the tensor to keep track of seen contexts. + sampling_table_seed (`int`, *optional*, defaults to 0): + Random seed to generate the sampling table. + sampling_table_size (`int`, *optional*, defaults to 65536): + Size of the sampling table. + skip_first_ngram_calls (`bool`, *optional*, defaults to `False`): + Whether to skip first ngram calls. + debug_mode (`bool`, optional, *optional*, defaults to `False`): + Logits are modified to uniform one got before watermarking modification is applied. This is to test the + implementation. + + Examples: + ```python + >>> from transformers import AutoModelForCausalLM, AutoTokenizer, SynthIDTextWatermarkingConfig + + >>> tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b-it') + >>> model = AutoModelForCausalLM.from_pretrained('google/gemma-2-2b-it') + + >>> # SynthID Text configuration + >>> watermarking_config = SynthIDTextWatermarkingConfig( + ... keys=[654, 400, 836, 123, 340, 443, 597, 160, 57], + ... ngram_len=5, + ... ) + + >>> # Generation with watermarking + >>> tokenized_prompts = tokenizer(["your prompts here"]) + >>> output_sequences = model.generate( + ... **tokenized_prompts, watermarking_config=watermarking_config, do_sample=True, + ... ) + >>> watermarked_text = tokenizer.batch_decode(output_sequences) + ``` + """ + + def __init__( + self, + ngram_len: int, + keys: List[int], + context_history_size: int = 1024, + sampling_table_seed: int = 0, + sampling_table_size: int = 2**16, + skip_first_ngram_calls: bool = False, + debug_mode: bool = False, + ): + self.ngram_len = ngram_len + self.keys = keys + self.sampling_table_size = sampling_table_size + self.sampling_table_seed = sampling_table_seed + self.context_history_size = context_history_size + self.skip_first_ngram_calls = skip_first_ngram_calls + self.debug_mode = debug_mode + + def validate(self): + watermark_missing_arg_msg = ( + "Some of the keys in `watermarking_config` are defined incorrectly. `{key}` should be {correct_value}` " + "but found {found_value}" + ) + if self.sampling_table_size > 2**24: + raise ValueError( + watermark_missing_arg_msg.format( + key="sampling_table_size", + correct_value="< 2**24", + found_value=self.sampling_table_size, + ), + ) + + def construct_processor(self, vocab_size: int, device) -> "WatermarkLogitsProcessor": + return SynthIDTextWatermarkLogitsProcessor( + ngram_len=self.ngram_len, + keys=self.keys, + sampling_table_size=self.sampling_table_size, + sampling_table_seed=self.sampling_table_seed, + context_history_size=self.context_history_size, + device=device, + skip_first_ngram_calls=self.skip_first_ngram_calls, + debug_mode=self.debug_mode, + ) diff --git a/src/transformers/generation/logits_process.py b/src/transformers/generation/logits_process.py index d88c7a17d892d4..fde95c7a85652f 100644 --- a/src/transformers/generation/logits_process.py +++ b/src/transformers/generation/logits_process.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2020 The HuggingFace Inc. team +# Copyright 2024 The HuggingFace Inc. team and Google DeepMind. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -2460,6 +2460,7 @@ def _score_rejection_sampling(self, input_seq: torch.LongTensor, scores: torch.F final_greenlist.append(greedy_predictions[i]) return torch.tensor(final_greenlist, device=input_seq.device) + @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: if input_ids.shape[-1] < self.context_width: logger.warning( @@ -2477,3 +2478,478 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to scores_processed[b_idx, greenlist_ids] = scores_processed[b_idx, greenlist_ids] + self.bias return scores_processed + + +class SynthIDTextWatermarkState: + """SynthID watermarking state.""" + + def __init__( + self, + batch_size: int, + ngram_len: int, + context_history_size: int, + device: torch.device, + ): + """Initializes the state. + + Args: + batch_size (`int`): Batch size. + ngram_len (`int`): Ngram length. + context_history_size (`int`): Size of the tensor to keep track of seen contexts. + device (`int`): Device to use. + """ + self.context = torch.zeros( + (batch_size, ngram_len - 1), + dtype=torch.int64, + device=device, + ) + self.context_history = torch.zeros( + (batch_size, context_history_size), + dtype=torch.int64, + device=device, + ) + self.num_calls = 0 + + +class SynthIDTextWatermarkLogitsProcessor(LogitsProcessor): + r""" + Logits processor that implements watermarking techniques for text generation models. + This class facilitates the application of SynthID text watermarking, a method for embedding imperceptible signals + into generated text to aid in detecting synthetic content. It operates by subtly manipulating the probabilities of + token selection during text generation in a manner that can be reliably recovered later for verification. + + Key Features: + * **State Management:** Maintains internal state to track token sequences and generate watermarking keys + dynamically. + + * **Key Generation:** Computes hashes based on token sequences and watermarking parameters to create unique keys + for each position. + + * **G-Value Sampling:** Employs a pre-computed sampling table to sample watermarking values (g-values) based on + the generated keys. + + * **Score Adjustment:** Applies calculated g-values to modify token probabilities during generation, embedding the + watermark. + + * **Context Repetition Handling:** Incorporates logic to avoid watermarking tokens in repeated contexts, + preserving naturalness. + + * **EOS Token Masking:** Supports masking end-of-sentence tokens to prevent their inclusion in watermarking + calculations. + + * **Utility Functions:** Provides functions to compute g-values directly, check for context repetition, create + EOS token masks, and estimate expected mean g-values. + + Refer to paper url: https://www.nature.com/articles/s41586-024-08025-4 for more details around this. + + Args: + ngram_len (`int`): + Ngram length. + keys (`List[int]`): + A sequence of watermarking keys, one for each depth. + sampling_table_size (`int`): + Size of the sampling table. + sampling_table_seed (`int`): + Random seed to generate the sampling table. + context_history_size (`int`): + Size of the tensor to keep track of seen contexts. + device (`torch.device`): + Device to use. + skip_first_ngram_calls (`bool`, *optional*, defaults to `False`): + Whether to skip first ngram calls. + debug_mode (`bool`, optional, *optional*, defaults to `False`): + Logits are modified to uniform one got before watermarking modification is applied. This is to test the + implementation. + + Examples: + ```python + >>> from transformers import AutoModelForCausalLM, AutoTokenizer, SynthIDTextWatermarkingConfig + + >>> tokenizer = AutoTokenizer.from_pretrained('google/gemma-2-2b-it') + >>> model = AutoModelForCausalLM.from_pretrained('google/gemma-2-2b-it') + + >>> # SynthID Text configuration + >>> watermarking_config = SynthIDTextWatermarkingConfig( + ... keys=[654, 400, 836, 123, 340, 443, 597, 160, 57], + ... ngram_len=5, + ... ) + + >>> # Generation with watermarking + >>> tokenized_prompts = tokenizer(["your prompts here"]) + >>> output_sequences = model.generate( + ... **tokenized_prompts, watermarking_config=watermarking_config, do_sample=True, + ... ) + >>> watermarked_text = tokenizer.batch_decode(output_sequences) + ``` + """ + + def __init__( + self, + ngram_len: int, + keys: List[int], + sampling_table_size: int, + sampling_table_seed: int, + context_history_size: int, + device: torch.device, + skip_first_ngram_calls: bool = False, + debug_mode: bool = False, + ): + self.ngram_len = ngram_len + self.keys = torch.tensor(keys, device=device) + + generator = torch.Generator(device=device).manual_seed(sampling_table_seed) + # A random sampling table is pre-computed and modulo table size is applied to map from a hash of ngram keys to + # g values, this is similar to the hashtable implementation used in + # https://github.com/facebookresearch/three_bricks. We note that the hashing employed in this repository is + # different from that used to watermark the Gemini App, and hence the detectors trained based on the + # hashing in this repository will not transfer to text generated by the Gemini App. + self.sampling_table = torch.randint( + low=0, + high=2, + size=(sampling_table_size,), + generator=generator, + device=device, + ) + self.context_history_size = context_history_size + self.device = device + self.state = None + self.skip_first_ngram_calls = skip_first_ngram_calls + self.debug_mode = debug_mode + + def _init_state(self, batch_size: int): + """Initializes the state.""" + self.state = SynthIDTextWatermarkState( + batch_size=batch_size, + ngram_len=self.ngram_len, + context_history_size=self.context_history_size, + device=self.device, + ) + + def update_scores(self, scores: torch.FloatTensor, g_values: torch.FloatTensor) -> torch.FloatTensor: + """Updates scores using the g values. + + We assume that the scores are in the log space. + Args: + scores (`torch.FloatTensor`): Scores (batch_size, vocab_size). + g_values (`torch.FloatTensor`): G valus (batch_size, vocab_size, depth). + + Returns: + Updated scores (batch_size, vocab_size). + """ + _, _, depth = g_values.shape + + probs = torch.softmax(scores, dim=1) + + for i in range(depth): + g_values_at_depth = g_values[:, :, i] + g_mass_at_depth = (g_values_at_depth * probs).sum(axis=1, keepdims=True) + probs = probs * (1 + g_values_at_depth - g_mass_at_depth) + + log_probs = torch.log(probs) + log_probs = torch.where(torch.isfinite(log_probs), log_probs, torch.finfo(log_probs.dtype).min) + return log_probs + + @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + self._check_input_ids_shape(input_ids) + batch_size, vocab_size = scores.shape + + if self.debug_mode: + scores = torch.ones_like(scores) + + # Currently indices is just a arange to compute watermarking on the desnse logits. + all_indices = torch.stack([torch.arange(vocab_size, device=self.device) for _ in range(batch_size)]) + + if self.state is None: + # Initialize watermarking state if it does not exist. + self._init_state(batch_size) + else: + # Append last input id (which is the input id added in last call) to the + # previous context so we have the context to be used for current + # watermarking. + self.state.context = torch.concat( + (self.state.context, input_ids[:, -1:]), + dim=1, + ) + self.state.context = self.state.context[:, 1:] + + if self.state is None: + raise ValueError("self.state can't be None! Call `self._init_state` to initialize the state.") + + self.state.num_calls += 1 + + # Don't watermark the first ngram_len - 1 tokens if set. + if self.skip_first_ngram_calls and self.state.num_calls < self.ngram_len: + return scores + + # 2. Generate random keys for each ngram key combination. + ngram_keys, hash_result_with_just_context = self._compute_keys(self.state.context, all_indices) + # ngram_keys shape [batch_size, top_k, depth] + + # 3. Sample g values. + g_values = self.sample_g_values(ngram_keys) + # g_values shape [batch_size, top_k, depth] + + # 4. Modify scores. + updated_scores = self.update_scores(scores, g_values) + # updated scores shape [batch_size, top_k] + + # 5. Check if the current watermarking context was previously used, if yes skip watermarking. + hash_result_with_just_context = hash_result_with_just_context[:, None] + is_repeated_context = (self.state.context_history == hash_result_with_just_context).any( + dim=1, + keepdim=True, + ) + self.state.context_history = torch.concat( + (hash_result_with_just_context, self.state.context_history), + dim=1, + )[:, :-1] + + updated_watermarked_scores = torch.where( + is_repeated_context, + input=scores, + other=updated_scores, + ) + return updated_watermarked_scores + + def accumulate_hash( + self, + current_hash: torch.LongTensor, + data: torch.LongTensor, + multiplier: int = 6364136223846793005, + increment: int = 1, + ) -> torch.LongTensor: + """ + Accumulate hash of data on current hash. + + Method uses adapted linear congruential generator with newlib/musl parameters. + + This function has following property - + f(x, data[T]) = f(f(x, data[:T - 1]), data[T]) + + This function expects current_hash.shape and data.shape[:-1] to + match/broadcastable. + + Args: + current_hash (`torch.LongTensor`): + (shape,) + data (`torch.LongTensor`): + (shape, tensor_len) + multiplier (`int`, optional, *optional*, defaults to 6364136223846793005): + multiplier of linear congruential generator + increment (`int`, optional, *optional*, defaults to 1): + increment of linear congruential generator + + Returns: + updated hash (shape,) + """ + for i in range(data.shape[-1]): + current_hash = torch.add(current_hash, data[..., i]) + current_hash = torch.mul(current_hash, multiplier) + current_hash = torch.add(current_hash, increment) + return current_hash + + def compute_ngram_keys(self, ngrams: torch.LongTensor) -> torch.LongTensor: + """Computes random keys for each ngram and depth. + + Args: + ngrams (`torch.LongTensor`): + Ngrams (batch_size, num_ngrams, ngram_len). + + Returns: + ngram keys (batch_size, num_ngrams, depth). + """ + if len(ngrams.shape) != 3: + raise ValueError( + "Ngrams should be of shape (batch_size, num_ngrams, ngram_len), but" f" is {ngrams.shape}" + ) + if ngrams.shape[2] != self.ngram_len: + raise ValueError( + "Ngrams should be of shape (batch_size, num_ngrams, ngram_len)," + f" where ngram_len is {self.ngram_len}, but is {ngrams.shape}" + ) + batch_size, _, _ = ngrams.shape + + hash_result = torch.ones(batch_size, device=self.device, dtype=torch.long) + # hash_result shape [batch_size,] + # ngrams shape [batch_size, num_ngrams, ngram_len] + hash_result = torch.vmap(self.accumulate_hash, in_dims=(None, 1), out_dims=1)(hash_result, ngrams) + # hash_result shape [batch_size, num_ngrams] + + keys = self.keys[None, None, :, None] + # hash_result shape [batch_size, num_ngrams] + # keys shape [1, 1, depth, 1] + hash_result = torch.vmap(self.accumulate_hash, in_dims=(None, 2), out_dims=2)(hash_result, keys) + # hash_result shape [batch_size, num_ngrams, depth] + + return hash_result + + def _compute_keys( + self, n_minus_1_grams: torch.LongTensor, indices: torch.LongTensor + ) -> Tuple[torch.LongTensor, torch.LongTensor]: + """Computes random keys for each ngram and depth. + + Args: + n_minus_1_grams (`torch.LongTensor`): + Ngrams (batch_size, ngram_len - 1). + indices (`torch.LongTensor`): + indices of the continuations (batch_size, num_indices) + + Returns: + Ngram keys (batch_size, num_indices, depth). + """ + batch_size, _ = n_minus_1_grams.shape + + hash_result = torch.ones(batch_size, device=self.device, dtype=torch.long) + # First hash n_minus_1 gram, for each batch entry we have a single + # n_minus_1 gram context. + # hash_result shape [batch_size] + # n_minus_1_gram shape [batch_size, ngram_len - 1] + hash_result_with_just_context = self.accumulate_hash(hash_result, n_minus_1_grams) + # hash_result shape [batch_size,] + # Indices is of shape [batch_size, num_indices], so we make it + # [batch_size, num_indices, 1] so we can vmap over num_indices dim. + hash_result = torch.vmap(self.accumulate_hash, in_dims=(None, 1), out_dims=1)( + hash_result_with_just_context, indices[:, :, None] + ) + # hash_result shape [batch_size, num_indices] + # Basically we have a hash for each batch entry and each indices + # Now we add watermarking keys to this hash. + # keys are of shape [depth,] + # We add batch, num_indices and data dimension to this making it + # [1, 1, depth, 1]. + # So we can vmap over the depth dimension for compute_hash + keys = self.keys[None, None, :, None] + hash_result = torch.vmap(self.accumulate_hash, in_dims=(None, 2), out_dims=2)(hash_result, keys) + # hash_result shape should be [batch_size, num_indices, depth] + return hash_result, hash_result_with_just_context + + def sample_g_values(self, ngram_keys: torch.LongTensor) -> torch.LongTensor: + """ + Samples g values from Bernoulli distribution. + + It is not possible to pass random keys in a vectorized way in torch. Instead + we pre-compute a random sampling table, and use apply modulo table size to + map from ngram keys (int64) to g values. + + Args: + ngram_keys (`torch.LongTensor`): + Random keys (batch_size, num_ngrams, depth). + + Returns: + G values (batch_size, num_ngrams, depth). + """ + (sampling_table_size,) = self.sampling_table.shape + sampling_table = self.sampling_table.reshape((1, 1, sampling_table_size)) + ngram_keys = ngram_keys % sampling_table_size + return torch.take_along_dim(sampling_table, indices=ngram_keys, dim=2) + + def _check_input_ids_shape(self, input_ids: torch.LongTensor): + """Checks the shape of input ids.""" + if len(input_ids.shape) != 2: + raise ValueError("Input ids should be of shape (batch_size, input_len), but is" f" {input_ids.shape}") + + def compute_g_values(self, input_ids: torch.LongTensor) -> torch.LongTensor: + """ + Computes g values for each ngram from the given sequence of tokens. + + Args: + input_ids (`torch.LongTensor`): + Input token ids (batch_size, input_len). + + Returns: + G values (batch_size, input_len - (ngram_len - 1), depth). + """ + self._check_input_ids_shape(input_ids) + ngrams = input_ids.unfold(dimension=1, size=self.ngram_len, step=1) + ngram_keys = self.compute_ngram_keys(ngrams) + return self.sample_g_values(ngram_keys) + + def compute_context_repetition_mask(self, input_ids: torch.LongTensor) -> torch.LongTensor: + """ + Computes repetition mask. + + 0 and 1 stand for repeated and not repeated context n-1 grams respectively. + + Args: + input_ids (`torch.LongTensor`): + Input token ids (batch_size, input_len). + + Returns: + Repetitions mask (batch_size, input_len - (ngram_len - 1)). + """ + self._check_input_ids_shape(input_ids) + batch_size, _ = input_ids.shape + state = SynthIDTextWatermarkState( + batch_size=batch_size, + ngram_len=self.ngram_len, + context_history_size=self.context_history_size, + device=self.device, + ) + contexts = input_ids[:, :-1].unfold( + dimension=1, + size=self.ngram_len - 1, + step=1, + ) + _, num_contexts, _ = contexts.shape + + are_repeated_contexts = [] + for i in range(num_contexts): + context = contexts[:, i, :] + hash_result = torch.ones(batch_size, device=self.device, dtype=torch.long) + context_hash = self.accumulate_hash(hash_result, context)[:, None] + is_repeated_context = (state.context_history == context_hash).any( + dim=1, + keepdim=True, + ) + are_repeated_contexts.append(is_repeated_context) + state.context_history = torch.concat( + (context_hash, state.context_history), + dim=1, + )[:, :-1] + are_repeated_contexts = torch.concat(are_repeated_contexts, dim=1) + + return torch.logical_not(are_repeated_contexts) + + def compute_eos_token_mask(self, input_ids: torch.LongTensor, eos_token_id: int) -> torch.LongTensor: + """ + Computes repetitions mask. + + 1 stands for ngrams that don't contain EOS tokens and vice versa. + + Args: + input_ids (`torch.LongTensor`): + Input token ids (batch_size, input_len). + eos_token_id (`int`): + EOS token ID. + + Returns: + EOS token mask (batch_size, input_len). + """ + self._check_input_ids_shape(input_ids) + noneos_masks = [] + all_eos_equated = input_ids == eos_token_id + for eos_equated in all_eos_equated: + nonzero_idx = torch.nonzero(eos_equated) + noneos_mask = torch.ones_like(eos_equated) + if nonzero_idx.shape[0] != 0: + noneos_mask[nonzero_idx[0][0] :] = 0 + noneos_masks.append(noneos_mask) + return torch.stack(noneos_masks, dim=0) + + def expected_mean_g_value(self, vocab_size: int, coinflip_prob: float = 0.5) -> float: + """ + Compute expected mean g-value after watermarking, assuming uniform LM dist. + + This is the theoretical expected value for single-layer watermarking. + + Args: + vocab_size (`int`): + The size of the vocabulary. + coinflip_prob arg_name (`float`, *optional*, defaults to 0.5): + Probability of 1 in boolean prf. + + Returns: + The expected mean g-value for watermarked text. + """ + return coinflip_prob + coinflip_prob * (1 - coinflip_prob) * (1 - (1 / vocab_size)) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index c399a8a2c829c7..700ea0443f4dbd 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -92,7 +92,6 @@ TopPLogitsWarper, TypicalLogitsWarper, UnbatchedClassifierFreeGuidanceLogitsProcessor, - WatermarkLogitsProcessor, ) from .stopping_criteria import ( ConfidenceCriteria, @@ -1011,15 +1010,7 @@ def _get_logits_processor( ) if generation_config.watermarking_config is not None: processors.append( - WatermarkLogitsProcessor( - vocab_size=self.config.vocab_size, - device=device, - greenlist_ratio=generation_config.watermarking_config.greenlist_ratio, - bias=generation_config.watermarking_config.bias, - hashing_key=generation_config.watermarking_config.hashing_key, - seeding_scheme=generation_config.watermarking_config.seeding_scheme, - context_width=generation_config.watermarking_config.context_width, - ) + generation_config.watermarking_config.construct_processor(self.config.vocab_size, device) ) # TODO (joao): find a strategy to specify the order of the processors diff --git a/src/transformers/generation/watermarking.py b/src/transformers/generation/watermarking.py index e998d996ec4159..da90c03dd0da89 100644 --- a/src/transformers/generation/watermarking.py +++ b/src/transformers/generation/watermarking.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2024 The HuggingFace Inc. team +# Copyright 2024 The HuggingFace Inc. team and Google DeepMind. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,19 +16,22 @@ import collections from dataclasses import dataclass from functools import lru_cache -from typing import Dict, Optional, Union +from typing import Any, Dict, Optional, Tuple, Union import numpy as np +import torch +from torch import nn +from torch.nn import BCELoss -from ..configuration_utils import PretrainedConfig -from ..utils import is_torch_available, logging -from .configuration_utils import WatermarkingConfig +from ..modeling_utils import PreTrainedModel +from ..utils import ModelOutput, is_torch_available, logging +from .configuration_utils import PretrainedConfig, WatermarkingConfig if is_torch_available(): import torch - from .logits_process import WatermarkLogitsProcessor + from .logits_process import SynthIDTextWatermarkLogitsProcessor, WatermarkLogitsProcessor logger = logging.get_logger(__name__) @@ -237,3 +240,310 @@ def __call__( confidence=confidence, ) return prediction + + +class BayesianDetectorConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`BayesianDetectorModel`]. It is used to + instantiate a Bayesian Detector model according to the specified arguments. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + watermarking_depth (`int`, *optional*): + The number of tournament layers. + base_rate (`float1`, *optional*, defaults to 0.5): + Prior probability P(w) that a text is watermarked. + """ + + def __init__(self, watermarking_depth: int = None, base_rate: float = 0.5, **kwargs): + self.watermarking_depth = watermarking_depth + self.base_rate = base_rate + # These can be set later to store information about this detector. + self.model_name = None + self.watermarking_config = None + + super().__init__(**kwargs) + + def set_detector_information(self, model_name, watermarking_config): + self.model_name = model_name + self.watermarking_config = watermarking_config + + +@dataclass +class BayesianWatermarkDetectorModelOutput(ModelOutput): + """ + Base class for outputs of models predicting if the text is watermarked. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + posterior_probabilities (`torch.FloatTensor` of shape `(1,)`): + Multiple choice classification loss. + """ + + loss: Optional[torch.FloatTensor] = None + posterior_probabilities: Optional[torch.FloatTensor] = None + + +class BayesianDetectorWatermarkedLikelihood(nn.Module): + """Watermarked likelihood model for binary-valued g-values. + + This takes in g-values and returns p(g_values|watermarked). + """ + + def __init__(self, watermarking_depth: int): + """Initializes the model parameters.""" + super().__init__() + self.watermarking_depth = watermarking_depth + self.beta = torch.nn.Parameter(-2.5 + 0.001 * torch.randn(1, 1, watermarking_depth)) + self.delta = torch.nn.Parameter(0.001 * torch.randn(1, 1, self.watermarking_depth, watermarking_depth)) + + def _compute_latents(self, g_values: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Computes the unique token probability distribution given g-values. + + Args: + g_values (`torch.Tensor` of shape `(batch_size, seq_len, watermarking_depth)`): + PRF values. + + Returns: + p_one_unique_token and p_two_unique_tokens, both of shape + [batch_size, seq_len, watermarking_depth]. p_one_unique_token[i,t,l] + gives the probability of there being one unique token in a tournament + match on layer l, on timestep t, for batch item i. + p_one_unique_token[i,t,l] + p_two_unique_token[i,t,l] = 1. + """ + # Tile g-values to produce feature vectors for predicting the latents + # for each layer in the tournament; our model for the latents psi is a + # logistic regression model psi = sigmoid(delta * x + beta). + + # [batch_size, seq_len, watermarking_depth, watermarking_depth] + x = torch.repeat_interleave(torch.unsqueeze(g_values, dim=-2), self.watermarking_depth, axis=-2) + + # mask all elements above -1 diagonal for autoregressive factorization + x = torch.tril(x, diagonal=-1) + + # [batch_size, seq_len, watermarking_depth] + # (i, j, k, l) x (i, j, k, l) -> (i, j, k) einsum equivalent + logits = (self.delta[..., None, :] @ x.type(self.delta.dtype)[..., None]).squeeze() + self.beta + + p_two_unique_tokens = torch.sigmoid(logits) + p_one_unique_token = 1 - p_two_unique_tokens + return p_one_unique_token, p_two_unique_tokens + + def forward(self, g_values: torch.Tensor) -> torch.Tensor: + """Computes the likelihoods P(g_values|watermarked). + + Args: + g_values (`torch.Tensor` of shape `(batch_size, seq_len, watermarking_depth)`): + g-values (values 0 or 1) + + Returns: + p(g_values|watermarked) of shape [batch_size, seq_len, watermarking_depth]. + """ + p_one_unique_token, p_two_unique_tokens = self._compute_latents(g_values) + + # P(g_tl | watermarked) is equal to + # 0.5 * [ (g_tl+0.5) * p_two_unique_tokens + p_one_unique_token]. + return 0.5 * ((g_values + 0.5) * p_two_unique_tokens + p_one_unique_token) + + +class BayesianDetectorModel(PreTrainedModel): + r""" + Bayesian classifier for watermark detection. + + This detector uses Bayes' rule to compute a watermarking score, which is the sigmoid of the log of ratio of the + posterior probabilities P(watermarked|g_values) and P(unwatermarked|g_values). Please see the section on + BayesianScore in the paper for further details. + Paper URL: https://www.nature.com/articles/s41586-024-08025-4 + + Note that this detector only works with non-distortionary Tournament-based watermarking using the Bernoulli(0.5) + g-value distribution. + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`BayesianDetectorConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. + """ + + config_class = BayesianDetectorConfig + base_model_prefix = "model" + + def __init__(self, config): + super().__init__(config) + + self.watermarking_depth = config.watermarking_depth + self.base_rate = config.base_rate + self.likelihood_model_watermarked = BayesianDetectorWatermarkedLikelihood( + watermarking_depth=self.watermarking_depth + ) + self.prior = torch.nn.Parameter(torch.tensor([self.base_rate])) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, nn.Parameter): + module.weight.data.normal_(mean=0.0, std=0.02) + + def _compute_posterior( + self, + likelihoods_watermarked: torch.Tensor, + likelihoods_unwatermarked: torch.Tensor, + mask: torch.Tensor, + prior: float, + ) -> torch.Tensor: + """ + Compute posterior P(w|g) given likelihoods, mask and prior. + + Args: + likelihoods_watermarked (`torch.Tensor` of shape `(batch, length, depth)`): + Likelihoods P(g_values|watermarked) of g-values under watermarked model. + likelihoods_unwatermarked (`torch.Tensor` of shape `(batch, length, depth)`): + Likelihoods P(g_values|unwatermarked) of g-values under unwatermarked model. + mask (`torch.Tensor` of shape `(batch, length)`): + A binary array indicating which g-values should be used. g-values with mask value 0 are discarded. + prior (`float`): + the prior probability P(w) that the text is watermarked. + + Returns: + Posterior probability P(watermarked|g_values), shape [batch]. + """ + mask = torch.unsqueeze(mask, dim=-1) + prior = torch.clamp(prior, min=1e-5, max=1 - 1e-5) + log_likelihoods_watermarked = torch.log(torch.clamp(likelihoods_watermarked, min=1e-30, max=float("inf"))) + log_likelihoods_unwatermarked = torch.log(torch.clamp(likelihoods_unwatermarked, min=1e-30, max=float("inf"))) + log_odds = log_likelihoods_watermarked - log_likelihoods_unwatermarked + + # Sum relative surprisals (log odds) across all token positions and layers. + relative_surprisal_likelihood = torch.einsum("i...->i", log_odds * mask) + + # Compute the relative surprisal prior + relative_surprisal_prior = torch.log(prior) - torch.log(1 - prior) + + # Combine prior and likelihood. + # [batch_size] + relative_surprisal = relative_surprisal_prior + relative_surprisal_likelihood + + # Compute the posterior probability P(w|g) = sigmoid(relative_surprisal). + return torch.sigmoid(relative_surprisal) + + def forward( + self, + g_values: torch.Tensor, + mask: torch.Tensor, + labels: Optional[torch.Tensor] = None, + loss_batch_weight=1, + return_dict=False, + ) -> BayesianWatermarkDetectorModelOutput: + """ + Computes the watermarked posterior P(watermarked|g_values). + + Args: + g_values (`torch.Tensor` of shape `(batch_size, seq_len, watermarking_depth, ...)`): + g-values (with values 0 or 1) + mask: + A binary array shape [batch_size, seq_len] indicating which g-values should be used. g-values with mask + value 0 are discarded. + + Returns: + p(watermarked | g_values), of shape [batch_size]. + """ + + likelihoods_watermarked = self.likelihood_model_watermarked(g_values) + likelihoods_unwatermarked = 0.5 * torch.ones_like(g_values) + out = self._compute_posterior( + likelihoods_watermarked=likelihoods_watermarked, + likelihoods_unwatermarked=likelihoods_unwatermarked, + mask=mask, + prior=self.prior, + ) + + loss = None + if labels is not None: + loss_fct = BCELoss() + loss_unwweight = torch.sum(self.likelihood_model_watermarked.delta**2) + loss_weight = loss_unwweight * loss_batch_weight + loss = loss_fct(torch.clamp(out, 1e-5, 1 - 1e-5), labels) + loss_weight + + if not return_dict: + return (out,) if loss is None else (out, loss) + + return BayesianWatermarkDetectorModelOutput(loss=loss, posterior_probabilities=out) + + +class SynthIDTextWatermarkDetector: + r""" + SynthID text watermark detector class. + + This class has to be initialized with the trained bayesian detector module check script + in examples/synthid_text/detector_training.py for example in training/saving/loading this + detector module. The folder also showcases example use case of this detector. + + Parameters: + detector_module ([`BayesianDetectorModel`]): + Bayesian detector module object initialized with parameters. + Check examples/research_projects/synthid_text/detector_training.py for usage. + logits_processor (`SynthIDTextWatermarkLogitsProcessor`): + The logits processor used for watermarking. + tokenizer (`Any`): + The tokenizer used for the model. + + Examples: + ```python + >>> from transformers import ( + ... AutoTokenizer, BayesianDetectorModel, SynthIDTextWatermarkLogitsProcessor, SynthIDTextWatermarkDetector + ... ) + + >>> # Load the detector. See examples/research_projects/synthid_text for training a detector. + >>> detector_model = BayesianDetectorModel.from_pretrained("joaogante/dummy_synthid_detector") + >>> logits_processor = SynthIDTextWatermarkLogitsProcessor( + ... **detector_model.config.watermarking_config, device="cpu" + ... ) + >>> tokenizer = AutoTokenizer.from_pretrained(detector_model.config.model_name) + >>> detector = SynthIDTextWatermarkDetector(detector_model, logits_processor, tokenizer) + + >>> # Test whether a certain string is watermarked + >>> test_input = tokenizer(["This is a test input"], return_tensors="pt") + >>> is_watermarked = detector(test_input.input_ids) + ``` + """ + + def __init__( + self, + detector_module: BayesianDetectorModel, + logits_processor: SynthIDTextWatermarkLogitsProcessor, + tokenizer: Any, + ): + self.detector_module = detector_module + self.logits_processor = logits_processor + self.tokenizer = tokenizer + + def __call__(self, tokenized_outputs: torch.Tensor): + # eos mask is computed, skip first ngram_len - 1 tokens + # eos_mask will be of shape [batch_size, output_len] + eos_token_mask = self.logits_processor.compute_eos_token_mask( + input_ids=tokenized_outputs, + eos_token_id=self.tokenizer.eos_token_id, + )[:, self.logits_processor.ngram_len - 1 :] + + # context repetition mask is computed + context_repetition_mask = self.logits_processor.compute_context_repetition_mask( + input_ids=tokenized_outputs, + ) + # context repitition mask shape [batch_size, output_len - (ngram_len - 1)] + + combined_mask = context_repetition_mask * eos_token_mask + + g_values = self.logits_processor.compute_g_values( + input_ids=tokenized_outputs, + ) + # g values shape [batch_size, output_len - (ngram_len - 1), depth] + return self.detector_module(g_values, combined_mask) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index e109ea659c74e0..36e1ff2cfe65c4 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -191,6 +191,20 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class BayesianDetectorConfig(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class BayesianDetectorModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class BeamScorer(metaclass=DummyObject): _backends = ["torch"] @@ -457,6 +471,27 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class SynthIDTextWatermarkDetector(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class SynthIDTextWatermarkingConfig(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class SynthIDTextWatermarkLogitsProcessor(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class TemperatureLogitsWarper(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/generation/test_logits_process.py b/tests/generation/test_logits_process.py index a5d3ab37efa51e..aeebb5c4c53d24 100644 --- a/tests/generation/test_logits_process.py +++ b/tests/generation/test_logits_process.py @@ -16,6 +16,7 @@ import unittest from typing import List, Union +import numpy as np from parameterized import parameterized from transformers import is_torch_available @@ -48,6 +49,7 @@ PrefixConstrainedLogitsProcessor, RepetitionPenaltyLogitsProcessor, SequenceBiasLogitsProcessor, + SynthIDTextWatermarkLogitsProcessor, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, @@ -975,3 +977,187 @@ def test_watermarking_processor(self): scores_wo_bias = scores[:, -1].clone() out = watermark(input_ids=input_ids, scores=scores) self.assertTrue((out[:, 1] == scores_wo_bias + watermark.bias).all()) + + @parameterized.expand([(5, 3, 10000), (10, 5, 1000)]) + def test_synthidtext_watermarking_processor_bias_uniformity(self, ngram_len, num_layers, vocab_size): + """Test SynthID watermarked distribution bias uniformity over iterations.""" + torch.manual_seed(0) + np.random.seed(0) + watermarking_config = { + "ngram_len": ngram_len, + "keys": np.random.randint(low=0, high=2**16, size=(num_layers,)), + "sampling_table_size": 2**16, + "sampling_table_seed": 0, + "context_history_size": 512, + "device": torch_device, + } + batch_size = 100000 + ngrams = torch.randint( + low=0, + high=vocab_size, + size=(batch_size, ngram_len), + device=torch_device, + ) + + logits_processor = SynthIDTextWatermarkLogitsProcessor(**watermarking_config) + g_values = logits_processor.compute_g_values(ngrams) + g_values_mean = torch.mean(torch.mean(g_values.float(), dim=0)) + self.assertAlmostEqual(g_values_mean, 0.5, delta=0.01) + + @parameterized.expand([(10000, 3), (1000, 20)]) + def test_synthidtext_watermark_processor_bias_uniformity_across_vocab(self, vocab_size, num_layers): + """Test SynthID watermarked distribution bias uniformity over vocabs of the model.""" + batch_size = 1000 + ngram_len = 5 + torch.manual_seed(0) + np.random.seed(0) + watermarking_config = { + "ngram_len": ngram_len, + "keys": np.random.randint(low=0, high=2**16, size=(num_layers,)), + "sampling_table_size": 2**16, + "sampling_table_seed": 0, + "context_history_size": 512, + "device": torch_device, + } + n_minus_1_grams = torch.randint( + low=0, + high=vocab_size, + size=(batch_size, watermarking_config["ngram_len"] - 1), + device=torch_device, + ) + + logits_processor = SynthIDTextWatermarkLogitsProcessor(**watermarking_config) + ngram_keys, _ = logits_processor._compute_keys( + n_minus_1_grams, + torch.stack([torch.arange(vocab_size, device=torch_device) for _ in range(batch_size)]), + ) + + g_values = logits_processor.sample_g_values(ngram_keys) + # g_values shape should be [batch_size, vocab_size, num_layers] + g_values_mean = torch.mean(torch.mean(g_values.float(), dim=1)) + self.assertAlmostEqual(g_values_mean, 0.5, delta=0.001) + + @parameterized.expand([(2, "uniform"), (10, "uniform"), (2, "random"), (10, "random")]) + def test_synthidtext_watermark_processor_distributional_convergence(self, vocab_size, logits_type): + """Check if watermarked distribution converges to unwatermarked logits distribution.""" + batch_size = 1500 + num_keys = 1000 + + updated_softmaxes = 0 + np.random.seed(0) + torch.manual_seed(0) + if logits_type == "uniform": + fixed_logits = torch.ones((batch_size, vocab_size), device=torch_device) + elif logits_type == "random": + fixed_logits = torch.rand( + ( + 1, + vocab_size, + ), + device=torch_device, + ) + fixed_logits = fixed_logits.repeat(batch_size, 1) + else: + raise ValueError(f"Unrecognized logits_type {logits_type}") + for _ in range(num_keys): + watermarking_config = { + "ngram_len": 5, + "keys": np.random.randint(0, 10**9, size=(1,), dtype=np.int64), + "sampling_table_size": 2**16, + "sampling_table_seed": 0, + "context_history_size": 1024, + "device": torch_device, + } + + logits_processor = SynthIDTextWatermarkLogitsProcessor(**watermarking_config) + + ngrams = torch.randint( + low=0, + high=vocab_size, + size=(batch_size, watermarking_config["ngram_len"]), + device=torch_device, + ) + + # Insert ngram-1 into logit_processor state. + for idx in range(watermarking_config["ngram_len"] - 1): + _ = logits_processor(ngrams[:, :idx], fixed_logits) + + updated_scores = logits_processor(ngrams, fixed_logits) + updated_softmaxes += torch.nn.functional.softmax(updated_scores, dim=1).cpu().numpy() + + updated_softmaxes = np.mean(updated_softmaxes, axis=0) / num_keys + is_close = torch.all( + torch.isclose( + torch.tensor(updated_softmaxes, device=torch_device), + torch.nn.Softmax()(fixed_logits[0]), # Take any batch entry, all are same. + atol=1e-3, + rtol=0, + ) + ) + self.assertTrue(is_close) + + @parameterized.expand([(2, 10, 1, 0.01), (100, 5, 1, 0.01), (100, 10, 2, 0.02)]) + def test_synthidtext_watermark_processor_bias_test(self, vocab_size, ngram_len, num_layers, atol): + """Test SynthID watermarking bias matches theoretical value.""" + batch_size = 20000 + generator = torch.Generator(device=torch_device).manual_seed(0) + np.random.seed(0) + + keys = [np.random.randint(0, 10**9) for _ in range(num_layers)] + # Use 10**9 rather than vocab_size to ensure variety in (n-1)-grams. + context = torch.randint( + low=0, + high=10**9, + size=(batch_size, ngram_len - 1), + dtype=torch.int64, + generator=generator, + device=torch_device, + ) + + context_history_size = 1024 + logits_processor = SynthIDTextWatermarkLogitsProcessor( + ngram_len=ngram_len, + keys=keys, + sampling_table_size=2**16, + sampling_table_seed=0, + context_history_size=context_history_size, + device=torch_device, + ) + + scores = torch.ones( + (batch_size, vocab_size), + dtype=torch.float64, + device=torch_device, + ) + # Init state of the logits processor. + logits_processor(context, scores) + # insert context into the state. + for idx in range(1, ngram_len - 1): + _ = logits_processor(context[:, :idx], scores) + + updated_scores = logits_processor(context, scores) + + probs = torch.nn.functional.softmax(updated_scores, dim=1) + generator = torch.Generator(device=torch_device).manual_seed(0) + next_tokens = torch.multinomial( + probs, + num_samples=1, + generator=generator, + ) + + ngrams = torch.concat((context, next_tokens), dim=1) + g_values = logits_processor.compute_g_values(ngrams) + mean_g_values = g_values.mean(dtype=torch.float64, dim=(0, 1)) + + expected_mean_g_value = logits_processor.expected_mean_g_value( + vocab_size=vocab_size, + ) + is_close = torch.all( + torch.isclose( + mean_g_values, + torch.tensor(expected_mean_g_value, dtype=torch.float64, device=torch_device), + atol=atol, + rtol=0, + ) + ) + self.assertTrue(is_close) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 996d95eb80ff9b..4e5d8f30265995 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -84,6 +84,7 @@ SampleEncoderDecoderOutput, StoppingCriteria, StoppingCriteriaList, + SynthIDTextWatermarkingConfig, WatermarkDetector, WatermarkingConfig, ) @@ -2517,9 +2518,9 @@ def test_beam_search_low_memory(self): self.assertListEqual(low_output.tolist(), high_output.tolist()) @slow - def test_watermark_generation(self): - tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2") - model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2").to(torch_device) + def test_green_red_watermark_generation(self): + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") tokenizer.pad_token_id = tokenizer.eos_token_id model_inputs = tokenizer("I will be", return_tensors="pt").to(torch_device) input_len = model_inputs["input_ids"].shape[-1] @@ -2548,6 +2549,61 @@ def test_watermark_generation(self): self.assertListEqual(detection_out_watermarked.prediction.tolist(), [True]) self.assertListEqual(detection_out.prediction.tolist(), [False]) + """Check the mean bias inserted by the watermarking algorithm.""" + + @slow + def test_synthid_text_watermark_generation_mean_expected_bias(self): + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + tokenizer.pad_token_id = tokenizer.eos_token_id + model_inputs = tokenizer("I will be", return_tensors="pt").to(torch_device) + input_len = 5 + batch_size = 200 + + # generation should work with both input types: WatermarkingConfig or Dict, so let's check it here :) + watermark_config = SynthIDTextWatermarkingConfig(keys=[10, 20], ngram_len=5, debug_mode=True) + logits_processor = watermark_config.construct_processor(model.config.vocab_size, torch_device) + mean_g_values_repeats = [] + for _ in range(40): + input_ids = torch.zeros( + (batch_size, input_len), + dtype=torch.int64, + device=torch_device, + ) + model_inputs = { + "input_ids": input_ids, + "attention_mask": torch.ones_like(input_ids, device=torch_device), + } + output = model.generate( + **model_inputs, watermarking_config=watermark_config, do_sample=True, max_length=500, top_k=1000 + ) + g_values = logits_processor.compute_g_values(input_ids=output[:, input_len:]) + context_repetition_mask = logits_processor.compute_context_repetition_mask( + input_ids=output[:, input_len:], + ).unsqueeze(dim=2) + + mean_g_values = torch.masked.mean( + g_values, + mask=context_repetition_mask, + dim=0, + keepdim=True, + dtype=torch.float64, + ) + mean_g_values_repeats.append(mean_g_values) + + mean_g_values = torch.concat(mean_g_values_repeats, dim=0).mean(dim=0) + expected_mean_g_value = logits_processor.expected_mean_g_value( + vocab_size=model.config.vocab_size, + ) + atol = 0.03 + is_close = torch.isclose( + mean_g_values, + torch.tensor(expected_mean_g_value, dtype=torch.float64), + atol=atol, + rtol=0, + ) + self.assertTrue(torch.all(is_close)) + @slow def test_beam_search_example_integration(self): # PT-only test: TF doesn't have a BeamSearchScorer