MLE in Practice: Training AI Models
Maximum Likelihood Estimation isn't just theory—it's the engine behind training virtually every modern AI model. This lesson shows how MLE translates into the practical algorithms used to train neural networks.
From MLE to Loss Functions
The fundamental connection:
Maximizing Log-Likelihood ↔ Minimizing Negative Log-Likelihood (NLL)
In deep learning, we minimize loss functions. The most common losses ARE negative log-likelihoods:
| Task | Loss Function | MLE Interpretation |
|---|---|---|
| Classification | Cross-entropy | -log P(correct class) |
| Regression | Mean squared error | -log P(y) under Gaussian |
| Sequence modeling | Cross-entropy per token | -log P(sequence) |
Cross-Entropy Loss as MLE
For classification with K classes:
True label: one-hot vector y = [0, 0, 1, 0, 0] Predicted probabilities: p̂ = [0.05, 0.1, 0.7, 0.1, 0.05]
Cross-entropy loss:
L = -Σ yᵢ × log(p̂ᵢ)
= -log(0.7)
≈ 0.36
MLE interpretation: We want to maximize P(correct class) = 0.7, which means minimizing -log(0.7).
Why This Works
- When model is confident and correct: log(0.99) ≈ -0.01 (small loss)
- When model is confident and wrong: log(0.01) ≈ -4.6 (large loss)
- When model is uncertain: log(0.33) ≈ -1.1 (medium loss)
The model is penalized heavily for confident wrong answers.
Mean Squared Error as MLE
For regression under Gaussian noise assumption:
Model: y = f(x) + ε, where ε ~ Normal(0, σ²)
Likelihood of one observation:
P(yᵢ | xᵢ, θ) = (1/√(2πσ²)) × exp(-(yᵢ - f(xᵢ))² / (2σ²))
Log-likelihood:
log P(yᵢ | xᵢ, θ) = -½log(2πσ²) - (yᵢ - f(xᵢ))² / (2σ²)
Maximizing over all data:
The first term is constant. The second term's sum is:
-Σ (yᵢ - f(xᵢ))² / (2σ²)
Maximizing this means minimizing:
Σ (yᵢ - f(xᵢ))²
This is exactly Mean Squared Error (MSE)!
Training Neural Networks
The Training Loop
for epoch in range(num_epochs):
for batch_x, batch_y in data_loader:
# Forward pass: compute predictions
predictions = model(batch_x)
# Compute loss (negative log-likelihood)
loss = loss_function(predictions, batch_y)
# Backward pass: compute gradients
loss.backward()
# Update parameters (gradient ascent on log-likelihood)
optimizer.step()
optimizer.zero_grad()
Every step is gradient ascent on log-likelihood (or equivalently, gradient descent on loss).
Batch Size and MLE
Full batch: Compute exact gradient of total log-likelihood
∇ℓ(θ) = Σᵢ ∇log P(yᵢ | xᵢ, θ)
Mini-batch: Estimate gradient with subset
∇ℓ(θ) ≈ (n/m) × Σᵢ∈batch ∇log P(yᵢ | xᵢ, θ)
This is Stochastic Gradient Descent (SGD)—an unbiased estimate of the true gradient.
Language Model Training
Language models are trained with MLE on sequences:
Objective: Maximize P(w₁, w₂, ..., wₙ)
Using chain rule:
P(w₁, w₂, ..., wₙ) = P(w₁) × P(w₂|w₁) × P(w₃|w₁,w₂) × ...
Log-likelihood:
ℓ = Σᵢ log P(wᵢ | w₁, ..., wᵢ₋₁)
Loss per token:
L = -1/n × Σᵢ log P(wᵢ | context)
This is cross-entropy loss at each position!
Example: Training on "The cat sat"
Input: "<start> The cat"
Target: "The cat sat"
Loss = -[log P("The" | <start>) +
log P("cat" | The) +
log P("sat" | The cat)]
The model learns to predict each next word given the previous words.
Regularization and MLE
Pure MLE can overfit. Regularization modifies the objective:
L2 Regularization (Ridge)
Regularized loss = Loss + λ × Σ wᵢ²
Bayesian interpretation: This is MLE with a Gaussian prior on weights!
Maximum a posteriori (MAP) estimation:
P(θ | data) ∝ P(data | θ) × P(θ)
log P(θ | data) = log P(data | θ) + log P(θ)
= log-likelihood + log(Gaussian prior)
= log-likelihood - λ × ||θ||²
L1 Regularization (Lasso)
Regularized loss = Loss + λ × Σ |wᵢ|
Bayesian interpretation: Laplace prior on weights (encourages sparsity).
Dropout
Randomly dropping neurons during training is implicitly similar to training an ensemble of models, approximating Bayesian model averaging.
Early Stopping
Another form of regularization:
- Train while monitoring validation loss
- Stop when validation loss starts increasing
- Use the model at the stopping point
This prevents the model from maximizing the training likelihood too much (overfitting).
Training Loss ↓↓↓↓↓↓↓↓↓↓
Val Loss ↓↓↓↓↓↓↑↑↑↑ ← Stop here!
Optimizers as MLE Algorithms
Different optimizers approach the MLE in different ways:
SGD (Stochastic Gradient Descent)
θ ← θ - lr × ∇L(θ)
Simple but can be slow or unstable.
Momentum
v ← β×v + ∇L(θ)
θ ← θ - lr × v
Accumulates gradients for faster convergence.
Adam
m ← β₁×m + (1-β₁)×∇L(θ) # First moment (mean)
v ← β₂×v + (1-β₂)×(∇L(θ))² # Second moment (variance)
θ ← θ - lr × m / (√v + ε)
Adapts learning rate per parameter based on gradient statistics.
All are finding the same MLE—just with different paths!
MLE for Different Architectures
Convolutional Neural Networks (CNNs)
For image classification:
Loss = -log P(correct_class | image)
The architecture (convolutions, pooling) defines the model structure, but training is still MLE.
Recurrent Neural Networks (RNNs)
For sequences:
Loss = -Σₜ log P(yₜ | y₁, ..., yₜ₋₁, x)
Each timestep contributes to the total likelihood.
Transformers
Same as RNNs conceptually:
Loss = -Σₜ log P(tokenₜ | token₁, ..., tokenₜ₋₁)
The attention mechanism just changes how context is computed.
Handling Imbalanced Data
When classes are imbalanced, MLE can be biased toward common classes.
Weighted Loss
Weight rare class examples more:
L = -Σᵢ wᵢ × log P(yᵢ | xᵢ)
where wᵢ is higher for rare classes.
Class-Balanced Sampling
Sample batches to balance classes, then use standard MLE.
Focal Loss
Down-weight easy (confident) examples:
L = -(1 - p̂)^γ × log(p̂)
This focuses training on hard examples.
Numerical Considerations
Log-Sum-Exp Trick
Computing log(Σ exp(xᵢ)) directly can overflow/underflow.
Solution:
log(Σ exp(xᵢ)) = max(x) + log(Σ exp(xᵢ - max(x)))
Label Smoothing
Instead of hard labels [0, 0, 1, 0]:
Smoothed = [0.025, 0.025, 0.925, 0.025]
Prevents over-confidence and improves generalization.
Summary
- Cross-entropy loss is negative log-likelihood for classification
- MSE is MLE under Gaussian noise assumption
- Training neural networks = gradient ascent on log-likelihood
- Regularization = MLE with priors (MAP estimation)
- Different optimizers (SGD, Adam) find the MLE differently
- All modern architectures (CNN, RNN, Transformer) use MLE training
Next, we'll explore loss functions and optimization in more depth.

