diff --git a/clip/model.py b/clip/model.py index 1ddd90862..2a0924599 100644 --- a/clip/model.py +++ b/clip/model.py @@ -287,7 +287,7 @@ def __init__(self, self.ln_final = LayerNorm(transformer_width) self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) - self.logit_scale = nn.Parameter(torch.ones([])) + self.logit_scale = nn.Parameter(torch.FloatTensor([np.log(1/0.07)])) self.initialize_parameters()