- Take a random small model with weight tying, Llama-3.2-1B in this case. Input some random text and do a forward pass, record what is being added to the residual stream at each layer.
- Look at the final logit output and check for the top few most likely next tokens, then record their (normalized) token embedding as their direction. At least in the last layer hidden states those direction are meaningful and basically represent how much the model wants the output to be that token.
- Check which layers contributed most to those directions. I computed each layer's percentage contribution by dotting each layer's output with the above direction vector and divide by total magnitude in that direction.
So for example suppose the input text is just "Steve", then the most likely next token is " Jobs". I then record the " Jobs" token embedding as direction (I also tried normalizing it but it doesn't change the end result), dot it with the final hidden state which gets 18, which is exactly the number in the raw logits. Before the final hidden state there was a RMSNorm which only scale the magnitude but doesn't change the direction. And the pre-norm dot product is about 3. So what I did was dotting the output of each layer with the " Jobs" direction, which turns out the final MLP contributed more than 2 out of 3 here where all other MLP and attention layers contribute very small amount and can be seen as the result of some kind of interference most likely.
And it turns out that the final MLP layer consistently contributed to 60%-80% (sometimes as high as 90%) of the magnitude in top output directions after trying many input texts. I also checked the frobenius norm of all down_proj matrix of all the MLP layers to make sure it's not just the last layer outputting everything large. (All of them are mostly the same)
My conclusion is that the final MLP takes in whatever the real hidden representation of the input text is (concentrated on the last token), and just output the probability distribution of next token directly. And the actual unembedding matrix just acts as a format converter (much like softmax) instead of having any meaningful computation itself. But since they aren't real parameters there, it isn't really wasteful and could indeed be a more efficient way for small models. But functionally speaking doing weight tying seems to just make the last MLP to be true unembedding and you effectively lose one MLP layer worth of computation.
I am not a researcher and am not sure if this is the best place to have this kind of discussion. I would appreciate any opinion on if my method and the result makes sense and what are some good places to discuss things like this.