MXNorm: Reusing MXFP block scales for efficient tensor normalisation

TL;DR

MXNorm reuses MXFP8 block scales for efficient tensor normalization, reducing reduction size by 32x.

cs.LG 🔴 Advanced 2026-03-14 3 views
Callum McLean Luke Y. Prince Alexandre Payot Paul Balança Carlo Luschi
efficiency pretraining quantization tensor normalization large models

Key Findings

Methodology

This paper introduces MXNorm, a novel tensor normalization method that estimates the RMS by reusing block scales calculated during MXFP8 quantization, achieving efficient normalization. MXNorm serves as a replacement for RMSNorm, reducing the reduction size required for normalization by 32 times. The method was validated on pre-training Llama 3 models with 125M, 1B, and 8B parameters, showing minimal loss in training accuracy compared to a baseline using RMSNorm with MXFP8 matrix multiplications. Additionally, practical kernel speedups using torch.compile showed MXNorm to be up to 2.4x faster than RMSNorm.

Key Results

  • In pre-training Llama 3 models, MXNorm showed minimal training accuracy loss on 125M, 1B, and 8B parameter models compared to RMSNorm, indicating its effectiveness as a replacement.
  • Using torch.compile for kernel acceleration, MXNorm was up to 2.4x faster than RMSNorm, translating to a 1.3% speedup in Llama 3 8B transformer layers in MXFP8 and a 2.6% speedup in NVFP4.
  • On the 8B parameter model, MXNorm(p=2) achieved a training loss comparable to RMSNorm, with losses of 2.126 and 2.132 respectively, while MXNorm(p=1) had a worse final loss of 2.175.

Significance

The introduction of MXNorm is significant for both academia and industry as it addresses the performance bottleneck in reduction and elementwise computations, which lag behind the improvements in low-precision matrix multiplication. By reducing the reduction size requirement, MXNorm decreases computational overhead, enhancing training efficiency and speed. This method is particularly beneficial for pre-training large language models, allowing for significant speed improvements without substantial accuracy loss.

Technical Contribution

The technical contribution of MXNorm lies in its innovative fusion of normalization with the MX quantization process, reducing redundant computation overhead. Compared to existing RMSNorm methods, MXNorm estimates RMS by reusing MXFP8 block scales, reducing the reduction size requirement by 32 times. Furthermore, MXNorm demonstrates significant speed improvements in practical kernel acceleration, showcasing its potential in engineering applications.

Novelty

The novelty of MXNorm is in its first-time integration of normalization with the MX quantization process, reusing MXFP8 block scales to estimate RMS. This method significantly reduces computational overhead compared to traditional RMSNorm methods, enhancing training efficiency.

Limitations

  • MXNorm shows slightly higher training loss than RMSNorm on larger models (e.g., 8B parameters), indicating potential accuracy loss in certain scenarios.
  • MXNorm's stability is poorer at higher learning rates, with a tendency to exhibit loss spikes.
  • MXNorm's performance depends on specific hardware and compiler optimizations, which may vary across platforms.

Future Work

Future research directions include further optimizing MXNorm's stability, especially at higher learning rates. Additionally, exploring MXNorm's application in other types of neural network models and its performance across different hardware platforms is crucial. Research on enhancing MXNorm's generality and adaptability without relying on specific hardware optimizations is also a promising direction.

AI Executive Summary

In the rapid development of deep learning, the improvement of matrix multiplication performance has been a key driver for scaling large models. However, while low-precision matrix multiplication has seen significant acceleration, the performance of reductions and elementwise computations has lagged, becoming a new bottleneck. To address this issue, this paper introduces a novel tensor normalization method—MXNorm. MXNorm achieves efficient normalization by reusing block scales calculated during MXFP8 quantization, reducing the reduction size requirement by 32 times.

MXNorm serves as a replacement for RMSNorm and was validated on pre-training Llama 3 models with 125M, 1B, and 8B parameters. Experimental results showed minimal training accuracy loss compared to a baseline using RMSNorm with MXFP8 matrix multiplications. Additionally, practical kernel speedups using torch.compile demonstrated MXNorm to be up to 2.4x faster than RMSNorm, translating to a 1.3% speedup in Llama 3 8B transformer layers in MXFP8 and a 2.6% speedup in NVFP4.

The core technical principle of MXNorm lies in its integration of normalization with the MX quantization process, reusing MXFP8 block scales to estimate RMS. This innovative approach significantly reduces redundant computation overhead, enhancing training efficiency and speed. By reducing the reduction size requirement, MXNorm decreases computational overhead, making it particularly suitable for pre-training large language models, allowing for significant speed improvements without substantial accuracy loss.

However, MXNorm also has some limitations. It shows slightly higher training loss than RMSNorm on larger models (e.g., 8B parameters), indicating potential accuracy loss in certain scenarios. Additionally, MXNorm's stability is poorer at higher learning rates, with a tendency to exhibit loss spikes. These issues need to be addressed in future research.

Overall, the introduction of MXNorm is significant for both academia and industry as it addresses the performance bottleneck in reduction and elementwise computations, which lag behind the improvements in low-precision matrix multiplication. Future research directions include further optimizing MXNorm's stability, exploring its application in other types of neural network models, and its performance across different hardware platforms.

Deep Analysis

Background

In recent years, deep learning has made significant progress in natural language processing, computer vision, and scientific fields such as molecular biology. This progress has been enabled by leaps in AI accelerator capabilities, particularly in the acceleration of low-precision matrix multiplications. Over the past eight years, GPU acceleration of low-precision matrix multiplications has improved by 80 times, allowing researchers and practitioners to scale pre-training of transformer-style neural networks and learn from vast pools of unlabelled data. However, as matrix multiplication becomes less of a bottleneck to throughput, other components of model architectures emerge as new bottlenecks. Indeed, other aspects of AI accelerators have not kept pace with improvements in matrix multiplication throughput. For example, elementwise operations and reductions are limited by memory bandwidth and CUDA core throughput in GPUs; however, these have only improved by factors of 8.9x and 5.1x, respectively, over the past eight years. Furthermore, this gap is set to widen with upcoming GPU architecture releases.

Core Problem

While the improvement in matrix multiplication performance has addressed the bottleneck in scaling deep learning workloads, the performance of reductions and elementwise computations has lagged, becoming a new bottleneck. These operations are limited by memory bandwidth and CUDA core throughput, which have only improved by factors of 8.9x and 5.1x, respectively, over the past eight years. Furthermore, this gap is set to widen with upcoming GPU architecture releases. In some cases, these operations can be hidden by overlapping them with matrix multiplications, but others require too much memory to pipeline in practice. Therefore, the community needs to consider new building blocks that contribute a smaller overhead.

Innovation

The core innovation of MXNorm lies in its integration of normalization with the MX quantization process, reusing MXFP8 block scales to estimate RMS. This method significantly reduces redundant computation overhead, enhancing training efficiency and speed. Specifically:


  • �� MXNorm estimates RMS by reusing block scales calculated during MXFP8 quantization, achieving efficient normalization and reducing the reduction size requirement by 32 times.

  • �� MXNorm serves as a replacement for RMSNorm and was validated on pre-training Llama 3 models, showing minimal training accuracy loss compared to a baseline using RMSNorm with MXFP8 matrix multiplications.

  • �� Practical kernel speedups using torch.compile demonstrated MXNorm to be up to 2.4x faster than RMSNorm, translating to a 1.3% speedup in Llama 3 8B transformer layers in MXFP8 and a 2.6% speedup in NVFP4.

Methodology

The implementation of MXNorm involves the following steps:


  • �� First, MXNorm estimates RMS by reusing block scales calculated during MXFP8 quantization, reducing the reduction size requirement by 32 times.

  • �� Then, MXNorm was validated on pre-training Llama 3 models with 125M, 1B, and 8B parameters. Experimental results showed minimal training accuracy loss compared to a baseline using RMSNorm with MXFP8 matrix multiplications.

  • �� Finally, practical kernel speedups using torch.compile demonstrated MXNorm to be up to 2.4x faster than RMSNorm, translating to a 1.3% speedup in Llama 3 8B transformer layers in MXFP8 and a 2.6% speedup in NVFP4.

Experiments

The experimental design involved validating the performance of MXNorm on pre-training Llama 3 models with 125M, 1B, and 8B parameters. The baseline used was RMSNorm with MXFP8 matrix multiplications. The experiments utilized torch.compile to implement kernel acceleration for MXNorm and were tested on actual hardware. Additionally, a learning rate sensitivity test was conducted to evaluate the stability and performance of MXNorm under different learning rates. The results showed minimal training accuracy loss and significant speed improvements in kernel acceleration.

Results

The experimental results demonstrated that MXNorm performed excellently in pre-training Llama 3 models. On 125M, 1B, and 8B parameter models, MXNorm showed minimal training accuracy loss compared to RMSNorm. Additionally, using torch.compile for kernel acceleration, MXNorm was up to 2.4x faster than RMSNorm, translating to a 1.3% speedup in Llama 3 8B transformer layers in MXFP8 and a 2.6% speedup in NVFP4. On the 8B parameter model, MXNorm(p=2) achieved a training loss comparable to RMSNorm, with losses of 2.126 and 2.132 respectively, while MXNorm(p=1) had a worse final loss of 2.175.

Applications

MXNorm's application scenarios include pre-training large language models, particularly in situations requiring efficient normalization and reduced computational overhead. MXNorm allows for significant speed improvements without substantial accuracy loss. Additionally, MXNorm demonstrated excellent performance in practical kernel acceleration, making it suitable for high-performance computing scenarios such as natural language processing and computer vision.

Limitations & Outlook

While MXNorm performed excellently in experiments, it also has some limitations. On larger models (e.g., 8B parameters), MXNorm showed slightly higher training loss than RMSNorm, indicating potential accuracy loss in certain scenarios. Additionally, MXNorm's stability is poorer at higher learning rates, with a tendency to exhibit loss spikes. These issues need to be addressed in future research.

Plain Language Accessible to non-experts

Imagine you're in a kitchen cooking. Traditionally, after each dish, you'd wash all the pots and pans, much like RMSNorm, which recalculates all normalization parameters each time. MXNorm, however, is like a clever chef who reuses some already cleaned tools during the cooking process, like using the same pot for different dishes. This not only saves time but also reduces the number of washes. Similarly, MXNorm reduces computational overhead by reusing block scales calculated during MXFP8 quantization to estimate RMS. This way, you can cook more dishes faster without having to start from scratch each time. This method is particularly suitable for scenarios requiring quick processing of large amounts of data, like training large language models.

ELI14 Explained like you're 14

Hey there, young pals! Today, I'm going to tell you a story about supercomputers. Imagine you're playing a huge puzzle game where each piece is tiny, and you need to put them together to see the full picture. The traditional way is like rearranging all the puzzle pieces every time, which is slow, right?

Now, there's a super-smart helper called MXNorm with a brilliant idea! It found a way to reuse some already arranged puzzle pieces, so you don't have to start from scratch every time. It's like using your organized notes from last time to finish your homework faster.

This MXNorm helper is especially good at handling large amounts of data, just like supercomputers when dealing with complex language models, allowing them to complete tasks faster without recalculating all the data each time.

So, next time you're playing a puzzle game or doing homework, remember to try this smart method! It'll help you get things done faster and easier!

Glossary

MXNorm

MXNorm is a novel tensor normalization method that estimates RMS by reusing block scales calculated during MXFP8 quantization, achieving efficient normalization.

In this paper, MXNorm is used as a replacement for traditional RMSNorm to reduce computational overhead.

RMSNorm

RMSNorm is a normalization method that normalizes tensors by calculating their root mean square (RMS).

In the pre-training of Llama 3 models, RMSNorm is replaced by MXNorm.

MXFP8

MXFP8 is a low-precision quantization format used to accelerate matrix multiplication.

In this paper, MXFP8 is used to quantize tensors to improve computational efficiency.

Quantization

Quantization is the process of converting high-precision data into a low-precision format to reduce computational and storage overhead.

In this paper, MXFP8 quantization is used to accelerate matrix multiplication.

Normalization

Normalization is the process of adjusting data scales to ensure they fall within a certain range, improving model training stability and efficiency.

In this paper, MXNorm is used for efficient tensor normalization.

Llama 3 Model

Llama 3 is a large-scale language model used for natural language processing tasks.

In this paper, the Llama 3 model is used to validate the performance of MXNorm.

torch.compile

torch.compile is a compilation tool in PyTorch used to optimize model computational efficiency.

In this paper, torch.compile is used to implement kernel acceleration for MXNorm.

Kernel Acceleration

Kernel acceleration is a method of optimizing computational processes to improve speed and efficiency.

In this paper, MXNorm demonstrates faster computational speed than RMSNorm through kernel acceleration.

Large-scale Language Model

A large-scale language model is a deep learning model used for natural language tasks, typically with a large number of parameters and complex structures.

In this paper, Llama 3 is a large-scale language model.

Root Mean Square (RMS)

Root mean square is a statistical measure of data, representing the square root of the average of squared values.

In this paper, RMS is used for tensor normalization.

Open Questions Unanswered questions from this research

  • 1 MXNorm's stability is poorer at higher learning rates, with a tendency to exhibit loss spikes. Future research needs to explore ways to improve MXNorm's stability at higher learning rates.
  • 2 On larger models, MXNorm shows slightly higher training loss than RMSNorm, indicating potential accuracy loss in certain scenarios. Further research is needed to improve MXNorm's performance without accuracy loss.
  • 3 MXNorm's performance depends on specific hardware and compiler optimizations, which may vary across platforms. Future research needs to explore ways to enhance MXNorm's generality and adaptability.
  • 4 Current research focuses primarily on pre-training large language models, and MXNorm's application in other types of neural network models has yet to be fully validated.
  • 5 The effectiveness of MXNorm in practical applications, especially across different hardware platforms and application scenarios, requires further experimental validation.

Applications

Immediate Applications

Pre-training Large Language Models

MXNorm can significantly improve the training speed of large language models without substantial accuracy loss, suitable for scenarios requiring efficient normalization.

Natural Language Processing

In natural language processing tasks, MXNorm can be used to accelerate model training and inference, improving computational efficiency.

Computer Vision

In computer vision tasks, MXNorm can reduce computational overhead, improving model real-time performance and response speed.

Long-term Vision

General AI Accelerators

The successful application of MXNorm may drive the development of general AI accelerators, supporting a wider range of deep learning models and application scenarios.

Cross-platform Optimization

In the future, MXNorm may be optimized across different hardware platforms, promoting the development of cross-platform deep learning applications.

Abstract

Matrix multiplication performance has long been the major bottleneck to scaling deep learning workloads, which has stimulated the design of new accelerators that use increasingly low-precision number formats. However, improvements in matrix multiplication performance have far outstripped improvements in performance on reductions and elementwise computations, which are still being performed in higher precision. In this work, we propose MXNorm, a drop-in replacement for RMSNorm that estimates the RMS using only the block scales calculated as part of the MXFP8 cast and enables a 32x decrease in the size of reduction needed for normalization. We validate our approximation method on pre-training of Llama 3 models of 125M, 1B and 8B parameters, finding minimal loss of training accuracy compared to a baseline using RMSNorm with MXFP8 matmuls. We also show practical kernel speedups using only torch.compile of up to 2.4x for MXNorm over RMSNorm, corresponding to a 1.3% speedup in Llama 3 8B transformer layers in MXFP8 and a 2.6% speedup in NVFP4.

cs.LG cs.AI cs.NE

References (20)

How not to lie with statistics: the correct way to summarize benchmark results

P. Fleming, J. J. Wallace

1986 494 citations ⭐ Influential

The Llama 3 Herd of Models

Abhimanyu Dubey, Abhinav Jauhri, Abhinav Pandey et al.

2024 13359 citations ⭐ Influential View Analysis →

TorchAO: PyTorch-Native Training-to-Serving Model Optimization

Andrew Or, Apurva Jain, Daniel Vega-Myhre et al.

2025 7 citations ⭐ Influential View Analysis →

Recipes for Pre-training LLMs with MXFP8

Asit K. Mishra, Dusan Stosic, Simon Layton

2025 11 citations ⭐ Influential View Analysis →

Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

Sergey Ioffe, Christian Szegedy

2015 46189 citations View Analysis →

Efficient Streaming Language Models with Attention Sinks

Guangxuan Xiao, Yuandong Tian, Beidi Chen et al.

2023 1489 citations View Analysis →

Small-scale proxies for large-scale Transformer training instabilities

Mitchell Wortsman, Peter J. Liu, Lechao Xiao et al.

2023 152 citations View Analysis →

Gemma 2: Improving Open Language Models at a Practical Size

Gemma Team Morgane Riviere, Shreya Pathak, Pier Giuseppe Sessa et al.

2024 1751 citations View Analysis →

Training Deep Learning Models with Norm-Constrained LMOs

T. Pethick, Wanyun Xie, Kimon Antonakopoulos et al.

2025 84 citations View Analysis →

OLMES: A Standard for Language Model Evaluations

Yuling Gu, Oyvind Tafjord, Bailey Kuehl et al.

2024 63 citations View Analysis →

GLU Variants Improve Transformer

Noam Shazeer

2020 1658 citations View Analysis →

Et al

P. Cochat, L. Vaucoret, J. Sarles

2008 74049 citations

Massive Activations in Large Language Models

Mingjie Sun, Xinlei Chen, J. Z. Kolter et al.

2024 177 citations View Analysis →

Microscaling Data Formats for Deep Learning

B. Rouhani, Ritchie Zhao, A. More et al.

2023 144 citations View Analysis →

PaLM: Scaling Language Modeling with Pathways

A. Chowdhery, Sharan Narang, Jacob Devlin et al.

2022 7747 citations View Analysis →

Layer Normalization

Jimmy Ba, J. Kiros, Geoffrey E. Hinton

2016 12132 citations View Analysis →

Query-Key Normalization for Transformers

Alex Henry, Prudhvi Raj Dachapally, S. Pawar et al.

2020 195 citations View Analysis →

RoFormer: Enhanced Transformer with Rotary Position Embedding

Jianlin Su, Yu Lu, Shengfeng Pan et al.

2021 4514 citations View Analysis →

FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision

Jay Shah, Ganesh Bikshandi, Ying Zhang et al.

2024 412 citations View Analysis →

LLaMA: Open and Efficient Foundation Language Models

Hugo Touvron, Thibaut Lavril, Gautier Izacard et al.

2023 18909 citations View Analysis →