From 2632489e02a70f2fde462783bcd6a3a78667eb55 Mon Sep 17 00:00:00 2001 From: Jon Durbin Date: Fri, 11 Aug 2023 11:53:49 -0400 Subject: [PATCH] Dataset merging/culling/LIMAfication --- airoboros/embeddings.py | 56 +++++++ airoboros/self_instruct.py | 300 +++++++++++++++++++++++++++---------- example-config.yaml | 7 + setup.py | 6 +- 4 files changed, 287 insertions(+), 82 deletions(-) create mode 100644 airoboros/embeddings.py diff --git a/airoboros/embeddings.py b/airoboros/embeddings.py new file mode 100644 index 0000000..8651119 --- /dev/null +++ b/airoboros/embeddings.py @@ -0,0 +1,56 @@ +import numpy as np +import torch +from typing import Any, List + + +# Max tokens for our embedding model. This code is really designed for the gte-* +# series, e.g.: https://huggingface.co/thenlper/gte-small +# but could in theory be generated to work with other models I suspect. +MAX_LENGTH = 512 + + +def average_pool( + last_hidden_states: torch.Tensor, attention_mask: torch.Tensor +) -> torch.Tensor: + last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) + return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] + + +def calculate_fragment_embeddings(model: Any, fragment: str) -> List[float]: + """Calculate vector embeddings for a single input fragment, which is smaller than the + max model length. + """ + with torch.no_grad(): + return model.encode(fragment, normalize_embeddings=True) + + +def calculate_embeddings( + input_text: str, model: Any, tokenizer: Any, truncate=True +) -> List[float]: + """Calculate the vector embeddings for the specified input text. + + 1. split the text based on the model's max sequence length + 2. calculate the embeddings for each chunk + 3. calculate the average embedding across all chunks + """ + if truncate: + return calculate_fragment_embeddings(model, input_text) + + # Tokenize the input, and convert tokens into chunks based on max model size. + inputs = tokenizer(input_text, padding=False, truncation=False, return_tensors="pt") + chunks = [ + torch.Tensor(inputs["input_ids"][0][i : i + MAX_LENGTH].tolist()).int() + for i in range(0, len(inputs["input_ids"][0]), MAX_LENGTH) + ] + fragments = [tokenizer.decode(chunk) for chunk in chunks] + + # Now, calculate embeddings for each fragment. + all_embeddings = [] + lengths = [] + for fragment in fragments: + lengths.append(len(fragment)) + all_embeddings.append(calculate_fragment_embeddings(model, fragment)) + + # Finally, calculate the average across all fragments. + embeddings = np.average(all_embeddings, axis=0, weights=lengths) + return embeddings / np.linalg.norm(embeddings) diff --git a/airoboros/self_instruct.py b/airoboros/self_instruct.py index 7978707..001d66f 100644 --- a/airoboros/self_instruct.py +++ b/airoboros/self_instruct.py @@ -3,6 +3,7 @@ import asyncio import backoff import datetime +import faiss import os import json import math @@ -16,8 +17,10 @@ from collections import defaultdict from loguru import logger from time import sleep +from tqdm import tqdm from typing import List, Dict, Any from uuid import uuid4 +from airoboros.embeddings import calculate_embeddings from airoboros.exceptions import ( RateLimitError, TooManyRequestsError, @@ -27,8 +30,9 @@ ContextLengthExceededError, BadResponseError, ) -from langchain.vectorstores import Chroma -from langchain.embeddings import HuggingFaceEmbeddings +from fast_sentence_transformers import FastSentenceTransformer +from sentence_transformers import SentenceTransformer +from transformers import AutoTokenizer # Defaults and constants. MAX_DOCSTORE_SIZE = 15000 @@ -55,7 +59,6 @@ def __init__(self, *, config_path: str = "config.yaml"): self.config_path = config_path self.load_config() self.instructor_counts = defaultdict(int) - self.docstore_lock = asyncio.Semaphore(1) def load_config(self): """Load an advanced configuration from a YAML file.""" @@ -101,6 +104,20 @@ def load_config(self): self.language = raw_config.get("language") or "English" self.default_flesch = raw_config.get("default_flesch") or READABILITY_HINT + # Embedding model. + model_name = raw_config.get("embedding_model") or "thenlper/gte-small" + + # Hacky, but we'll load this twice, the first time to get dimension, since + # it's not accessible in the Fast (cpu) version. + model = SentenceTransformer(model_name) + self.embedding_dimension = model.get_sentence_embedding_dimension() + model = None + if raw_config.get("embedding_device") == "cuda": + self.embedding_model = SentenceTransformer(model_name, device="cuda") + else: + self.embedding_model = FastSentenceTransformer(model_name, device="cpu") + self.embedding_tokenizer = AutoTokenizer.from_pretrained(model_name) + # Validate the model for each generator. self.instructors = raw_config.get("instructors") self.validate_model(self.model) @@ -110,8 +127,8 @@ def load_config(self): self.validate_model(config["model"]) valid_models[config["model"]] = True - def initialize_docstores(self): - """Initialize the in-memory vector databases used to check prompt uniqueness.""" + def initialize_index(self): + """Initialize the in-memory faiss index to check prompt uniqueness.""" docs = [] if os.path.exists(self.output_path): if self.overwrite: @@ -124,7 +141,9 @@ def initialize_docstores(self): with open(self.output_path, "r") as infile: for line in infile.readlines(): task = json.loads(line) - self.instructor_counts[task.get("category", "general")] += 1 + category = task.get("category", "general") + if category != "chat" or "chat" in category: + self.instructor_counts[category] += 1 if task["category"] != "chat": docs.append(task["instruction"]) logger.info( @@ -136,26 +155,17 @@ def initialize_docstores(self): raise RuntimeError( f"{self.output_path} already exists, but overwrite and append are false!" ) - logger.info( - "Initializing in-memory document store for similarity comparison..." - ) + logger.info("Initializing faiss index similarity comparison...") if not docs: docs = ["__initialize__"] - self.embeddings = HuggingFaceEmbeddings() - batches = [ - docs[i * MAX_DOCSTORE_SIZE : (i + 1) * MAX_DOCSTORE_SIZE] - for i in range((len(docs) + MAX_DOCSTORE_SIZE - 1) // MAX_DOCSTORE_SIZE) - ] - logger.info(f"Need to create {len(batches)} unique docstores...") - self.docstores = [ - Chroma.from_texts(batch, self.embeddings) for batch in batches - ] - self.docstore_size = len(batches[-1]) - if self.docstore_size >= MAX_DOCSTORE_SIZE: - logger.info("Initializing fresh docstore due to doc count...") - self.docstore_size = 0 - self.docstores.append( - Chroma.from_texts(["__initialize__"], self.embeddings) + + # This is a bit slow. + self.index = faiss.IndexFlatL2(self.embedding_dimension) + for doc in docs: + self.index.add( + calculate_embeddings( + doc, self.embedding_model, self.embedding_tokenizer + ) ) def validate_model(self, model): @@ -396,93 +406,224 @@ async def is_decent_response(self, item): ) return True if "GOOD" in result: - logger.success(f"Good response [{item['category']}]: {preview}") + logger.info(f"Judge: good [{item['category']}]: {preview}") return True - logger.warning(f"Bad response [{item['category']}]: {preview}") + logger.info(f"Judge: bad [{item['category']}]: {preview}") return False - async def cull(self, input_path: str, output_path: str) -> None: + async def judge(self, instructions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Filter only the "good" instructions, as determined by an LLM.""" + batch_size = ( + self.raw_config.get("judge", {}).get("batch_size") + or self.default_batch_size + ) + batches = np.array_split( + instructions, math.ceil(len(instructions) / batch_size) + ) + quality = [] + for batch in batches: + results = await asyncio.gather( + *[self.is_decent_response(item) for item in batch] + ) + for idx in range(len(batch)): + if results[idx]: + quality.append(batch[idx]) + return quality + + async def cull(self, input_paths: List[str], output_path: str) -> None: """Use the LLM to filter bad responses based on a set of rules. - :param input_path: Path to the input JSONL file to filter. - :type input_path: str + :param input_paths: List of paths to the input JSONL file(s) to filter. + :type input_paths: List[str] :param output_path: Path to save the "good" instructions to. :type output_path: str """ - with open(input_path) as infile: - original = [json.loads(line) for line in infile.readlines()] + original = [] + categories = defaultdict(list) + for path in input_paths: + with open(path) as infile: + for line in infile.readlines(): + item = json.loads(line) + original.append(item) + category = item.get("category", "general") + if category == "reasoning_or_math": + category = "orca" + categories[category].append(item) + + # Deduplicate and select best items. output_file = open(output_path, "w") - instructions = [] - for instruction in original: - if instruction.get("category") in [ + max_k = self.raw_config.get("cull_max_k") + if max_k is None: + max_k = 100 + for category, items in categories.items(): + # Skip categories that are too weird/cumbersome to score properly. + if category in [ + "orca", "chat", "detailed_writing", "contextual", "counterfactual_contextual", ]: - output_file.write(json.dumps(instruction) + "\n") - else: - instructions.append(instruction) - try: - batch_size = ( - self.raw_config.get("scoring", {}).get("batch_size") or 5 - ) # self.default_batch_size - batches = np.array_split( - instructions, math.ceil(len(instructions) / batch_size) + for item in items: + output_file.write(json.dumps(item) + "\n") + output_file.flush() + continue + + # Add all of the items in this category to a faiss index. + logger.info( + f"Initializing faiss index for {category} with {len(items)} documents..." ) - for batch in batches: - results = await asyncio.gather( - *[self.is_decent_response(item) for item in batch] + index = faiss.IndexFlatL2(self.embedding_dimension) + all_embeddings = [] + for item in tqdm(items): + all_embeddings.append( + np.array( + [ + calculate_embeddings( + "\n".join([item["instruction"], item["response"]]), + self.embedding_model, + self.embedding_tokenizer, + truncate=False, + ) + ] + ) ) - for idx in range(len(batch)): - if results[idx]: - output_file.write(json.dumps(batch[idx]) + "\n") - output_file.flush() - finally: - output_file.close() + index.add(all_embeddings[-1]) + + # Here's where it's tricky... + # + # We need to iterate through the objects, finding all matches that are under are + # specified similarity score for this category. + # + # Once we've found all of the matches, we can select the "best" by first using + # the LLM to judge whether the response is high quality or not, but only if it's + # a category that we can score well. + # + # If multiple instructions remain that are high quality, we can use other metrics, + # such as length and complexity of speech to select the best. + # + # If none of the matching instructions are high quality, I guess we just remove + # all of them? + purged = set([]) + saved = set([]) + min_score = ( + self.instructors.get(category, {}).get("min_docsearch_score") + or self.min_docsearch_score + ) + for idx in range(len(items)): + if idx in purged or idx in saved: + continue + distances, indices = index.search( + all_embeddings[idx], k=min(len(items), max_k) + ) + distances = distances[0].tolist()[1:] + indices = indices[0].tolist()[1:] + batch = [items[idx]] + batch_idx = [idx] + for check_idx in range(len(distances)): + # Don't check items before this one (since they would have already been checked). + if indices[check_idx] < idx: + continue - async def is_too_similar(self, instruction: str, min_score: float = None): - """Check the similarity of a new instruction to the existing set. + # Don't check items we've judged as duplicate or low-quality. + if indices[check_idx] in purged: + continue - :param instruction: The instruction string to compare. - :type instruction: str + # Ignore and stop checking if we exceed the min score. + if distances[check_idx] > min_score: + break + batch.append(items[indices[check_idx]]) + batch_idx.append(indices[check_idx]) + + # Score the batch. + quality = await self.judge(batch) + if not quality: + for purge_idx in range(len(batch)): + purged.add(batch_idx[purge_idx]) + preview = items[batch_idx[purge_idx]][ + "instruction" + ].splitlines()[0][0:100] + logger.warning(f"Removing low-quality instruction: {preview}") + continue + + # Only one high-quality result, keep it. + if len(quality) == 1: + preview = quality[0]["instruction"].splitlines()[0][0:100] + logger.success(f"Saving high-quality instruction: {preview}") + output_file.write(json.dumps(quality[0]) + "\n") + output_file.flush() + found = False + for save_idx in range(len(batch)): + if batch[save_idx] == quality[0]: + if not found: + saved.add(batch_idx[save_idx]) + found = True + else: + purged.add(batch_idx[save_idx]) + continue + + # This is kind of a hacky fallback, but it's fast and easy. + longest = sorted( + quality, key=lambda x: len(x["instruction"] + x["response"]) + )[-1] + found = False + for purge_idx in range(len(batch)): + if batch[purge_idx] == longest and not found: + found = True + saved.add(batch_idx[purge_idx]) + if batch[purge_idx] != longest or found: + purged.add(batch_idx[purge_idx]) + found = True + preview = longest["instruction"].splitlines()[0][0:100] + logger.success(f"Saving high-quality, longest instruction: {preview}") + output_file.write(json.dumps(longest) + "\n") + output_file.flush() + output_file.close() + + async def is_too_similar( + self, input_text: str, min_score: float = None, index=None + ): + """Check the similarity of an input string against an index. + + :param input_text: The input string to calculate similarity of. + :type input_text: str :param min_score: Minimum document similarity score to consider unique. :type min_score: float - :return: Boolean indicating if the instruction is too similar or not. + :param index: Optional faiss index to query against, defaults to main index. + :type index: failss index + + :return: Boolean indicating if the text is too similar or not. :rtype: bool """ - async with self.docstore_lock: - min_ = 1.0 - for docstore in self.docstores: - similar = docstore.similarity_search_with_score(instruction, k=1) - for _, score in similar: - if score < min_: - min_ = score - if min_ <= min_score: - logger.warning( - f"Skipping instruction, too similar [{min_}]: {instruction}" - ) - return True - return False + index = index or self.index + distance, _ = index.search( + calculate_embeddings( + input_text, self.embedding_model, self.embedding_tokenizer + ), + 1, + ) + if min_score is None: + min_score = self.min_docsearch_score + if distance <= min_score: + logger.warning(f"Too similar [{distance}]: {input_text}") + return True + return False def persist(self, item): - """Persist a single item to the output file and docstore.""" + """Persist a single item to the output file and add it to the index.""" skip_counting = item.pop("skip_counting", False) self.outfile.write(json.dumps(item) + "\n") self.outfile.flush() if item["category"] != "chat": - self.docstores[-1].add_texts([item["instruction"]]) - self.docstore_size += 1 - if self.docstore_size >= MAX_DOCSTORE_SIZE: - logger.info("Initializing new docstore...") - self.docstores.append( - Chroma.from_texts(["__initialize__"], self.embeddings) + self.index.add( + calculate_embeddings( + item["instruction"], self.embedding_model, self.embedding_tokenizer ) - self.docstore_size = 0 + ) if not skip_counting: self.instructor_counts[item["category"]] += 1 @@ -568,7 +709,7 @@ async def run(self): } await self.initialize_topics() - self.initialize_docstores() + self.initialize_index() # Generate instructions for each category. self.outfile = open(self.output_path, "a+") @@ -635,6 +776,7 @@ def cull_instructions(args): **{ "type": str, "help": "path to the file containing instructions to cull", + "nargs": "+", }, ) parser.add_argument( diff --git a/example-config.yaml b/example-config.yaml index dc491a4..4cce677 100644 --- a/example-config.yaml +++ b/example-config.yaml @@ -19,6 +19,12 @@ overwrite: false # Append to the output file. append: true +# Embedding model, for calculating similarity between documents; probably best left as-is since the code is fairly specific to this one. +embedding_model: thenlper/gte-small +embedding_device: cpu +# If you have a GPU, set this to "cuda", e.g.: +# embedding_device: cuda + # Topic avoidance prompt string. topic_avoidance: Avoid any tasks that would be related to climate change, green tech, renewable energy, DEI (diversity, equity, inclusion), sex and/or gender, religion, politics, social issues, race, ethnicity, artificial intelligence, baking/cooking, urban development, or any topic that you would likely not respond to, or any task which a language model would not be able to respond to, e.g. tasks about emotions, feelings, physical senses, etc. @@ -379,6 +385,7 @@ instructors: count: 25 batch_size: 1 min_docsearch_score: 0.1 + seed_path: chat_card_seeds output_dir: chat_cards ################################################################################## diff --git a/setup.py b/setup.py index a694402..70887b3 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name="airoboros", - version="2.0.21", + version="2.0.22", description="Updated and improved implementation of the self-instruct system.", long_description=long_description, long_description_content_type="text/markdown", @@ -21,8 +21,8 @@ "backoff>=2.2", "requests>=2.28", "loguru>=0.7", - "chromadb>=0.3.21", - "langchain>=0.0.162", + "faiss-cpu==1.7.4", + "fast-sentence-transformers==0.4.1", "sentence-transformers>=2.2.2", ], extras_require={