Surprising Truths About GPU Memory for Large Models: How Much Do You Really Need?
Learn how to accurately estimate GPU memory for training large models like Llama-6B, including memory changes with fp32, fp16, and int8 precision.
This article explains how to estimate the required GPU memory based on model parameters, settings, and batch size.
Suppose you want to fully train a llama-6B model. How much GPU memory do you need?
We'll also explore how memory requirements change under fp32, fp16, and int8 modes.
Memory Composition for Large Models
GPU memory for large models consists of three parts: the model itself, the CUDA kernel, and batch size.
The Model Itself
The model's memory needs can be divided into three areas: model parameters, gradients, and optimizer parameters.
Model Parameters
Memory required = number of parameters * memory per parameter.
Consider the impact of precision on memory:
fp32 precision: 32 bits per parameter, 4 bytes.
fp16 precision: 16 bits per parameter, 2 bytes.
int8 precision: 8 bits per parameter, 1 byte.
Gradients
Memory required = number of parameters * memory per gradient.
Optimizer Parameters
The amount of memory depends on the optimizer. For AdamW, it requires twice the model parameters (to store the first and second moments).
CUDA Kernel
CUDA kernel uses around 1.3GB of RAM, as shown below:
torch.ones((1, 1)).to("cuda")
print_gpu_utilization()
>>> GPU memory occupied: 1343 MB
Batch
First, calculate the memory for intermediate variables for each instance in the batch:
Memory = intermediate parameter count * memory per parameter * batch size.
GPU Memory Calculation Example
Let's calculate the memory needed for a Llama-6B model with a batch size of 50 and int8 precision.
The Model Itself
Model parameters: LLaMA-6B with int8 requires 6B * 1 byte = 6GB.
Gradients: Also 6GB.
Optimizer parameters: AdamW for int8 LLaMA-6B requires 6B * 1 byte * 2 = 12GB.
CUDA kernel: 1.3GB.
Total for the model: 6GB + 6GB + 12GB + 1.3GB = 25.3GB.
Batch
LLaMA architecture:
hidden_size = 4096
intermediate_size = 11008
num_hidden_layers = 32
context_length = 2048
For each instance:
Memory = (4096 + 11008) * 2048 * 32 * 1 byte = 990MB.
For a batch size of 50:
Memory = 990MB * 50 = 48.3GB.
Total memory required for Llama-6B with int8 precision and a batch size of 50:
25.3GB + 48.3GB = 73.6GB.
This just fits within an A100 GPU with 80GB RAM, allowing full-parameter fine-tuning of Llama-6B with a batch size of 50 and int8 precision.
You can apply similar calculations for other scenarios based on precision, model size, intermediate variables, and batch size.