Skip to content

Commit

Permalink
[F] Fix examples
Browse files Browse the repository at this point in the history
  • Loading branch information
hykilpikonna committed Oct 11, 2023
1 parent b41935f commit bb3bfb4
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions examples/audio_captioning.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""
This is an example using CLAPCAP for audio captioning.
"""
from CLAPWrapper import CLAPWrapper
from msclap import CLAP

# Load and initialize CLAP
weights_path = "weights_path"
clap_model = CLAPWrapper(weights_path, version = 'clapcap', use_cuda=False)
clap_model = CLAP(weights_path, version = 'clapcap', use_cuda=False)

#Load audio files
audio_files = ['audio_file']
Expand Down
4 changes: 2 additions & 2 deletions examples/zero_shot_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
classification on ESC50 (https://github.com/karolpiczak/ESC-50).
"""

from CLAPWrapper import CLAPWrapper
from msclap import CLAP
from esc50_dataset import ESC50
import torch.nn.functional as F
import numpy as np
Expand All @@ -18,7 +18,7 @@

# Load and initialize CLAP
weights_path = "weights_path"
clap_model = CLAPWrapper(weights_path, version = '2023', use_cuda=False)
clap_model = CLAP(weights_path, version = '2023', use_cuda=False)

# Computing text embeddings
text_embeddings = clap_model.get_text_embeddings(y)
Expand Down
4 changes: 2 additions & 2 deletions examples/zero_shot_predictions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
This is an example using CLAP for zero-shot inference.
"""
from CLAPWrapper import CLAPWrapper
from msclap import CLAP
import torch.nn.functional as F

# Define classes for zero-shot
Expand All @@ -17,7 +17,7 @@
# Load and initialize CLAP
weights_path = "weights_path"
# Setting use_cuda = True will load the model on a GPU using CUDA
clap_model = CLAPWrapper(weights_path, version = '2023', use_cuda=False)
clap_model = CLAP(weights_path, version = '2023', use_cuda=False)

# compute text embeddings from natural text
text_embeddings = clap_model.get_text_embeddings(class_prompts)
Expand Down

0 comments on commit bb3bfb4

Please sign in to comment.