Skip to content

Commit

Permalink
fix class names ASAP!!!
Browse files Browse the repository at this point in the history
  • Loading branch information
BBC-Esq authored Apr 13, 2024
1 parent 5cca65e commit 6288c96
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions InstructorEmbedding/instructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def batch_to_device(batch, target_device: str):
return batch


class InstructorPooling(nn.Module):
class INSTRUCTORPooling(nn.Module):
"""Performs pooling (max or mean) on the token embeddings.
Using pooling, it generates from a variable sized sentence a fixed sized sentence embedding.
Expand Down Expand Up @@ -245,7 +245,7 @@ def load(input_path):
) as config_file:
config = json.load(config_file)

return InstructorPooling(**config)
return INSTRUCTORPooling(**config)


def import_from_string(dotted_path):
Expand All @@ -271,7 +271,7 @@ def import_from_string(dotted_path):
raise ImportError(msg)


class InstructorTransformer(Transformer):
class INSTRUCTORTransformer(Transformer):
def __init__(
self,
model_name_or_path: str,
Expand Down Expand Up @@ -378,7 +378,7 @@ def load(input_path: str):

with open(sbert_config_path, encoding="UTF-8") as config_file:
config = json.load(config_file)
return InstructorTransformer(model_name_or_path=input_path, **config)
return INSTRUCTORTransformer(model_name_or_path=input_path, **config)

def tokenize(self, texts):
"""
Expand Down Expand Up @@ -420,7 +420,7 @@ def tokenize(self, texts):

input_features = self.tokenize(instruction_prepended_input_texts)
instruction_features = self.tokenize(instructions)
input_features = Instructor.prepare_input_features(
input_features = INSTRUCTOR.prepare_input_features(
input_features, instruction_features
)
else:
Expand All @@ -430,7 +430,7 @@ def tokenize(self, texts):
return output


class Instructor(SentenceTransformer):
class INSTRUCTOR(SentenceTransformer):
@staticmethod
def prepare_input_features(
input_features, instruction_features, return_data_type: str = "pt"
Expand Down Expand Up @@ -510,7 +510,7 @@ def smart_batching_collate(self, batch):

input_features = self.tokenize(instruction_prepended_input_texts)
instruction_features = self.tokenize(instructions)
input_features = Instructor.prepare_input_features(
input_features = INSTRUCTOR.prepare_input_features(
input_features, instruction_features
)
batched_input_features.append(input_features)
Expand Down Expand Up @@ -559,9 +559,9 @@ def _load_sbert_model(self, model_path, token = None, cache_folder = None, revis
modules = OrderedDict()
for module_config in modules_config:
if module_config["idx"] == 0:
module_class = InstructorTransformer
module_class = INSTRUCTORTransformer
elif module_config["idx"] == 1:
module_class = InstructorPooling
module_class = INSTRUCTORPooling
else:
module_class = import_from_string(module_config["type"])
module = module_class.load(os.path.join(model_path, module_config["path"]))
Expand Down Expand Up @@ -686,4 +686,4 @@ def encode(
if input_was_string:
all_embeddings = all_embeddings[0]

return all_embeddings
return all_embeddings

0 comments on commit 6288c96

Please sign in to comment.