Chapter 13 — The forward pass with loss
After Chapter 12 we have a model that turns token ids into logits. To train the model we need the other side of Chapter 4’s gradient-descent recipe: a scalar loss to minimise. For language modelling that loss is cross-entropy between the model’s predicted next-token distribution and the actual next token — exactly the cross-entropy of Ch.4 §4.10’s promise, finally cashed in.
By the end you will have:
- understood the next-token prediction task: given tokens $0, 1, \ldots, t$, predict token $t+1$,
- met cross-entropy loss and computed it by hand on one $(logits, target)$ pair,
- understood the target shift trick: input is the sequence; targets are the same sequence shifted left by one,
- updated
mygpt.GPT.forward(ids, targets=None)to return both logits and (optionally) loss, - watched a random-init GPT-2-small-shaped model produce a loss of about
4.26on the running example, and understood why it is higher than the ideal $\log V \approx 1.39$.
There is one small new concept (cross-entropy) and a useful new convention (the (logits, loss) return tuple). After this chapter we are exactly one Chapter 4-style training loop away from a trained model.
13.1 The next-token prediction task
Recall §1.5: a language model factors $P(x_1, \ldots, x_T)$ as
\[P(x_1) \cdot P(x_2 \mid x_1) \cdot P(x_3 \mid x_1, x_2) \cdot \ldots \cdot P(x_T \mid x_1, \ldots, x_{T-1}).\]Training optimises every conditional simultaneously. For our model, after running forward(ids) we have logits of shape (B, T, V) — one length-$V$ logit vector per (b, t) position. The contract:
logits[b, t, :]is the model’s prediction of the next token afterids[b, 0..t]— that is, ofids[b, t+1].
So position $t = 0$ predicts position 1, position 1 predicts position 2, …, position $T-1$ predicts position $T$. Position $T$ has no prediction target — it would predict the (non-existent) $(T+1)$-th token.
The cleanest way to feed this into a loss is to shift: input is ids[:, :-1] (length $T-1$), targets are ids[:, 1:] (length $T-1$, the same tokens shifted left by one). Both tensors then have shape (B, T-1), and every logits[b, t, :] has a matching target[b, t].
For our running example "I love AI !" = [0, 1, 2, 3] (length 4):
- input:
[0, 1, 2]←"I love AI"(length 3) - targets:
[1, 2, 3]←"love AI !"(length 3)
Position 0 of the input is 0 ("I"); the target at position 0 is 1 ("love") — the model is learning that "love" is a likely next token after "I". And so on for the other two positions.
13.2 Setup
This chapter assumes you finished Chapter 12 — mygpt/ exists with the GPT class.
If you skipped Chapter 12, recreate the state from a clean directory:
uv init mygpt --package
cd mygpt
mkdir -p experiments
uv add torch numpy
Then overwrite src/mygpt/__init__.py with the Chapter 12 ending state from docs/_state_after_ch12.md.
You are ready.
13.3 Cross-entropy loss
Suppose the model produces logit vector $\mathbf{z} \in \mathbb{R}^V$ for one position, and the true next token is index $y \in {0, \ldots, V-1}$. The cross-entropy loss is
\[\mathcal{L}(\mathbf{z}, y) \;=\; -\log\!\Big(\frac{e^{z_y}}{\sum_{j=0}^{V-1} e^{z_j}}\Big) \;=\; -\log \text{softmax}(\mathbf{z})_y \;=\; -z_y + \log \sum_{j=0}^{V-1} e^{z_j}.\]Three properties to read off:
- Non-negative. Softmax outputs are in $(0, 1]$, so $-\log \text{softmax}$ is in $[0, \infty)$.
- Zero only when $\text{softmax}(\mathbf{z})_y = 1$. That requires the model to put all its mass on the correct token — perfect prediction.
- Equals $\log V$ when the logits are all equal. Uniform softmax gives $\frac{1}{V}$ at every index, so $-\log \frac{1}{V} = \log V$. For our $V = 4$ that is $\log 4 \approx 1.386$ — the random-guessing baseline.
For a batch of predictions, we average the per-position losses:
\[\mathcal{L}_\text{batch} \;=\; \frac{1}{N} \sum_{i=1}^{N} \mathcal{L}(\mathbf{z}_i, y_i),\]where $N = B \cdot (T-1)$ is the number of (input, target) pairs in the batch.
PyTorch ships this as F.cross_entropy(logits, targets) and applies the softmax internally — for numerical stability, the log-sum-exp form on the right is computed in a numerically-stable way that never explicitly forms softmax. The expected input shapes are:
logits: shape(N, V)— each row is the logit vector for one prediction.targets: shape(N,)— each entry is the true class index.
Our (B, T, V) logits and (B, T) targets need to be flattened: logits.view(B*T, V) and targets.view(B*T).
A worked example. Take logits $\mathbf{z} = (1.0, 2.0, 0.5, -1.0)$ and target $y = 1$:
\[\begin{aligned} e^{1.0} \approx 2.7183, \;\; e^{2.0} \approx 7.3891, \;\; e^{0.5} \approx 1.6487, \;\; e^{-1.0} \approx 0.3679 \\ \sum e^{z_j} \approx 12.1240, \qquad \text{softmax}(\mathbf{z})_1 \approx 7.3891 / 12.1240 \approx 0.6095 \\ \mathcal{L} \;=\; -\log(0.6095) \;\approx\; 0.4952. \end{aligned}\]We will reproduce that with F.cross_entropy next.
Save the following to 📄 experiments/26_cross_entropy_by_hand.py:
"""Experiment 26 — Cross-entropy by hand on (logits=(1.0, 2.0, 0.5, -1.0), target=1)."""
import math
import torch
import torch.nn.functional as F
def main() -> None:
logits = torch.tensor([1.0, 2.0, 0.5, -1.0])
target = torch.tensor(1)
# By hand
exp_logits = torch.exp(logits)
sum_exp = exp_logits.sum().item()
softmax_target = exp_logits[target].item() / sum_exp
loss_by_hand = -math.log(softmax_target)
print(f"logits: {logits}")
print(f"target: {target.item()}")
print(f"exp(logits): {exp_logits}")
print(f"sum exp: {sum_exp:.6f}")
print(f"softmax[target]: {softmax_target:.6f}")
print(f"-log(softmax[target]): {loss_by_hand:.6f}")
print()
# By F.cross_entropy
loss_torch = F.cross_entropy(logits.unsqueeze(0), target.unsqueeze(0))
print(f"F.cross_entropy: {loss_torch.item():.6f}")
print(f"matches by-hand: {abs(loss_by_hand - loss_torch.item()) < 1e-6}")
if __name__ == "__main__":
main()
Run it:
uv run python experiments/26_cross_entropy_by_hand.py
Expected output:
logits: tensor([ 1.0000, 2.0000, 0.5000, -1.0000])
target: 1
exp(logits): tensor([2.7183, 7.3891, 1.6487, 0.3679])
sum exp: 12.123940
softmax[target]: 0.609460
-log(softmax[target]): 0.495182
F.cross_entropy: 0.495182
matches by-hand: True
Two things to read off:
- The loss is
0.4952, not zero. The model assigns 60.95% probability to the correct token. Since softmax of2.0(the biggest logit) does not equal1.0, the loss is non-zero. Perfect prediction would require the logit at the target to be infinitely larger than the others. - PyTorch’s loss equals our by-hand computation.
F.cross_entropyapplies log-softmax internally; we did the same arithmetic explicitly.
13.4 Updating GPT.forward to return loss
The standard convention in nanoGPT-style code is for forward to take an optional targets argument. If targets is None, return only the logits (this is how you call it at generation time, when there are no labels). If targets is not None, also return the cross-entropy loss.
Replace the forward method of GPT with:
def forward(self, ids: torch.Tensor, targets: torch.Tensor | None = None) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Compute logits and (optionally) cross-entropy loss.
Inputs:
ids: long tensor of shape (B, T) with values in [0, vocab_size).
targets: optional long tensor of shape (B, T); if given, every
logits[b, t, :] is scored against targets[b, t] via
cross-entropy.
Outputs:
(logits, loss). `loss` is None if `targets` is None.
"""
B, T = ids.shape
if T > self.max_seq_len:
raise ValueError(
f"input length T={T} exceeds max_seq_len={self.max_seq_len}"
)
positions = torch.arange(T, device=ids.device)
x = self.token_embedding(ids) + self.position_embedding(positions)
x = self.embed_drop(x)
x = self.blocks(x)
x = self.ln_f(x)
logits = x @ self.token_embedding.embedding.weight.T # (B, T, V)
if targets is None:
return logits, None
# F.cross_entropy expects (N, V) logits and (N,) targets.
loss = F.cross_entropy(
logits.view(B * T, -1),
targets.view(B * T),
)
return logits, loss
(You also need to add import math at the top of __init__.py if it’s not already there — it is, from Chapter 6.)
Then update main to demonstrate the (input, targets) → (logits, loss) flow:
def main() -> None:
print("Vocabulary:", VOCAB)
print(f"Vocabulary size V = {len(VOCAB)}")
set_seed(0)
V, C, h, N = len(VOCAB), 4, 2, 2
gpt = GPT(vocab_size=V, embed_dim=C, num_heads=h, num_layers=N,
max_seq_len=64, dropout=0.0)
gpt.eval()
# Next-token prediction setup: input is the first T-1 tokens, targets
# are the same sequence shifted left by 1.
full = to_ids(["I", "love", "AI", "!"])
input_ids = full[:-1].unsqueeze(0) # (1, 3) = [[0, 1, 2]] = "I love AI"
targets = full[1:].unsqueeze(0) # (1, 3) = [[1, 2, 3]] = "love AI !"
logits, loss = gpt(input_ids, targets)
print(f"\nInput ids (B, T): {tuple(input_ids.shape)} {input_ids.tolist()}")
print(f"Targets (B, T): {tuple(targets.shape)} {targets.tolist()}")
print(f"Logits (B, T, V): {tuple(logits.shape)}")
print(f"Loss: {loss.item():.4f}")
import math as _math
print(f"\nReference: log(V) = log({V}) = {_math.log(V):.4f}")
print("(Random-init loss is typically a small multiple of log(V); training drives it down.)")
(The import math as _math inside main avoids shadowing the outer math import if you have from mygpt import math elsewhere — a defensive paranoia, not strictly needed here.)
Run it:
uv run mygpt
Expected output:
Vocabulary: ('I', 'love', 'AI', '!')
Vocabulary size V = 4
Input ids (B, T): (1, 3) [[0, 1, 2]]
Targets (B, T): (1, 3) [[1, 2, 3]]
Logits (B, T, V): (1, 3, 4)
Loss: 4.2588
Reference: log(V) = log(4) = 1.3863
Three things to read off:
- Loss = 4.2588 at random init. That is higher than the random-guess baseline of
log(V) = 1.39. Why? A randomly-initialised network with N(0,1) weights produces logits of moderate magnitude, and softmax on moderate logits puts disproportionate mass on whichever logit happens to be biggest. If the biggest logit happens to be at the wrong token (which it usually is, since training has not occurred), the model is confidently wrong, and confident-wrongness gives loss > log(V). With weight initialisation tuned to start the network closer to identity (which real implementations do), the random-init loss is closer to log(V). logits.shape = (1, 3, 4). Three positions of prediction (the input length), each producing a length-4 logit vector. After softmax the row sums would be 1; after argmax, each row would be the model’s most-likely next token at that position.lossis a scalar.F.cross_entropyreturns a single number — the average over the $B \cdot T = 3$ pairs. This is the scalar gradient descent will minimise in Chapter 14.
13.5 The forward contract: what does and doesn’t return loss
Two regimes:
| Use case | Call | Returns |
|---|---|---|
| Inference / generation | logits, _ = gpt(input_ids) |
(logits, None) |
| Training | logits, loss = gpt(input_ids, targets) |
(logits, loss) |
This pattern appears in every nanoGPT-style codebase. The key reason for the optional targets: at generation time, you don’t have a target — you are creating the next token, not comparing against a known one. Returning None for loss lets the same forward function serve both regimes.
A subtle nuance: because Python’s tuple unpacking is positional, you can write logits, loss = gpt(ids) even when targets is None — the second element is None, which assigns fine. This is the canonical idiom.
13.6 Experiments
- Confirm uniform-logits gives
log(V). In a Python session, setlogits = torch.zeros(1, 3, 4)andtargets = torch.tensor([[1, 2, 3]]). Computeloss = F.cross_entropy(logits.view(3, 4), targets.view(3)). The result should be exactlylog(4) ≈ 1.3863— uniform logits are exactly the random-guess baseline. - A fully-confident correct prediction has loss 0. Set
logits = torch.tensor([[[0.0, 100.0, 0.0, 0.0]]])andtargets = torch.tensor([[1]]). Compute the loss. The answer is essentially 0 (because softmax of a +100 logit is essentially 1.0). - A fully-confident wrong prediction has very high loss. Same logits as above, but
targets = torch.tensor([[0]])(target is the position with logit 0, not the position with logit 100). The loss is approximately 100 — the negative log of the tiny softmax probability assigned to the wrong target. Cross-entropy punishes over-confidence on wrong predictions strongly. - The forward
targets=Noneregime. Callgpt(input_ids)(without targets). The return is(logits, None). Extract just logits withlogits, _ = gpt(input_ids). This is what generation will look like in Chapter 15.
After each experiment, restore the file you changed before moving on (only exp 4 modifies anything, and only ephemerally).
13.7 Exercises
- The mean over what? When we call
F.cross_entropy(logits.view(B*T, V), targets.view(B*T)), PyTorch averages over all $B \cdot T$ predictions — both the batch and the time axes. Why does that matter for batch-size invariance? (Hint: if instead PyTorch summed, doubling the batch size would double the loss, and gradients would be twice as big — effectively halving the learning rate.) - Why log-sum-exp instead of explicit softmax? PyTorch’s
F.cross_entropydoes not actually computesoftmax(logits)and thenlogit — it computes-z_y + log(sum exp(z))directly via the log-sum-exp identity. Argue why this is more numerically stable, by considering what happens for $z_j = 1000$. (Hint:exp(1000)overflows;log(sum exp(z))doesn’t, when computed via the trick.) - What is the highest possible loss? For a $V$-class cross-entropy where the model puts probability $\epsilon$ on the true class, the loss is $-\log \epsilon$. As $\epsilon \to 0$, the loss $\to \infty$. There is no upper bound. Argue informally why this property is useful for training (gradient signal is strongest where the model is most wrong).
- The shift trick formalised. Given a sequence
tokensof length $T$, write the input/target slicing in Python (tokens[:-1]andtokens[1:]). Confirm both have length $T-1$. What is the relationship between the input at position $t$ and the target at position $t$? (Answer: target at position $t$ istokens[t+1]; the model is predicting the token at position $t+1$ from the prefixtokens[0..t].)
13.8 What’s next
We have a model and a loss. Chapter 14 closes the loop: a Chapter-4-style training loop applied to a real text file. We will
- tokenise the file into a
(N,)long tensor of ids, - sample random
(input, target)pairs of length(B, T)from it, - run
logits, loss = gpt(input, target),loss.backward(),optimizer.step(), - watch the loss decrease over many iterations.
After Chapter 14 we have a trained mygpt.GPT — and Chapter 15 will sample text from it.
Looking ahead — what to remember from this chapter
- Cross-entropy on logits and an integer target is
-z_y + log(sum exp(z)). PyTorch ships it asF.cross_entropy, applying log-sum-exp for numerical stability.- The uniform-logits baseline is
log(V). A random-init network typically produces a loss several times higher because it is confidently wrong on most positions.- The next-token shift: input is
tokens[:-1], targets aretokens[1:]. Position $t$ predicts position $t+1$.mygpt.GPT.forward(ids, targets=None)returns(logits, loss)wherelossisNonefor inference and a scalar for training.
On to Chapter 14 — Training loop: gradient descent in practice (coming soon).