Why using Logit Scaling in Softmax?

Description of this Post
Author
Published

November 25, 2023

Why using Logit Scaling in Softmax?

Description of this Post
Author
Published

November 25, 2023

Discussed in the context of Depp Learning in this section

Let’s consider a hypothetical set of logits before applying the softmax function. For simplicity, let’s take a vector with three logits: \((2, 5, 8)\). The softmax function is defined as:

\[ P_i = \frac{e^{z_i / \tau}}{\sum_{j=1}^{N} e^{z_j / \tau}} \]

where: - \(z_i\) is the logit for class \(i\), - \(N\) is the number of classes, - \(\tau\) is the temperature.

Let’s examine how different temperatures affect the resulting softmax probabilities:

  1. High Temperature (Smoothed Distribution):

    • Temperature (\(\tau\)) is high, let’s say \(\tau = 5\).
    • Softmax Probabilities: \(P_1, P_2, P_3\) where \(P_i\) is the probability for class \(i\).
    logits = [2, 5, 8]
    tau = 5.0
    
    softmax_probs = torch.nn.functional.softmax(torch.tensor(logits) / tau, dim=-1)
    print(softmax_probs.numpy())

    The output will be a set of probabilities where the distribution is “smoothed” due to the high temperature:

    [0.04661262, 0.23688284, 0.71650454]

    Notice that the probabilities are more evenly distributed among the classes.

  2. Low Temperature (Sharp Distribution):

    • Temperature (\(\tau\)) is low, let’s say \(\tau = 0.5\).
    • Softmax Probabilities: \(P_1, P_2, P_3\)
    tau = 0.5
    
    softmax_probs = torch.nn.functional.softmax(torch.tensor(logits) / tau, dim=-1)
    print(softmax_probs.numpy())

    The output will be a set of probabilities where the distribution is “sharpened” due to the low temperature:

    [0.00242826, 0.04741446, 0.95015728]

    Notice that the probability for the class with the highest logit (class 3, with logit 8) is much higher compared to the others.

     # pairwise similarities between ima_feat and text_feat -> softamx -> scaling
     # scaling: self.logit_scale = self.clip_model.logit_scale.exp().detach() aka e^(logit_scale)
     # ^ makes softmax have a temperture term in its formula
     #   - high tau -> smothing probs
     #   - low  tau -> peaked probs

In these examples, adjusting the temperature influences the degree of exploration (smoothness) versus exploitation (sharpness) in the distribution of probabilities. Higher temperatures encourage a more uniform distribution, while lower temperatures emphasize the most confident predictions. This temperature parameter is often used in training as a form of regularization to control the model’s uncertainty.