# Enabling 70B Finetuning on Consumer GPUs
Benjamin Warner, Johno Whitaker, Kerem Turgutlu
2024-03-14

# Introduction

Answer.AI recently [announced
FSDP+QLoRA](https://www.answer.ai/posts/2024-03-06-fsdp-qlora.html), an
open-source project enabling the finetuning of models up to 70 billion
parameters on two high-end consumer GPUs. This companion post provides a
detailed look into the implementation of FSDP+QLoRA and is relevant if:

- You’d like to add FSDP support to a new quantization library
- You want to add FSDP+QLoRA training to your own training framework or
  script
- You’re curious about the details of how FSDP+QLoRA works

We’ve split this post into two sections. The first covers how to update
a new quantization library for FSDP and the second shows how to
integrate FSDP+QLoRA finetuning into a training framework or script.

The code examples within this post have been simplified for legibility.
For the full code please reference our [FSDP & QLoRA example
script](https://github.com/AnswerDotAI/fsdp_qlora) and our updates to
[bitsandbytes](https://github.com/TimDettmers/bitsandbytes/pull/970) and
[HQQ](https://github.com/mobiusml/hqq/pull/17).

# Integrate a New Quantization Library

The primary blocker for enabling a new quantization method to work with
FSDP+QLoRA is that most libraries store quantized weights in integer
datatypes while FSDP only supports float datatypes. Ideally the
quantized weight data type will match the non-quantized layers data type
for easier and more efficient FSDP model sharding. Furthermore, the
float dtype should be user specifiable so that Mixed Precision training
doesn’t convert quantized weights stored in float32 to bfloat16,
effectively randomizing the weights.

There are at least two methods quantization maintainers can use to store
quantized weights in float datatypes for FSDP[1]:

- Use `torch.view` to convert to and from float dtypes without modifying
  the underlying bytes.
- Read quantized weights in a datatype agnostic method in the
  dequantization kernel

Of these two methods, HQQ uses the first, while bitsandbytes uses the
second.

## HQQ

The example code below shows how HQQ’s `HQQLinear.quantize` and
`HQQLinear.dequantize` methods have been modified to support FSDP
training by viewing int dtype quantized weights as a selectable float
dtype when quantizing, and then converting back to an int dtype before
being passed to the HQQ dequantization kernel.

``` python
@classmethod
def quantize(
    cls,
    *args,
    compute_dtype:torch.dtype=None,
    view_as_float:bool=False
):
    meta['view_as_float'] = view_as_float
    W_q = Quantizer.pack[meta['packing']](W_q)
    if view_as_float:
        # store quantized weights as compute_dtype
        W_q = W_q.view(torch.float32 if compute_dtype is None else compute_dtype)

@classmethod
def dequantize(cls, W_q:torch.Tensor, meta:dict):
    compute_dtype = meta['compute_dtype'] if ('compute_dtype' in meta) else torch.float16
    if meta['view_as_float']:
        W_q = W_q.view(meta['unpack_view_dtype'])
```

Currently, HQQ stores the quantized weights as the same dtype as the
compute type. Splitting the `compute_type` from the quantized weight
storage dtype `quant_storage` as bitsandbytes would enable all of FSDP’s
Mixed Precision options, but the current implementation is already
usable with FSDP.

For more details and the full set of changes made to HQQ, see the [pull
request](https://github.com/mobiusml/hqq/pull/17) by Kerem and the
[mobius.ml folks](https://www.mobiuslabs.com).

## bitsandbytes

bitsandbytes had already implemented their quantization and
dequantization kernels to read and write raw bytes using a StoreChar.

``` cpp
__global__ void kQuantizeBlockwise(unsigned char *out, ...)
{
    unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH];
    __shared__ typename StoreChar::TempStorage storec;
    for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
    {
        // store 8-bit, FP4, or NF4 quantized qvals in storec as bytes or packed bytes
        StoreChar(storec).Store(&(out[(DATA_TYPE > 0) ? i/2 : i]), qvals, (DATA_TYPE > 0) ? (valid_items+1)/2 : valid_items);
    }
}


__global__ void kDequantizeBlockwise(T *out, ...)
{
    unsigned char qvals[NUM_PER_TH];
    __shared__ typename LoadChar::TempStorage loadchar;
    __shared__ typename StoreT::TempStorage storet;

    for (unsigned int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE)
    {
        // load bytes or packed bytes into qvals
        LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128);

        // after dequantizing qvals, store result in the PyTorch Tensor out
        StoreT(storet).Store(&(out[(DATA_TYPE > 0) ? i*2 : i]), vals, valid_items_store);
    }
}
```

This data type agnostic implementation allowed us to modify
bitsandbyte’s `quantize_4bit` and `dequantize_4bit` to have a settable
`quant_storage` dtype, independent of the computation data type
`compute_dtype`.

``` python
def quantize_4bit(A: Tensor, *args, quant_storage=torch.uint8) -> Tensor:
    # create an out Tensor with the quantized shape of dtype=quant_storage
    n = A.numel()
    mod = dtype2bytes[quant_storage] * 2
    out = torch.zeros(((n+1)//mod, 1), dtype=quant_storage, device=A.device)

    # quantize the input tensor A using the blockwise quantization Cuda kernel
    lib.cquantize_blockwise_bf16_nf4(get_ptr(A), get_ptr(out), ct.c_int(n)
)

def dequantize_4bit(A: Tensor, *args, out: Tensor = None) -> Tensor:
    if out is None:
        out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device)

    n = out.numel()

    lib.cdequantize_blockwise_bf16_nf4(get_ptr(A), get_ptr(out), ct.c_int(n))
    return out
```

Then we added `quant_storage` as an argument to `Linear4Bit` and
`Params4Bit`, defaulting to `torch.uint8` for compatibility with
existing codebases. The settable `quant_storage` supports all of FSDP’s
Mixed Precision options.

### Additional Changes

bitsandbytes needed two additional modifications for FSDP compatibility
not needed by HQQ.

First, bitsandbytes stored its quantization metadata dictionary in its
`Params4Bit`, a subclass of PyTorch’s `Parameters`. Unlike a regular
PyTorch `to`, which would preserve this information, FSDP creates
sharded instances of the model’s parameters and buffers. The sharding
process doesn’t persist or copy quantization metadata, which would
prevent QLoRA training or inference.

We resolved this blocker by reusing an in-progress solution from Tim
Dettmers et al.: creating a copy of the quantization metadata on the
`Linear4bit`layer and passing it to the `Params4Bit`during the layer’s
forward pass.

Second, bitsandbytes automatically quantized whenever a bitsandbytes
layer is moved to the GPU. It didn’t check whether the layer was already
quantized, which meant training with FSDP ran the risk of quantizing
already quantized weights, as layers are constantly being moved from CPU
to GPU. We resolved this potential issue by adding a quantization
Boolean to prevent already quantized weights from being quantized again.

For more details and the full set of changes made to bitsandbytes, see
the [pull request](https://github.com/TimDettmers/bitsandbytes/pull/970)
by Benjamin and Kerem.

# Integrate a New Trainer

When integrating a supported quantization method into a training
framework or script, there are two main factors to consider:

- Loading and quantizing a model to use minimal CPU memory and create
  quantization metadata for all devices.
- Correctly initializing FSDP wrapping policies, mixed precision, and
  sharding settings for QLoRA finetuning.

This section assumes you are familiar with the basics of setting up and
training models with FSDP[2].

## Loading and Quantizing a Model

The default method of initializing and loading a model for FSDP requires
one model per GPU (or rank). Loading a 70B model in BFloat16 on a two
GPU system would require roughly 280GB of CPU or GPU RAM just to
initialize the model before FSDP sharding[3].

However, FSDP also has a low memory option `sync_module_states=True`
which only requires the model weights to be initialized or loaded on the
rank 0 process. Then FSDP will shard the model across all GPU ranks.
This combined with quantization reduces the model weights to ~35GB for a
single node, no matter how many ranks[4].

This section will highlight how to load and quantize a model using
FSDP’s low memory option, so the quantization metadata is placed on all
devices. While we only show how to quantize unquantized models, a
similar process can be used to load pre-quantized models.

To simplify the sample code, we will only show how to load bitsandbytes
models using Hugging Face’s Transformers library, but the process is the
same for other model libraries and quantization methods which support
FSDP such as HQQ.

### Initializing the Model

While FSDP only needs a model with weights on the rank 0 process, it
requires an instance of the model on all processes. To prevent these
other instances from using memory, we will initialize the models on all
ranks using PyTorch’s meta device and then load the model weights into
the rank 0 process.

Depending on the model implementation, it might be possible to use the
meta device initialization

``` python
with torch.device(‘meta’):
    model = Model()
```

but Hugging Face models create non-persistent buffers on
initialization[5]. So instead we will use Accelerate’s
`init_empty_weights`, which by default creates the buffers and
initializes the model’s parameters on the meta device[6].

We initialize the model and immediately replace PyTorch’s `nn.Linear`
with bitsandbyte’s `Linear4bit` layer, which implements bitsandbyte’s
four-bit quantization.

``` python
cfg = AutoConfig.from_pretrained(model_name)
cfg.use_cache = False

with accelerate.init_empty_weights():
    model = AutoModelForCausalLM.from_config(cfg)
    model.model = replace_linear(model.model, Linear4bit, compute_dtype=compute_dtype,
                                 quant_type='nf4', quant_storage=storage_dtype)
model.is_loaded_in_4bit = True
```

One important argument to note is the new `quant_storage`. As mentioned
in the quantization section, FSDP cannot shard integer data types, and
by default bitsandbytes (and most other quantization libraries) store
quantized weights in uint8. `quant_storage` allows users to set the
storage for the quantized weights to a float data type. Ideally the same
data type as the rest of the model for optimal FSDP sharding. More on
this in [wrapping policy](#wrapping-policy) and [mixed precision
policy](#mixed-precision-policy) sections.

### Loading the Model Weights

With the model initialized across all ranks, the next step is to load
the pretrained weights and quantize them. Normally for low memory FSDP
we’d only load the model on the rank 0 process, but quantization
presents a wrinkle.

If the model is only quantized on rank 0, the quantization metadata will
only exist on rank 0 and will not be synced since FSDP cannot sync
dictionaries. Furthermore, after quantizing the model on rank 0, the
other rank’s parameters will be their original non-quantized shape,
while the rank 0 model parameters will be half-size or smaller due to
quantization, preventing FSDP from sharding the model.

To resolve this issue, we will load and quantize the model across all
ranks layer by layer, but on ranks 1+ we will immediately discard the
newly quantized layer’s weights, leaving us with a fully loaded model on
the rank 0 process and the correctly shaped meta models with
quantization metadata on all other ranks.

We use safetensors to iterate through all model shards and layers per
shard, then call `load_and_quantize` passing the model, the name of the
current weight, the weight itself, the device and data type to quantize
the layer on, and two low memory loading options: `to_cpu` and
`to_meta`. Rank 0 sets `to_cpu=True` while all other ranks set
`to_meta=True`.

``` python
def load_and_quantize(module:nn.Module, name:str, value:Tensor, device:torch.device,
                      dtype:torch.dtype, to_cpu:bool=False, to_meta:bool=False):

    module_key, _, value_key = name.rpartition('.')
    submodule = module.get_submodule(module_key)

    try:
        param = submodule.get_parameter(value_key)
        if isinstance(param, Params4bit):
            # quantize the layer using bitsandbytes
            value = type(param)(value.to(device, dtype).data, **param.__dict__).cuda(device)
            if to_meta:
                value = type(param)(value.data.to("meta"), **value.__dict__)
            elif to_cpu:
                value = type(param)(value.data.to("cpu"), **value.__dict__)
        else:
            value = type(param)(place_on_device(value).data)

    except AttributeError:
        value = place_on_device(value) # it's a buffer
        pass
    setattr(submodule, value_key, value)
```

After grabbing the correct parameter from the model, we quantize the
parameter data using `Params4bit.cuda` which in addition creates the
quantization metadata on all ranks.

We cannot call `Linear4bit.to` or `Param4bit.to` as this will place the
quantization metadata on `"meta"` or `"cpu"`, so we instead recreate
`Param4bit` on the final device, leaving the quantization metadata on
the current device. If the layer is a non-quantized parameter or buffer,
it is placed on the correct device using the `place_on_device` method,
whose definition is excluded for brevity, and added to the model using
`setattr`.

This entire process leaves us with a fully loaded and quantized model in
CPU RAM for the first process, and meta models with the correct layer
shapes and quantization metadata for all other ranks.

### Initializing the LoRA Layers

The last thing to do before configuring FSDP is to initialize the LoRA
layers. We use PEFT’s `LoraConfig`, targeting all linear layers and
apply the config to the model using `get_peft_model`[7].

``` python
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=lora_rank,
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    target_modules=lora_target_modules,
)

if rank!=0:
    setup_quantized_meta_for_peft(model)

model = get_peft_model(model, peft_config)

setup_quantized_peft_meta_for_training(model)
```

The current release of PEFTwill unhelpfully move the quantization
metadata `quant_state` to the meta device, undoing our careful model
loading[8]. As a temporary workaround we remove the ability to modify
`quant_state` device, apply PEFT to prepare the model for LoRA
finetuning, and then restore the normal `quant_state.to` behavior.

``` python
def setup_quantized_meta_for_peft(model:nn.Module):
    def temp_to_method(self, *args, **kwargs):
        return self
    for param in model.parameters():
        if isinstance(param, Params4bit):
            param.quant_state._orig_to = param.quant_state.to
            param.quant_state.to = types.MethodType(temp_to_method, param.quant_state)

def setup_quantized_peft_meta_for_training(model:nn.Module):
    for param in model.parameters():
        if isinstance(param, Params4bit) and hasattr(param.quant_state, '_orig_to'):
            param.quant_state.to = param.quant_state._orig_to
            param.quant_state._orig_to = None
```

This step isn’t needed with a custom LoRA implementation, and hopefully
won’t be needed in the future as PEFT adds support for FSDP+QLoRA.

## Setting Up FSDP for QLoRA Finetuning

Configuring FSDP for QLoRA finetuning requires a few modifications to a
standard FSDP setup.

- LoRA layers require a custom FSDP wrapping policy
- Different mixed precision setups require specific quantization weight
  types
- Low memory model loading requires specific FSDP initialization
  arguments

The next three sections will cover these modifications in detail.

### Wrapping Policy

FSDP uses wrapping policies to determine how to split models for
sharding. As a rule of thumb, each shard should have at least a million
parameters. However, the LoRA layers will need to be wrapped separately
due to wrapping restrictions. The two wrapping restrictions are:

- Wrapped layers cannot have a mixture of `requries_grad=True` and
  `requries_grad=False`
- Wrapped layers cannot have a mixture of different floating types

To wrap our LoRA model, we’ll use a wrapping method straight from Meta’s
[llama-recipes](https://github.com/facebookresearch/llama-recipes).
`get_wrapping_policy` defines `lambda_policy_fn` to identify any LoRA
layer implementation. Then it passes that policy into the FSDP
`lambda_auto_wrap_policy` and the Transformers `LlamaDecoderLayer` into
FSDP `transformer_auto_wrap_policy`. This will wrap each decoder layer
as a separate shard while splitting out the LoRA layers into their own
shards.

``` python
def get_wrapping_policy(transformer_layer_name:nn.Module=LlamaDecoderLayer):
    def lambda_policy_fn(module):
        return (
            len(list(module.named_children())) == 0
            and getattr(module, "weight", None) is not None
            and module.weight.requires_grad
        )

    lambda_policy = partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)

    transformer_wrap_policy = partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls=transformer_layer_name
    ),

    return partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])
```

Then we put it all together by combining the two polices together with
an FSDP `_or_policy`.

The quantized weight storage type `quant_storage` should match the
model’s non-quantized weight datatype so the FSDP wrapping policy will
combine the quantized layers with non-quantized layers. Should these
data types differ, we’d need to adjust our wrapping policy to separate
our quantized layers from the model’s non-quantized non-LoRA layers.

### Mixed Precision Policy

FSDP’s mixed precision differs from PyTorch’s automatic mixed precision
in that the user has complete control over the parameter, gradient
reduction, and buffer data types, although the two can be combined.

For example, the following two mixed precision policies are designed to
work with FSDP’s mixed precision and AMP’s `autocast` respectively. The
first policy will perform all mixed precision options in bfloat16, while
the second when combined with `amp.autocast` during the model’s forward
pass will dynamically cast to bfloat16 on a per layer basis.

``` python
MixedPrecision(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.bfloat16,
    buffer_dtype=torch.bfloat16
)

# cast to bfloat16 using amp.autocast
MixedPrecision(
    param_dtype=torch.float32,
    reduce_dtype=torch.float32,
    buffer_dtype=torch.float32
)
```

It’s important to set the quantized weight storage type `quant_storage`
to match the `param_dtype`, otherwise the quantized weights can be cast
to a different floating-point type, rendering them into random weights
or causing a computation type mismatch. The first mixed precision
example requires `quant_storage=torch.bfloat16` while the second
requires `quant_storage=torch.float32`.

It’s also possible to exclude layers from the mixed precision policy,
which will override our wrapping policy and wrap these layers
separately. This setting can be used for quantization libraries that
cannot set the quantized weights to a specific floating-point type so
the quantized weights are not modified by FSDP’s mixed precision
policies.

### FSDP for Low Memory Loading

With our LoRA-aware wrapping and optional mixed precision policies
defined, we can shard our model using FSDP.

Outside of our wrapping and optional mixed precision policies we enable
FSDP’s `CPUOffload` when training larger-than-GPU-memory models and
`sync_module_states` to only require loading one copy of the model into
CPU or GPU RAM.

``` python
model = FSDP(
    model,
    auto_wrap_policy=wrapping_policy,
    mixed_precision=mixed_precision_policy,
    cpu_offload=CPUOffload(offload_params=True) if cpu_offload else None,
    sync_module_states=low_memory,
    device_id=torch.cuda.current_device(),
    param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
        if (rank!=0 and low_memory) else None,
)
```

The last two arguments, `device_id` and `param_init_fn`, are required if
we are using `sync_module_states` to lower the model loading memory
requirements. These arguments must be initialized in this manner so FSDP
can convert our meta device models into regular models while sharding
and syncing the model weights from the rank 0 process.

Sharding our quantized 70B model with FSDP across two GPUs requires just
over 128GB of CPU RAM[9]. Memory usage settles down to ~110GB during
training due to CPU offloading.

# Conclusion

At the time of publication, initial support for FSDP+QLoRA has been
added to development builds of
[Axolotl](https://github.com/OpenAccess-AI-Collective/axolotl/pull/1378)
and the Hugging Face ecosystem[10].

We hope that this technical deep dive will accelerate adding FSDP+QLoRA
support to more quantization methods and training frameworks.

If you are interested in integrating FSDP+QLoRA support into your
quantization library or training framework, feel free to reach out to
us. Otherwise stay tuned for our upcoming benchmarking post where we
explore how to make the most of this new approach.

[1] If you implement a new method for your library, please let us know
so we can update this list

[2] If you are not familiar with FSDP, PyTorch has two tutorials which
can catch you up to speed: [Getting Started with Fully Sharded Data
Parallel
(FSDP)](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html) &
[Advanced Model Training with Fully Sharded Data Parallel
(FSDP)](https://pytorch.org/tutorials/intermediate/FSDP_adavnced_tutorial.html).

[3] A 70 billion parameter model at half precision requires two bytes
per parameter, so roughtly 140GB just to load one copy of the model.

[4] FSDP has additional overhead for sharding. In practice sharding a
quantized 70B model across two GPUs requires a minimum of 128GB plus
10-20GB of scratch disk space.

[5] Non-persistent buffers are not saved by default, so we need to make
sure to recreate them when initializing our model.

[6] Other training frameworks, such as PyTorch Lightning and Composer
also provide similar initialization methods.

[7] We cannot use PEFT’s `prepare_for_kbit_training` because it assumes
quantized weights are stored as int8 and upcasts everything else to
float32, and by default sets up gradient checkpointing which needs to be
handled separately when using FSDP.

[8] The current release at the time of publication is PEFT 0.9.0.

[9] A 10-20GB scratch disk is enough to handle this initial memory
spike.

[10] Initial FSDP+QLoRA has support has been added to
[Accelerate](https://github.com/huggingface/accelerate/pull/2544),
[Transformers](https://github.com/huggingface/transformers/pull/29587),
[TRL](https://github.com/huggingface/trl/pull/1416), and
[PEFT](https://github.com/huggingface/peft/pull/1550).
