Skip to content

Commit

Permalink
Update app.py
Browse files Browse the repository at this point in the history
  • Loading branch information
agasheaditya authored Aug 30, 2024
1 parent 84894d1 commit e0946e3
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
""", unsafe_allow_html=True)


device = "cpu"#"cuda" if torch.cuda.is_available() else "cpu"
device = "cpu" #"cuda" if torch.cuda.is_available() else "cpu"
f = open('Data/vocab.json')
vocab = json.load(f)
vocab_size = len(vocab) # 53529
Expand All @@ -72,8 +72,8 @@ def load_model(path:str):
loaded_model = TransformerModel(vocab_size, embed_size, num_heads, num_encoder_layers, num_decoder_layers, forward_expansion, dropout, max_len)

# Load the saved state dictionary into the model
# loaded_model.load_state_dict(torch.load(model_load_path, map_location=torch.device('cpu'), weights_only=False)) # ,pickle_module=pickle
loaded_model = torch.load(model_load_path, map_location=torch.device('cpu'), weights_only=False)
loaded_model.load_state_dict(torch.load(model_load_path, map_location=torch.device('cpu'), weights_only=False)) # ,pickle_module=pickle
# loaded_model = torch.load(model_load_path, map_location=torch.device('cpu'), weights_only=False)


# # Set the model to evaluation mode
Expand Down Expand Up @@ -156,4 +156,4 @@ def main():


if __name__ == '__main__':
main()
main()

0 comments on commit e0946e3

Please sign in to comment.