Skip to content

Commit

Permalink
add load_model method
Browse files Browse the repository at this point in the history
  • Loading branch information
Rubing Shen committed Oct 4, 2023
1 parent b900e77 commit 012c0b2
Showing 1 changed file with 25 additions and 3 deletions.
28 changes: 25 additions & 3 deletions AugmentedSocialScientist/bert_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ def encode(
label_names = np.unique(labels) # unique label names in alphabetical order
self.dict_labels = dict(zip(label_names, range(len(label_names))))

print(f"label ids: {self.dict_labels}")
if progress_bar:
print(f"label ids: {self.dict_labels}")

# Convert to pytorch tensors
inputs_tensors = torch.tensor(input_ids)
Expand Down Expand Up @@ -466,10 +467,31 @@ def predict(
torch.cuda.empty_cache() # release GPU memory

pred = np.concatenate(logits_complete) # flatten batches
print(f"label ids: {self.dict_labels}")
if progress_bar:
print(f"label ids: {self.dict_labels}")

return softmax(pred, axis=1) if proba else pred

def load_model(
self,
model_path: str
):
"""
Load a saved model
Parameters
----------
model_path: str
path to the saved model
Return
------
model : huggingface model
loaded model
"""
return self.model_sequence_classifier.from_pretrained(model_path)


def predict_with_model(
self,
dataloader: DataLoader,
Expand Down Expand Up @@ -499,7 +521,7 @@ def predict_with_model(
pred: ndarray of shape (n_samples, n_labels)
probabilities for each sequence (row) of belonging to each category (column)
"""
model = self.model_sequence_classifier.from_pretrained(model_path)
model = self.load_model(model_path)
if torch.cuda.is_available():
model.cuda()
return self.predict(dataloader, model, proba, progress_bar)
Expand Down

0 comments on commit 012c0b2

Please sign in to comment.