Enabling 70B Finetuning on Consumer GPUs

A Technical Deep Dive into FSDP+QLoRA

A detailed guide for adding FSDP and QLoRA support to quantization libraries and training frameworks.
Authors

Benjamin Warner

Johno Whitaker

Kerem Turgutlu

Published

March 14, 2024

Introduction

Answer.AI recently announced FSDP+QLoRA, an open-source project enabling the finetuning of models up to 70 billion parameters on two high-end consumer GPUs. This companion post provides a detailed look into the implementation of FSDP+QLoRA and is relevant if:

  • You’d like to add FSDP support to a new quantization library
  • You want to add FSDP+QLoRA training to your own training framework or script
  • You’re curious about the details of how FSDP+QLoRA works

We’ve split this post into two sections. The first covers how to update a new quantization library for FSDP and the second shows how to integrate FSDP+QLoRA finetuning into a training framework or script.

The code examples within this post have been simplified for legibility. For the full code please reference our FSDP & QLoRA example script and our updates to bitsandbytes and HQQ.

Integrate a New Quantization Library

The primary blocker for enabling a new quantization method to work with FSDP+QLoRA is that most libraries store quantized weights in integer datatypes while FSDP only supports float datatypes. Ideally the quantized weight data type will match the non-quantized layers data type for easier and more efficient FSDP model sharding. Furthermore, the float dtype should be user specifiable so that Mixed Precision training doesn’t convert quantized weights stored in float32 to bfloat16, effectively randomizing the weights.

There are at least two methods quantization maintainers can use to store quantized weights in float datatypes for FSDP1:

  • Use torch.view to convert to and from float dtypes without modifying the underlying bytes.
  • Read quantized weights in a datatype agnostic method in the dequantization kernel

Of these two methods, HQQ uses the first, while bitsandbytes uses the second.

HQQ

The example code below shows how HQQ’s HQQLinear.quantize and HQQLinear.dequantize methods have been modified to support FSDP training by viewing int dtype quantized weights as a selectable float dtype when quantizing, and then converting back to an int dtype before being passed to the HQQ dequantization kernel.

@classmethod
def quantize(
    cls,
    *args,
    compute_dtype:torch.dtype=None,
    view_as_float:bool=False
):
    meta['view_as_float'] = view_as_float
    W_q = Quantizer.pack[meta['packing']](W_q)
    if view_as_float:
        # store quantized weights as compute_dtype
        W_q = W_q.view(torch.float32 if compute_dtype is None else compute_dtype)

@classmethod
def dequantize(cls, W_q:torch.Tensor, meta:dict):
    compute_dtype = meta['compute_dtype'] if ('compute_dtype' in meta) else torch.float16
    if meta['view_as_float']:
        W_q = W_q.view(meta['unpack_view_dtype'])

Currently, HQQ stores the quantized weights as the same dtype as the compute type. Splitting the compute_type from the quantized weight storage dtype quant_storage as bitsandbytes would enable all of FSDP’s Mixed Precision options, but the current implementation is already usable with FSDP.

For more details and the full set of changes made to HQQ, see the pull request by Kerem and the mobius.ml folks.

bitsandbytes

bitsandbytes had already implemented their quantization and dequantization kernels to read and write raw bytes using a StoreChar.

__global__ void kQuantizeBlockwise(unsigned char *out, ...)
{
    unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH];
    __shared__ typename StoreChar::TempStorage storec;
    for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
    {
        // store 8-bit, FP4, or NF4 quantized qvals in storec as bytes or packed bytes
        StoreChar(storec).Store(&(out[(DATA_TYPE > 0) ? i/2 : i]), qvals, (DATA_TYPE > 0) ? (valid_items+1)/2 : valid_items);
    }
}


__global__ void kDequantizeBlockwise(T *out, ...)
{
    unsigned char qvals[NUM_PER_TH];
    __shared__ typename LoadChar::TempStorage loadchar;
    __shared__ typename StoreT::TempStorage storet;

    for (unsigned int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE)
    {
        // load bytes or packed bytes into qvals
        LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128);

        // after dequantizing qvals, store result in the PyTorch Tensor out
        StoreT(storet).Store(&(out[(DATA_TYPE > 0) ? i*2 : i]), vals, valid_items_store);
    }
}

This data type agnostic implementation allowed us to modify bitsandbyte’s quantize_4bit and dequantize_4bit to have a settable quant_storage dtype, independent of the computation data type compute_dtype.

def quantize_4bit(A: Tensor, *args, quant_storage=torch.uint8) -> Tensor:
    # create an out Tensor with the quantized shape of dtype=quant_storage
    n = A.numel()
    mod = dtype2bytes[quant_storage] * 2
    out = torch.zeros(((n+1)//mod, 1), dtype=quant_storage, device=A.device)

    # quantize the input tensor A using the blockwise quantization Cuda kernel
    lib.cquantize_blockwise_bf16_nf4(get_ptr(A), get_ptr(out), ct.c_int(n)
)

def dequantize_4bit(A: Tensor, *args, out: Tensor = None) -> Tensor:
    if out is None:
        out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device)

    n = out.numel()

    lib.cdequantize_blockwise_bf16_nf4(get_ptr(A), get_ptr(out), ct.c_int(n))
    return out

Then we added quant_storage as an argument to Linear4Bit and Params4Bit, defaulting to torch.uint8 for compatibility with existing codebases. The settable quant_storage supports all of FSDP’s Mixed Precision options.

Additional Changes

bitsandbytes needed two additional modifications for FSDP compatibility not needed by HQQ.

First, bitsandbytes stored its quantization metadata dictionary in its Params4Bit, a subclass of PyTorch’s Parameters. Unlike a regular PyTorch to, which would preserve this information, FSDP creates sharded instances of the model’s parameters and buffers. The sharding process doesn’t persist or copy quantization metadata, which would prevent QLoRA training or inference.

We resolved this blocker by reusing an in-progress solution from Tim Dettmers et al.: creating a copy of the quantization metadata on the Linear4bitlayer and passing it to the Params4Bitduring the layer’s forward pass.

Second, bitsandbytes automatically quantized whenever a bitsandbytes layer is moved to the GPU. It didn’t check whether the layer was already quantized, which meant training with FSDP ran the risk of quantizing already quantized weights, as layers are constantly being moved from CPU to GPU. We resolved this potential issue by adding a quantization Boolean to prevent already quantized weights from being quantized again.

For more details and the full set of changes made to bitsandbytes, see the pull request by Benjamin and Kerem.

Integrate a New Trainer

When integrating a supported quantization method into a training framework or script, there are two main factors to consider:

  • Loading and quantizing a model to use minimal CPU memory and create quantization metadata for all devices.
  • Correctly initializing FSDP wrapping policies, mixed precision, and sharding settings for QLoRA finetuning.

This section assumes you are familiar with the basics of setting up and training models with FSDP2.

Loading and Quantizing a Model

The default method of initializing and loading a model for FSDP requires one model per GPU (or rank). Loading a 70B model in BFloat16 on a two GPU system would require roughly 280GB of CPU or GPU RAM just to initialize the model before FSDP sharding3.

However, FSDP also has a low memory option sync_module_states=True which only requires the model weights to be initialized or loaded on the rank 0 process. Then FSDP will shard the model across all GPU ranks. This combined with quantization reduces the model weights to ~35GB for a single node, no matter how many ranks4.

This section will highlight how to load and quantize a model using FSDP’s low memory option, so the quantization metadata is placed on all devices. While we only show how to quantize unquantized models, a similar process can be used to load pre-quantized models.

To simplify the sample code, we will only show how to load bitsandbytes models using Hugging Face’s Transformers library, but the process is the same for other model libraries and quantization methods which support FSDP such as HQQ.

Initializing the Model

While FSDP only needs a model with weights on the rank 0 process, it requires an instance of the model on all processes. To prevent these other instances from using memory, we will initialize the models on all ranks using PyTorch’s meta device and then load the model weights into the rank 0 process.

Depending on the model implementation, it might be possible to use the meta device initialization

with torch.device(‘meta’):
    model = Model()

but Hugging Face models create non-persistent buffers on initialization5. So instead we will use Accelerate’s init_empty_weights, which by default creates the buffers and initializes the model’s parameters on the meta device6.

We initialize the model and immediately replace PyTorch’s nn.Linear with bitsandbyte’s Linear4bit layer, which implements bitsandbyte’s four-bit quantization.

cfg = AutoConfig.from_pretrained(model_name)
cfg.use_cache = False

with accelerate.init_empty_weights():
    model = AutoModelForCausalLM.from_config(cfg)
    model.model = replace_linear(model.model, Linear4bit, compute_dtype=compute_dtype,
                                 quant_type='nf4', quant_storage=storage_dtype)
model.is_loaded_in_4bit = True

One important argument to note is the new quant_storage. As mentioned in the quantization section, FSDP cannot shard integer data types, and by default bitsandbytes (and most other quantization libraries) store quantized weights in uint8. quant_storage allows users to set the storage for the quantized weights to a float data type. Ideally the same data type as the rest of the model for optimal FSDP sharding. More on this in wrapping policy and mixed precision policy sections.

Loading the Model Weights

With the model initialized across all ranks, the next step is to load the pretrained weights and quantize them. Normally for low memory FSDP we’d only load the model on the rank 0 process, but quantization presents a wrinkle.

If the model is only quantized on rank 0, the quantization metadata will only exist on rank 0 and will not be synced since FSDP cannot sync dictionaries. Furthermore, after quantizing the model on rank 0, the other rank’s parameters will be their original non-quantized shape, while the rank 0 model parameters will be half-size or smaller due to quantization, preventing FSDP from sharding the model.

To resolve this issue, we will load and quantize the model across all ranks layer by layer, but on ranks 1+ we will immediately discard the newly quantized layer’s weights, leaving us with a fully loaded model on the rank 0 process and the correctly shaped meta models with quantization metadata on all other ranks.

We use safetensors to iterate through all model shards and layers per shard, then call load_and_quantize passing the model, the name of the current weight, the weight itself, the device and data type to quantize the layer on, and two low memory loading options: to_cpu and to_meta. Rank 0 sets to_cpu=True while all other ranks set to_meta=True.

def load_and_quantize(module:nn.Module, name:str, value:Tensor, device:torch.device,
                      dtype:torch.dtype, to_cpu:bool=False, to_meta:bool=False):

    module_key, _, value_key = name.rpartition('.')
    submodule = module.get_submodule(module_key)

    try:
        param = submodule.get_parameter(value_key)
        if isinstance(param, Params4bit):
            # quantize the layer using bitsandbytes
            value = type(param)(value.to(device, dtype).data, **param.__dict__).cuda(device)
            if to_meta:
                value = type(param)(value.data.to("meta"), **value.__dict__)
            elif to_cpu:
                value = type(param)(value.data.to("cpu"), **value.__dict__)
        else:
            value = type(param)(place_on_device(value).data)

    except AttributeError:
        value = place_on_device(value) # it's a buffer
        pass
    setattr(submodule, value_key, value)

After grabbing the correct parameter from the model, we quantize the parameter data using Params4bit.cuda which in addition creates the quantization metadata on all ranks.

We cannot call Linear4bit.to or Param4bit.to as this will place the quantization metadata on "meta" or "cpu", so we instead recreate Param4bit on the final device, leaving the quantization metadata on the current device. If the layer is a non-quantized parameter or buffer, it is placed on the correct device using the place_on_device method, whose definition is excluded for brevity, and added to the model using setattr.

This entire process leaves us with a fully loaded and quantized model in CPU RAM for the first process, and meta models with the correct layer shapes and quantization metadata for all other ranks.

Initializing the LoRA Layers

The last thing to do before configuring FSDP is to initialize the LoRA layers. We use PEFT’s LoraConfig, targeting all linear layers and apply the config to the model using get_peft_model7.

peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=lora_rank,
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    target_modules=lora_target_modules,
)

if rank!=0:
    setup_quantized_meta_for_peft(model)

model = get_peft_model(model, peft_config)

setup_quantized_peft_meta_for_training(model)

The current release of PEFTwill unhelpfully move the quantization metadata quant_state to the meta device, undoing our careful model loading8. As a temporary workaround we remove the ability to modify quant_state device, apply PEFT to prepare the model for LoRA finetuning, and then restore the normal quant_state.to behavior.

def setup_quantized_meta_for_peft(model:nn.Module):
    def temp_to_method(self, *args, **kwargs):
        return self
    for param in model.parameters():
        if isinstance(param, Params4bit):
            param.quant_state._orig_to = param.quant_state.to
            param.quant_state.to = types.MethodType(temp_to_method, param.quant_state)

def setup_quantized_peft_meta_for_training(model:nn.Module):
    for param in model.parameters():
        if isinstance(param, Params4bit) and hasattr(param.quant_state, '_orig_to'):
            param.quant_state.to = param.quant_state._orig_to
            param.quant_state._orig_to = None

This step isn’t needed with a custom LoRA implementation, and hopefully won’t be needed in the future as PEFT adds support for FSDP+QLoRA.

Setting Up FSDP for QLoRA Finetuning

Configuring FSDP for QLoRA finetuning requires a few modifications to a standard FSDP setup.

  • LoRA layers require a custom FSDP wrapping policy
  • Different mixed precision setups require specific quantization weight types
  • Low memory model loading requires specific FSDP initialization arguments

The next three sections will cover these modifications in detail.

Wrapping Policy

FSDP uses wrapping policies to determine how to split models for sharding. As a rule of thumb, each shard should have at least a million parameters. However, the LoRA layers will need to be wrapped separately due to wrapping restrictions. The two wrapping restrictions are:

  • Wrapped layers cannot have a mixture of requries_grad=True and requries_grad=False
  • Wrapped layers cannot have a mixture of different floating types

To wrap our LoRA model, we’ll use a wrapping method straight from Meta’s llama-recipes. get_wrapping_policy defines lambda_policy_fn to identify any LoRA layer implementation. Then it passes that policy into the FSDP lambda_auto_wrap_policy and the Transformers LlamaDecoderLayer into FSDP transformer_auto_wrap_policy. This will wrap each decoder layer as a separate shard while splitting out the LoRA layers into their own shards.

def get_wrapping_policy(transformer_layer_name:nn.Module=LlamaDecoderLayer):
    def lambda_policy_fn(module):
        return (
            len(list(module.named_children())) == 0
            and getattr(module, "weight", None) is not None
            and module.weight.requires_grad
        )

    lambda_policy = partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)

    transformer_wrap_policy = partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls=transformer_layer_name
    ),

    return partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])

Then we put it all together by combining the two polices together with an FSDP _or_policy.

The quantized weight storage type quant_storage should match the model’s non-quantized weight datatype so the FSDP wrapping policy will combine the quantized layers with non-quantized layers. Should these data types differ, we’d need to adjust our wrapping policy to separate our quantized layers from the model’s non-quantized non-LoRA layers.

Mixed Precision Policy

FSDP’s mixed precision differs from PyTorch’s automatic mixed precision in that the user has complete control over the parameter, gradient reduction, and buffer data types, although the two can be combined.

For example, the following two mixed precision policies are designed to work with FSDP’s mixed precision and AMP’s autocast respectively. The first policy will perform all mixed precision options in bfloat16, while the second when combined with amp.autocast during the model’s forward pass will dynamically cast to bfloat16 on a per layer basis.

MixedPrecision(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.bfloat16,
    buffer_dtype=torch.bfloat16
)

# cast to bfloat16 using amp.autocast
MixedPrecision(
    param_dtype=torch.float32,
    reduce_dtype=torch.float32,
    buffer_dtype=torch.float32
)

It’s important to set the quantized weight storage type quant_storage to match the param_dtype, otherwise the quantized weights can be cast to a different floating-point type, rendering them into random weights or causing a computation type mismatch. The first mixed precision example requires quant_storage=torch.bfloat16 while the second requires quant_storage=torch.float32.

It’s also possible to exclude layers from the mixed precision policy, which will override our wrapping policy and wrap these layers separately. This setting can be used for quantization libraries that cannot set the quantized weights to a specific floating-point type so the quantized weights are not modified by FSDP’s mixed precision policies.

FSDP for Low Memory Loading

With our LoRA-aware wrapping and optional mixed precision policies defined, we can shard our model using FSDP.

Outside of our wrapping and optional mixed precision policies we enable FSDP’s CPUOffload when training larger-than-GPU-memory models and sync_module_states to only require loading one copy of the model into CPU or GPU RAM.

model = FSDP(
    model,
    auto_wrap_policy=wrapping_policy,
    mixed_precision=mixed_precision_policy,
    cpu_offload=CPUOffload(offload_params=True) if cpu_offload else None,
    sync_module_states=low_memory,
    device_id=torch.cuda.current_device(),
    param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
        if (rank!=0 and low_memory) else None,
)

The last two arguments, device_id and param_init_fn, are required if we are using sync_module_states to lower the model loading memory requirements. These arguments must be initialized in this manner so FSDP can convert our meta device models into regular models while sharding and syncing the model weights from the rank 0 process.

Sharding our quantized 70B model with FSDP across two GPUs requires just over 128GB of CPU RAM9. Memory usage settles down to ~110GB during training due to CPU offloading.

Conclusion

At the time of publication, initial support for FSDP+QLoRA has been added to development builds of Axolotl and the Hugging Face ecosystem10.

We hope that this technical deep dive will accelerate adding FSDP+QLoRA support to more quantization methods and training frameworks.

If you are interested in integrating FSDP+QLoRA support into your quantization library or training framework, feel free to reach out to us. Otherwise stay tuned for our upcoming benchmarking post where we explore how to make the most of this new approach.

Footnotes

  1. If you implement a new method for your library, please let us know so we can update this list↩︎

  2. If you are not familiar with FSDP, PyTorch has two tutorials which can catch you up to speed: Getting Started with Fully Sharded Data Parallel (FSDP) & Advanced Model Training with Fully Sharded Data Parallel (FSDP).↩︎

  3. A 70 billion parameter model at half precision requires two bytes per parameter, so roughtly 140GB just to load one copy of the model.↩︎

  4. FSDP has additional overhead for sharding. In practice sharding a quantized 70B model across two GPUs requires a minimum of 128GB plus 10-20GB of scratch disk space.↩︎

  5. Non-persistent buffers are not saved by default, so we need to make sure to recreate them when initializing our model.↩︎

  6. Other training frameworks, such as PyTorch Lightning and Composer also provide similar initialization methods.↩︎

  7. We cannot use PEFT’s prepare_for_kbit_training because it assumes quantized weights are stored as int8 and upcasts everything else to float32, and by default sets up gradient checkpointing which needs to be handled separately when using FSDP.↩︎

  8. The current release at the time of publication is PEFT 0.9.0.↩︎

  9. A 10-20GB scratch disk is enough to handle this initial memory spike.↩︎

  10. Initial FSDP+QLoRA has support has been added to Accelerate, Transformers, TRL, and PEFT.↩︎