MXNorm: Reusing MXFP block scales for efficient tensor normalisation
MXNorm reuses MXFP8 block scales for efficient tensor normalization, reducing reduction size by 32x.
Key Findings
Methodology
This paper introduces MXNorm, a novel tensor normalization method that estimates the RMS by reusing block scales calculated during MXFP8 quantization, achieving efficient normalization. MXNorm serves as a replacement for RMSNorm, reducing the reduction size required for normalization by 32 times. The method was validated on pre-training Llama 3 models with 125M, 1B, and 8B parameters, showing minimal loss in training accuracy compared to a baseline using RMSNorm with MXFP8 matrix multiplications. Additionally, practical kernel speedups using torch.compile showed MXNorm to be up to 2.4x faster than RMSNorm.
Key Results
- In pre-training Llama 3 models, MXNorm showed minimal training accuracy loss on 125M, 1B, and 8B parameter models compared to RMSNorm, indicating its effectiveness as a replacement.
- Using torch.compile for kernel acceleration, MXNorm was up to 2.4x faster than RMSNorm, translating to a 1.3% speedup in Llama 3 8B transformer layers in MXFP8 and a 2.6% speedup in NVFP4.
- On the 8B parameter model, MXNorm(p=2) achieved a training loss comparable to RMSNorm, with losses of 2.126 and 2.132 respectively, while MXNorm(p=1) had a worse final loss of 2.175.
Significance
The introduction of MXNorm is significant for both academia and industry as it addresses the performance bottleneck in reduction and elementwise computations, which lag behind the improvements in low-precision matrix multiplication. By reducing the reduction size requirement, MXNorm decreases computational overhead, enhancing training efficiency and speed. This method is particularly beneficial for pre-training large language models, allowing for significant speed improvements without substantial accuracy loss.
Technical Contribution
The technical contribution of MXNorm lies in its innovative fusion of normalization with the MX quantization process, reducing redundant computation overhead. Compared to existing RMSNorm methods, MXNorm estimates RMS by reusing MXFP8 block scales, reducing the reduction size requirement by 32 times. Furthermore, MXNorm demonstrates significant speed improvements in practical kernel acceleration, showcasing its potential in engineering applications.
Novelty
The novelty of MXNorm is in its first-time integration of normalization with the MX quantization process, reusing MXFP8 block scales to estimate RMS. This method significantly reduces computational overhead compared to traditional RMSNorm methods, enhancing training efficiency.
Limitations
- MXNorm shows slightly higher training loss than RMSNorm on larger models (e.g., 8B parameters), indicating potential accuracy loss in certain scenarios.
- MXNorm's stability is poorer at higher learning rates, with a tendency to exhibit loss spikes.
- MXNorm's performance depends on specific hardware and compiler optimizations, which may vary across platforms.
Future Work
Future research directions include further optimizing MXNorm's stability, especially at higher learning rates. Additionally, exploring MXNorm's application in other types of neural network models and its performance across different hardware platforms is crucial. Research on enhancing MXNorm's generality and adaptability without relying on specific hardware optimizations is also a promising direction.
AI Executive Summary
In the rapid development of deep learning, the improvement of matrix multiplication performance has been a key driver for scaling large models. However, while low-precision matrix multiplication has seen significant acceleration, the performance of reductions and elementwise computations has lagged, becoming a new bottleneck. To address this issue, this paper introduces a novel tensor normalization method—MXNorm. MXNorm achieves efficient normalization by reusing block scales calculated during MXFP8 quantization, reducing the reduction size requirement by 32 times.
MXNorm serves as a replacement for RMSNorm and was validated on pre-training Llama 3 models with 125M, 1B, and 8B parameters. Experimental results showed minimal training accuracy loss compared to a baseline using RMSNorm with MXFP8 matrix multiplications. Additionally, practical kernel speedups using torch.compile demonstrated MXNorm to be up to 2.4x faster than RMSNorm, translating to a 1.3% speedup in Llama 3 8B transformer layers in MXFP8 and a 2.6% speedup in NVFP4.
The core technical principle of MXNorm lies in its integration of normalization with the MX quantization process, reusing MXFP8 block scales to estimate RMS. This innovative approach significantly reduces redundant computation overhead, enhancing training efficiency and speed. By reducing the reduction size requirement, MXNorm decreases computational overhead, making it particularly suitable for pre-training large language models, allowing for significant speed improvements without substantial accuracy loss.
However, MXNorm also has some limitations. It shows slightly higher training loss than RMSNorm on larger models (e.g., 8B parameters), indicating potential accuracy loss in certain scenarios. Additionally, MXNorm's stability is poorer at higher learning rates, with a tendency to exhibit loss spikes. These issues need to be addressed in future research.
Overall, the introduction of MXNorm is significant for both academia and industry as it addresses the performance bottleneck in reduction and elementwise computations, which lag behind the improvements in low-precision matrix multiplication. Future research directions include further optimizing MXNorm's stability, exploring its application in other types of neural network models, and its performance across different hardware platforms.
Deep Analysis
Background
In recent years, deep learning has made significant progress in natural language processing, computer vision, and scientific fields such as molecular biology. This progress has been enabled by leaps in AI accelerator capabilities, particularly in the acceleration of low-precision matrix multiplications. Over the past eight years, GPU acceleration of low-precision matrix multiplications has improved by 80 times, allowing researchers and practitioners to scale pre-training of transformer-style neural networks and learn from vast pools of unlabelled data. However, as matrix multiplication becomes less of a bottleneck to throughput, other components of model architectures emerge as new bottlenecks. Indeed, other aspects of AI accelerators have not kept pace with improvements in matrix multiplication throughput. For example, elementwise operations and reductions are limited by memory bandwidth and CUDA core throughput in GPUs; however, these have only improved by factors of 8.9x and 5.1x, respectively, over the past eight years. Furthermore, this gap is set to widen with upcoming GPU architecture releases.
Core Problem
While the improvement in matrix multiplication performance has addressed the bottleneck in scaling deep learning workloads, the performance of reductions and elementwise computations has lagged, becoming a new bottleneck. These operations are limited by memory bandwidth and CUDA core throughput, which have only improved by factors of 8.9x and 5.1x, respectively, over the past eight years. Furthermore, this gap is set to widen with upcoming GPU architecture releases. In some cases, these operations can be hidden by overlapping them with matrix multiplications, but others require too much memory to pipeline in practice. Therefore, the community needs to consider new building blocks that contribute a smaller overhead.
Innovation
The core innovation of MXNorm lies in its integration of normalization with the MX quantization process, reusing MXFP8 block scales to estimate RMS. This method significantly reduces redundant computation overhead, enhancing training efficiency and speed. Specifically:
- �� MXNorm estimates RMS by reusing block scales calculated during MXFP8 quantization, achieving efficient normalization and reducing the reduction size requirement by 32 times.
- �� MXNorm serves as a replacement for RMSNorm and was validated on pre-training Llama 3 models, showing minimal training accuracy loss compared to a baseline using RMSNorm with MXFP8 matrix multiplications.
- �� Practical kernel speedups using torch.compile demonstrated MXNorm to be up to 2.4x faster than RMSNorm, translating to a 1.3% speedup in Llama 3 8B transformer layers in MXFP8 and a 2.6% speedup in NVFP4.
Methodology
The implementation of MXNorm involves the following steps:
- �� First, MXNorm estimates RMS by reusing block scales calculated during MXFP8 quantization, reducing the reduction size requirement by 32 times.
- �� Then, MXNorm was validated on pre-training Llama 3 models with 125M, 1B, and 8B parameters. Experimental results showed minimal training accuracy loss compared to a baseline using RMSNorm with MXFP8 matrix multiplications.
- �� Finally, practical kernel speedups using torch.compile demonstrated MXNorm to be up to 2.4x faster than RMSNorm, translating to a 1.3% speedup in Llama 3 8B transformer layers in MXFP8 and a 2.6% speedup in NVFP4.
Experiments
The experimental design involved validating the performance of MXNorm on pre-training Llama 3 models with 125M, 1B, and 8B parameters. The baseline used was RMSNorm with MXFP8 matrix multiplications. The experiments utilized torch.compile to implement kernel acceleration for MXNorm and were tested on actual hardware. Additionally, a learning rate sensitivity test was conducted to evaluate the stability and performance of MXNorm under different learning rates. The results showed minimal training accuracy loss and significant speed improvements in kernel acceleration.
Results
The experimental results demonstrated that MXNorm performed excellently in pre-training Llama 3 models. On 125M, 1B, and 8B parameter models, MXNorm showed minimal training accuracy loss compared to RMSNorm. Additionally, using torch.compile for kernel acceleration, MXNorm was up to 2.4x faster than RMSNorm, translating to a 1.3% speedup in Llama 3 8B transformer layers in MXFP8 and a 2.6% speedup in NVFP4. On the 8B parameter model, MXNorm(p=2) achieved a training loss comparable to RMSNorm, with losses of 2.126 and 2.132 respectively, while MXNorm(p=1) had a worse final loss of 2.175.
Applications
MXNorm's application scenarios include pre-training large language models, particularly in situations requiring efficient normalization and reduced computational overhead. MXNorm allows for significant speed improvements without substantial accuracy loss. Additionally, MXNorm demonstrated excellent performance in practical kernel acceleration, making it suitable for high-performance computing scenarios such as natural language processing and computer vision.
Limitations & Outlook
While MXNorm performed excellently in experiments, it also has some limitations. On larger models (e.g., 8B parameters), MXNorm showed slightly higher training loss than RMSNorm, indicating potential accuracy loss in certain scenarios. Additionally, MXNorm's stability is poorer at higher learning rates, with a tendency to exhibit loss spikes. These issues need to be addressed in future research.
Plain Language Accessible to non-experts
Imagine you're in a kitchen cooking. Traditionally, after each dish, you'd wash all the pots and pans, much like RMSNorm, which recalculates all normalization parameters each time. MXNorm, however, is like a clever chef who reuses some already cleaned tools during the cooking process, like using the same pot for different dishes. This not only saves time but also reduces the number of washes. Similarly, MXNorm reduces computational overhead by reusing block scales calculated during MXFP8 quantization to estimate RMS. This way, you can cook more dishes faster without having to start from scratch each time. This method is particularly suitable for scenarios requiring quick processing of large amounts of data, like training large language models.
ELI14 Explained like you're 14
Hey there, young pals! Today, I'm going to tell you a story about supercomputers. Imagine you're playing a huge puzzle game where each piece is tiny, and you need to put them together to see the full picture. The traditional way is like rearranging all the puzzle pieces every time, which is slow, right?
Now, there's a super-smart helper called MXNorm with a brilliant idea! It found a way to reuse some already arranged puzzle pieces, so you don't have to start from scratch every time. It's like using your organized notes from last time to finish your homework faster.
This MXNorm helper is especially good at handling large amounts of data, just like supercomputers when dealing with complex language models, allowing them to complete tasks faster without recalculating all the data each time.
So, next time you're playing a puzzle game or doing homework, remember to try this smart method! It'll help you get things done faster and easier!
Glossary
MXNorm
MXNorm is a novel tensor normalization method that estimates RMS by reusing block scales calculated during MXFP8 quantization, achieving efficient normalization.
In this paper, MXNorm is used as a replacement for traditional RMSNorm to reduce computational overhead.
RMSNorm
RMSNorm is a normalization method that normalizes tensors by calculating their root mean square (RMS).
In the pre-training of Llama 3 models, RMSNorm is replaced by MXNorm.
MXFP8
MXFP8 is a low-precision quantization format used to accelerate matrix multiplication.
In this paper, MXFP8 is used to quantize tensors to improve computational efficiency.
Quantization
Quantization is the process of converting high-precision data into a low-precision format to reduce computational and storage overhead.
In this paper, MXFP8 quantization is used to accelerate matrix multiplication.
Normalization
Normalization is the process of adjusting data scales to ensure they fall within a certain range, improving model training stability and efficiency.
In this paper, MXNorm is used for efficient tensor normalization.
Llama 3 Model
Llama 3 is a large-scale language model used for natural language processing tasks.
In this paper, the Llama 3 model is used to validate the performance of MXNorm.
torch.compile
torch.compile is a compilation tool in PyTorch used to optimize model computational efficiency.
In this paper, torch.compile is used to implement kernel acceleration for MXNorm.
Kernel Acceleration
Kernel acceleration is a method of optimizing computational processes to improve speed and efficiency.
In this paper, MXNorm demonstrates faster computational speed than RMSNorm through kernel acceleration.
Large-scale Language Model
A large-scale language model is a deep learning model used for natural language tasks, typically with a large number of parameters and complex structures.
In this paper, Llama 3 is a large-scale language model.
Root Mean Square (RMS)
Root mean square is a statistical measure of data, representing the square root of the average of squared values.
In this paper, RMS is used for tensor normalization.
Open Questions Unanswered questions from this research
- 1 MXNorm's stability is poorer at higher learning rates, with a tendency to exhibit loss spikes. Future research needs to explore ways to improve MXNorm's stability at higher learning rates.
- 2 On larger models, MXNorm shows slightly higher training loss than RMSNorm, indicating potential accuracy loss in certain scenarios. Further research is needed to improve MXNorm's performance without accuracy loss.
- 3 MXNorm's performance depends on specific hardware and compiler optimizations, which may vary across platforms. Future research needs to explore ways to enhance MXNorm's generality and adaptability.
- 4 Current research focuses primarily on pre-training large language models, and MXNorm's application in other types of neural network models has yet to be fully validated.
- 5 The effectiveness of MXNorm in practical applications, especially across different hardware platforms and application scenarios, requires further experimental validation.
Applications
Immediate Applications
Pre-training Large Language Models
MXNorm can significantly improve the training speed of large language models without substantial accuracy loss, suitable for scenarios requiring efficient normalization.
Natural Language Processing
In natural language processing tasks, MXNorm can be used to accelerate model training and inference, improving computational efficiency.
Computer Vision
In computer vision tasks, MXNorm can reduce computational overhead, improving model real-time performance and response speed.
Long-term Vision
General AI Accelerators
The successful application of MXNorm may drive the development of general AI accelerators, supporting a wider range of deep learning models and application scenarios.
Cross-platform Optimization
In the future, MXNorm may be optimized across different hardware platforms, promoting the development of cross-platform deep learning applications.
Abstract
Matrix multiplication performance has long been the major bottleneck to scaling deep learning workloads, which has stimulated the design of new accelerators that use increasingly low-precision number formats. However, improvements in matrix multiplication performance have far outstripped improvements in performance on reductions and elementwise computations, which are still being performed in higher precision. In this work, we propose MXNorm, a drop-in replacement for RMSNorm that estimates the RMS using only the block scales calculated as part of the MXFP8 cast and enables a 32x decrease in the size of reduction needed for normalization. We validate our approximation method on pre-training of Llama 3 models of 125M, 1B and 8B parameters, finding minimal loss of training accuracy compared to a baseline using RMSNorm with MXFP8 matmuls. We also show practical kernel speedups using only torch.compile of up to 2.4x for MXNorm over RMSNorm, corresponding to a 1.3% speedup in Llama 3 8B transformer layers in MXFP8 and a 2.6% speedup in NVFP4.
References (20)
How not to lie with statistics: the correct way to summarize benchmark results
P. Fleming, J. J. Wallace
The Llama 3 Herd of Models
Abhimanyu Dubey, Abhinav Jauhri, Abhinav Pandey et al.
TorchAO: PyTorch-Native Training-to-Serving Model Optimization
Andrew Or, Apurva Jain, Daniel Vega-Myhre et al.
Recipes for Pre-training LLMs with MXFP8
Asit K. Mishra, Dusan Stosic, Simon Layton
Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift
Sergey Ioffe, Christian Szegedy
Efficient Streaming Language Models with Attention Sinks
Guangxuan Xiao, Yuandong Tian, Beidi Chen et al.
Small-scale proxies for large-scale Transformer training instabilities
Mitchell Wortsman, Peter J. Liu, Lechao Xiao et al.
Gemma 2: Improving Open Language Models at a Practical Size
Gemma Team Morgane Riviere, Shreya Pathak, Pier Giuseppe Sessa et al.
Training Deep Learning Models with Norm-Constrained LMOs
T. Pethick, Wanyun Xie, Kimon Antonakopoulos et al.
OLMES: A Standard for Language Model Evaluations
Yuling Gu, Oyvind Tafjord, Bailey Kuehl et al.
Et al
P. Cochat, L. Vaucoret, J. Sarles
Massive Activations in Large Language Models
Mingjie Sun, Xinlei Chen, J. Z. Kolter et al.
Microscaling Data Formats for Deep Learning
B. Rouhani, Ritchie Zhao, A. More et al.
PaLM: Scaling Language Modeling with Pathways
A. Chowdhery, Sharan Narang, Jacob Devlin et al.
Query-Key Normalization for Transformers
Alex Henry, Prudhvi Raj Dachapally, S. Pawar et al.
RoFormer: Enhanced Transformer with Rotary Position Embedding
Jianlin Su, Yu Lu, Shengfeng Pan et al.
FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision
Jay Shah, Ganesh Bikshandi, Ying Zhang et al.
LLaMA: Open and Efficient Foundation Language Models
Hugo Touvron, Thibaut Lavril, Gautier Izacard et al.