|
| 1 | +"""Mask detected PII in a dataset. |
| 2 | +""" |
| 3 | + |
| 4 | +import argparse |
| 5 | +import json |
| 6 | +import logging |
| 7 | +import random |
| 8 | +import time |
| 9 | +from functools import partial |
| 10 | +from pprint import pformat |
| 11 | + |
| 12 | +from datasets import load_dataset |
| 13 | +from datasets.utils.logging import set_verbosity_info |
| 14 | + |
| 15 | +from manual_sharding import save_manual_shards |
| 16 | +from utils import get_replacements, redact_pii_batch |
| 17 | + |
| 18 | + |
| 19 | +def parseArgs(): |
| 20 | + parser = argparse.ArgumentParser(description="PII detection and redaction") |
| 21 | + parser.add_argument( |
| 22 | + "--dataset_name", |
| 23 | + default="bigcode/pii-for-code", |
| 24 | + type=str, |
| 25 | + help="HF repo name/path of the dataset.", |
| 26 | + ) |
| 27 | + parser.add_argument( |
| 28 | + "--num_load_proc", |
| 29 | + default=64, |
| 30 | + type=int, |
| 31 | + help="Number of processes to use for loading the dataset", |
| 32 | + ) |
| 33 | + parser.add_argument( |
| 34 | + "--lang", |
| 35 | + default="ada", |
| 36 | + type=str, |
| 37 | + help="Language to redact PII in.", |
| 38 | + ) |
| 39 | + parser.add_argument( |
| 40 | + "--text_column", |
| 41 | + default="content", |
| 42 | + type=str, |
| 43 | + help="Text column to use, if will be renamed to content", |
| 44 | + ) |
| 45 | + parser.add_argument( |
| 46 | + "--split", |
| 47 | + default="train", |
| 48 | + type=str, |
| 49 | + help="Dataset split to process", |
| 50 | + ) |
| 51 | + parser.add_argument( |
| 52 | + "--batch_size", |
| 53 | + default=100, |
| 54 | + type=int, |
| 55 | + help="Batch size for the PII detection/redaction", |
| 56 | + ) |
| 57 | + parser.add_argument( |
| 58 | + "--seed", |
| 59 | + default=0, |
| 60 | + type=int, |
| 61 | + help="Seed for random", |
| 62 | + ) |
| 63 | + parser.add_argument( |
| 64 | + "--num_proc", |
| 65 | + default=96, |
| 66 | + type=int, |
| 67 | + help="Number of processes to use for the PII detection/redaction", |
| 68 | + ) |
| 69 | + parser.add_argument( |
| 70 | + "--no_redaction", |
| 71 | + action="store_true", |
| 72 | + help="If set, we don't perform redaction", |
| 73 | + ) |
| 74 | + parser.add_argument( |
| 75 | + "--load_replacements", |
| 76 | + default=True, |
| 77 | + help="If set, we load the replacements from file replacements.json", |
| 78 | + ) |
| 79 | + parser.add_argument( |
| 80 | + "--add_reference_text", |
| 81 | + default=False, |
| 82 | + type=bool, |
| 83 | + help="If True we add the reference text with PII between delimiters \ |
| 84 | + in the redacted text -used for visualization-", |
| 85 | + ) |
| 86 | + parser.add_argument( |
| 87 | + "--check_all_files", |
| 88 | + action="store_true", |
| 89 | + help="If set, we check all files, not only the ones that contain PII", |
| 90 | + ) |
| 91 | + parser.add_argument( |
| 92 | + "--check_sampling_size", |
| 93 | + default=0, |
| 94 | + type=int, |
| 95 | + help="Number of samples to check for PII", |
| 96 | + ) |
| 97 | + # for saving the dataset: either push to HF or save locally with datasets or save manual shards |
| 98 | + parser.add_argument( |
| 99 | + "--save_mode", |
| 100 | + default="manual_shards", |
| 101 | + type=str, |
| 102 | + choices=["hub", "local", "manual_shards"], |
| 103 | + help="How to save the dataset", |
| 104 | + ) |
| 105 | + parser.add_argument( |
| 106 | + "--save_mode_checks", |
| 107 | + default="hub", |
| 108 | + type=str, |
| 109 | + choices=["hub", "local", "manual_shards"], |
| 110 | + help="How to save the checks dataset", |
| 111 | + ) |
| 112 | + # add argument for name of dataset on the hub |
| 113 | + parser.add_argument( |
| 114 | + "--target_dataset", |
| 115 | + default="bigcode-pii2", |
| 116 | + type=str, |
| 117 | + help="HF repo name of the target dataset in save_mode=hub.", |
| 118 | + ) |
| 119 | + parser.add_argument( |
| 120 | + "--hub_username", |
| 121 | + default="loubnabnl", |
| 122 | + type=str, |
| 123 | + help="Username for the hub", |
| 124 | + ) |
| 125 | + parser.add_argument( |
| 126 | + "--save_path_disk", |
| 127 | + default="bigcode-pii2-local", |
| 128 | + type=str, |
| 129 | + help="Path to save the dataset on disk in save_mode=local.", |
| 130 | + ) |
| 131 | + return parser.parse_args() |
| 132 | + |
| 133 | + |
| 134 | +def get_check_ds(ds, args): |
| 135 | + if not args.check_all_files: |
| 136 | + ds_checks = ds.filter( |
| 137 | + lambda exs: exs["modified"], |
| 138 | + batched=True, |
| 139 | + batch_size=args.batch_size, |
| 140 | + num_proc=args.num_proc, |
| 141 | + ) |
| 142 | + else: |
| 143 | + ds_checks = ds |
| 144 | + if not args.check_sampling_size: |
| 145 | + sampling_size = len(ds_checks) |
| 146 | + idx_samples = random.sample( |
| 147 | + range(len(ds_checks)), min(len(ds_checks), sampling_size) |
| 148 | + ) |
| 149 | + ds_checks = ds_checks.select(idx_samples) |
| 150 | + |
| 151 | + return ds_checks |
| 152 | + |
| 153 | + |
| 154 | +def check_uniques(example, uniques): |
| 155 | + """Check if current id is still in set of unique id and remove if true.""" |
| 156 | + if example["id"] in uniques: |
| 157 | + uniques.remove(example["id"]) |
| 158 | + return True |
| 159 | + else: |
| 160 | + return False |
| 161 | + |
| 162 | + |
| 163 | +def main(): |
| 164 | + set_verbosity_info() |
| 165 | + logger = logging.getLogger(__name__) |
| 166 | + logger.setLevel(logging.INFO) |
| 167 | + logging.basicConfig( |
| 168 | + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| 169 | + datefmt="%m/%d/%Y %H:%M:%S", |
| 170 | + level=logging.INFO, |
| 171 | + handlers=[logging.FileHandler("pii.log"), logging.StreamHandler()], |
| 172 | + ) |
| 173 | + args = parseArgs() |
| 174 | + logger.info( |
| 175 | + f"** The job is running with the following arguments: **\n{args}\n **** " |
| 176 | + ) |
| 177 | + |
| 178 | + logger.info(f" ===== Loading {args.dataset_name} =====") |
| 179 | + ds = load_dataset( |
| 180 | + args.dataset_name, |
| 181 | + split=args.split, |
| 182 | + use_auth_token=True, |
| 183 | + num_proc=args.num_load_proc, |
| 184 | + ) |
| 185 | + if args.text_column != "content": |
| 186 | + ds = ds.rename_column(args.text_column, "content") |
| 187 | + |
| 188 | + logger.info(f" ===== Deduplicating dataset =====") |
| 189 | + # Deduplication based on ids |
| 190 | + uniques = set(ds["id"]) |
| 191 | + frac = len(uniques) / len(ds) |
| 192 | + logger.info(f"Fraction of duplicates: {1-frac:.2%}") |
| 193 | + logger.info(f"Dataset:\n{ds}") |
| 194 | + # Deduplicate data and apply heuristics |
| 195 | + t_start = time.time() |
| 196 | + ds_pii = ds.filter( |
| 197 | + check_uniques, fn_kwargs={"uniques": uniques}, num_proc=args.num_proc |
| 198 | + ) |
| 199 | + logger.info(f"Time to filter dataset: {time.time()-t_start:.2f}") |
| 200 | + logger.info(f"Dataset after dedup:\n{ds_pii}") |
| 201 | + |
| 202 | + logger.info( |
| 203 | + f"Number of samples that contained PII: {sum([1 if x['entities'] else 0 for x in ds_pii])}" |
| 204 | + ) |
| 205 | + logger.info( |
| 206 | + f"Total number of secrets found: {sum([len(x['entities']) for x in ds_pii])}" |
| 207 | + ) |
| 208 | + |
| 209 | + # redact PII in the dataset |
| 210 | + logger.info(f" ===== Applying PII redaction =====") |
| 211 | + random.seed(args.seed) |
| 212 | + |
| 213 | + replacements = get_replacements() |
| 214 | + with open("replacements.json", "w") as f: |
| 215 | + json.dump(replacements, f) |
| 216 | + logging.info(f"Using the following replacements:\n{pformat(replacements)}") |
| 217 | + ds_pii = ds_pii.map( |
| 218 | + partial( |
| 219 | + redact_pii_batch, |
| 220 | + replacements=replacements, |
| 221 | + add_references=args.add_reference_text, |
| 222 | + ), |
| 223 | + batched=True, |
| 224 | + batch_size=args.batch_size, |
| 225 | + num_proc=args.num_proc, |
| 226 | + load_from_cache_file=False, |
| 227 | + ) |
| 228 | + logging.info(f"Dataset info after PII redaction:\n{ds_pii}") |
| 229 | + |
| 230 | + # check the dataset |
| 231 | + logger.info( |
| 232 | + f" ===== Checking {args.check_sampling_size} samples from those modified in the dataset =====" |
| 233 | + ) |
| 234 | + ds_checks = get_check_ds(ds_pii, args) |
| 235 | + |
| 236 | + # save checks dataset |
| 237 | + if len(ds_checks) == 0: |
| 238 | + logger.info("Dataset was empty. Not saving anything.") |
| 239 | + else: |
| 240 | + logger.info(f"Checks dataset info {ds_checks}") |
| 241 | + if args.save_mode_checks == "hub": |
| 242 | + logger.info( |
| 243 | + f"Pushing the checks dataset to the Hub as {args.target_dataset}_checks" |
| 244 | + ) |
| 245 | + ds_checks.push_to_hub(args.target_dataset + "_checks") |
| 246 | + |
| 247 | + elif args.save_mode_checks == "local": |
| 248 | + logger.info(f"Saving the checks dataset to disk") |
| 249 | + ds_checks.save_to_disk(args.save_path_disk + "_checks") |
| 250 | + |
| 251 | + elif args.save_mode_checks == "manual_shards": |
| 252 | + logger.info(f"Saving the checks dataset in manual shards") |
| 253 | + save_manual_shards( |
| 254 | + ds_checks, |
| 255 | + user=args.hub_username, |
| 256 | + remote_dataset_repo=args.target_dataset + "_checks", |
| 257 | + ) |
| 258 | + |
| 259 | + logger.info("Removing columns that are not needed for the final dataset") |
| 260 | + columns = ["content", "modified", "entities"] |
| 261 | + if args.add_reference_text: |
| 262 | + columns.append("references") |
| 263 | + ds_pii = ds_pii.remove_columns(columns) |
| 264 | + ds_pii = ds_pii.rename_column("new_content", "content") |
| 265 | + logger.info(f"Dataset info after removing columns:\n{ds_pii}") |
| 266 | + |
| 267 | + # save the final dataset |
| 268 | + if args.save_mode == "hub": |
| 269 | + logger.info( |
| 270 | + f" ===== Pushing the dataset to the Hub as: {args.target_dataset} =====" |
| 271 | + ) |
| 272 | + ds_pii.push_to_hub(args.target_dataset) |
| 273 | + |
| 274 | + elif args.save_mode == "local": |
| 275 | + logger.info(f" ===== Saving the dataset to disk =====") |
| 276 | + ds_pii.save_to_disk(args.save_path_disk) |
| 277 | + |
| 278 | + elif args.save_mode == "manual_shards": |
| 279 | + logger.info(f" ===== Saving the dataset in manual shards =====") |
| 280 | + save_manual_shards( |
| 281 | + ds_pii, user=args.hub_username, remote_dataset_repo=args.target_dataset |
| 282 | + ) |
| 283 | + |
| 284 | + logger.info(f" ===== Dataset saved successfully =====") |
| 285 | + |
| 286 | + |
| 287 | +if __name__ == "__main__": |
| 288 | + main() |
0 commit comments