From 6288c96443b1ae430fa61e9861ade8f1714e192f Mon Sep 17 00:00:00 2001 From: BBC-Esq Date: Sat, 13 Apr 2024 16:42:47 -0400 Subject: [PATCH] fix class names ASAP!!! --- InstructorEmbedding/instructor.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/InstructorEmbedding/instructor.py b/InstructorEmbedding/instructor.py index 9d4f90a..b0a7674 100644 --- a/InstructorEmbedding/instructor.py +++ b/InstructorEmbedding/instructor.py @@ -23,7 +23,7 @@ def batch_to_device(batch, target_device: str): return batch -class InstructorPooling(nn.Module): +class INSTRUCTORPooling(nn.Module): """Performs pooling (max or mean) on the token embeddings. Using pooling, it generates from a variable sized sentence a fixed sized sentence embedding. @@ -245,7 +245,7 @@ def load(input_path): ) as config_file: config = json.load(config_file) - return InstructorPooling(**config) + return INSTRUCTORPooling(**config) def import_from_string(dotted_path): @@ -271,7 +271,7 @@ def import_from_string(dotted_path): raise ImportError(msg) -class InstructorTransformer(Transformer): +class INSTRUCTORTransformer(Transformer): def __init__( self, model_name_or_path: str, @@ -378,7 +378,7 @@ def load(input_path: str): with open(sbert_config_path, encoding="UTF-8") as config_file: config = json.load(config_file) - return InstructorTransformer(model_name_or_path=input_path, **config) + return INSTRUCTORTransformer(model_name_or_path=input_path, **config) def tokenize(self, texts): """ @@ -420,7 +420,7 @@ def tokenize(self, texts): input_features = self.tokenize(instruction_prepended_input_texts) instruction_features = self.tokenize(instructions) - input_features = Instructor.prepare_input_features( + input_features = INSTRUCTOR.prepare_input_features( input_features, instruction_features ) else: @@ -430,7 +430,7 @@ def tokenize(self, texts): return output -class Instructor(SentenceTransformer): +class INSTRUCTOR(SentenceTransformer): @staticmethod def prepare_input_features( input_features, instruction_features, return_data_type: str = "pt" @@ -510,7 +510,7 @@ def smart_batching_collate(self, batch): input_features = self.tokenize(instruction_prepended_input_texts) instruction_features = self.tokenize(instructions) - input_features = Instructor.prepare_input_features( + input_features = INSTRUCTOR.prepare_input_features( input_features, instruction_features ) batched_input_features.append(input_features) @@ -559,9 +559,9 @@ def _load_sbert_model(self, model_path, token = None, cache_folder = None, revis modules = OrderedDict() for module_config in modules_config: if module_config["idx"] == 0: - module_class = InstructorTransformer + module_class = INSTRUCTORTransformer elif module_config["idx"] == 1: - module_class = InstructorPooling + module_class = INSTRUCTORPooling else: module_class = import_from_string(module_config["type"]) module = module_class.load(os.path.join(model_path, module_config["path"])) @@ -686,4 +686,4 @@ def encode( if input_was_string: all_embeddings = all_embeddings[0] - return all_embeddings + return all_embeddings \ No newline at end of file