PEFT Method Overview [implementing Adapters in PyTorch]

In the rapidly evolving landscape transformer-based architectures, a significant challenge has emerged: how do we customize these increasingly massive models for specific tasks without breaking the bank on computational resources?

Enter Parameter-Efficient Fine-Tuning (PEFT), a family of techniques that has revolutionized how we adapt pre-trained models to downstream tasks.

The Fine-Tuning Dilemma

So, you’ve got access to a SoTA LM with billions of parameters.

Perhaps it’s GPT-4, LLaMA 3, Mistral or Qwen. You want to adapt this model to a specialized domain like medical text analysis or legal document processing.

The traditional approach would involve fine-tuning the entire model on your domain-specific data.

Full fine-tuning comes with substantial costs:

  1. Computational Expense: Training billions of parameters requires significant GPU resources.
  2. Storage Overhead: Each fine-tuned version requires storing a complete copy of the model
  3. Catastrophic Forgetting: Aggressive fine-tuning might cause the model to lose its general capabilities
  4. Limited Scalability: Maintaining multiple specialized versions becomes unmanageable

This is where PEFT techniques come to the rescue. Rather than updating all parameters, PEFT methods focus on adding and training a small number of parameters while keeping most of the pre-trained model frozen. This approach typically requires updating less than 1% of the parameters compared to full fine-tuning, while achieving comparable performance.

Let’s understand most significant PEFT methods, their core principles, and implement them using PyTorch for better understanding. Then we’ll explore how to use these techniques with the Hugging Face peft library for practical applications.

General Overview of PEFT methods

Parameter-efficient fine-tuning methods taxonomy. Scaling Down to Scale Up: A Guide to Parameter-Efficient Fine-Tuning.s

1. Addition-Based Methods

These approaches add new, lightweight modules to a pre-trained model while keeping the original weights frozen. This is the most widely explored category and includes two main subtypes: adapters and soft prompts.

  • Adapters (e.g., Bottleneck Adapters, Parallel Adapters): Small neural layers inserted within Transformer blocks that are trained while the rest of the model remains unchanged. Variants differ in placement, structure, and compression strategies.
  • LoRA (Low-Rank Adaptation): Instead of fine-tuning full weight matrices, LoRA introduces low-rank decompositions for weight updates (e.g., replacing a full-rank update with W_down * W_up).
  • Prefix Tuning / Prompt Tuning: Add trainable vectors (prefixes or prompts) to the model’s input or internal layers. These methods steer model behavior without changing its core parameters.
  • Soft Prompts: Instead of using discrete tokens, these train continuous embeddings that are prepended to the input. Can be applied to input embeddings or even across all Transformer layers.

Despite adding new parameters, these methods often use significantly less memory and are more computationally efficient due to fewer gradients and optimizer states being updated.


2. Selection-Based Methods

Selective approaches involve fine-tuning only a specific subset of the model’s original parameters, chosen either manually or via structural criteria.

  • BitFit: Fine-tunes only the bias terms of the model, drastically reducing the number of parameters involved.
  • IA³ (Infused Adapter by Inhibiting and Amplifying Activations): Adds scalar gating parameters to control the flow of information through the attention and feed-forward layers.
  • Layer-wise Selection: Fine-tunes only the top or bottom layers of the model, or focuses on specific components (e.g., attention vs. FFN).
  • Sparse Fine-Tuning: Selects parameters to update based on certain criteria (e.g., magnitude or gradients), ignoring model structure. However, this poses practical challenges for current hardware.

These methods are particularly useful when model updates must be extremely lightweight or constrained due to storage, bandwidth, or privacy concerns.


3. Reparameterization-Based Methods

These techniques re-structure the parameter space to enable efficient updates with fewer trainable weights.

  • LoRA (also fits here): Uses low-rank matrices to model weight updates, greatly reducing parameter count.
  • Compacter: Builds on adapters but compresses them using low-rank decomposition and parameter sharing.
  • (IA)³: Combines gating and reparameterization ideas to modulate specific subcomponents of the model.
  • KronA / Kron Adapter: Uses Kronecker product decomposition to represent weight updates with a favorable trade-off between expressiveness and size.
  • Intrinsic SAID: Employs the Fastfood transform to apply updates within a low-rank subspace, based on the theory that fine-tuning operates within a lower-dimensional manifold.

These methods often target attention-related weights like W_Q, W_K, W_V, where much of the model’s representational power lies.


4. Hybrid Methods

Hybrid approaches combine strategies from multiple categories to balance trade-offs in memory, compute, and performance.

  • MAM Adapter: Combines Adapters with Prompt Tuning for better modularity.
  • UniPELT: Merges LoRA with Adapters and Prompts into a unified framework.
  • Compacter++ / Kron Adapter: Combines adapter-based methods with Kronecker reparameterization to reduce the number of trainable parameters further.

These methods allow researchers to adapt fine-tuning strategies to specific deployment constraints, whether that be edge devices, multi-task learning, or multilingual models.

Bottleneck Adapters

Adapters were among the first successful PEFT approaches, introduced in Parameter-Efficient Transfer Learning for NLP by Houlsby et al. in 2019.

The core idea is elegantly simple: insert small trainable modules into each layer of a pre-trained network while keeping the original parameters frozen.

Bottleneck adapters add lightweight feed-forward layers into each Transformer block. These adapter layers typically include:

  • a down-projection matrix that reduces the hidden state dimension from \(d\) to a smaller dimension \(b\),
  • a non-linear activation \(\sigma\),
  • an up-projection matrix that expands the representation back to the original size \(d\), and
  • a residual connection, so the original input is added back after transformation:
\[\text{Adapter}(x) = x + W_{\text{up}} \, \sigma(W_{\text{down}} x)\]

Depending on the specific configuration, these adapter layers can be placed at various points inside the Transformer block. Other components like residual connections, layer normalizations, activation functions, and the size of the bottleneck layer can also be customized.

The most important hyperparameter in this setup is the bottleneck dimension \(b\). Rather than setting \(b\) directly, it’s usually defined through a parameter called reduction_factor. This factor represents the ratio between the hidden layer size \(d\) and the bottleneck size \(b\), given by:

\[b = \frac{d}{\text{reduction\_factor}}\]
class BottleneckAdapter(nn.Module):
    def __init__(self, hidden_size, adapter_size, dropout_rate=0.1):
        """
        A bottleneck adapter module that can be inserted into a transformer.
        
        It projects hidden states down to a lower-dimensional space and then 
        back up again, with non-linearity and dropout in between. This helps 
        the model adapt to new tasks without updating the original transformer.
        
        Args:
            hidden_size: The dimension of the model's hidden states (e.g., 768 for BERT-base)
            adapter_size: The smaller bottleneck dimension (e.g., 64)
            dropout_rate: Regularization to improve generalization
        """
        super().__init__()
        
        self.down_project = nn.Linear(hidden_size, adapter_size)  # d -> b
        self.activation = nn.GELU()  # non-linearity
        self.up_project = nn.Linear(adapter_size, hidden_size)    # b -> d
        self.dropout = nn.Dropout(dropout_rate)
        self.layer_norm = nn.LayerNorm(hidden_size)
        
        # Initialize adapter weights — not learned from pretraining, so good init is important!
        nn.init.xavier_uniform_(self.down_project.weight)
        nn.init.zeros_(self.down_project.bias)
        nn.init.xavier_uniform_(self.up_project.weight)
        nn.init.zeros_(self.up_project.bias)

    def forward(self, hidden_states):
        # Store original input for residual connection
        residual = hidden_states

        # Apply adapter: down-project -> non-linear -> up-project -> dropout
        x = self.down_project(hidden_states)
        x = self.activation(x)
        x = self.up_project(x)
        x = self.dropout(x)

        # Add residual and normalize
        output = residual + x
        output = self.layer_norm(output)
        return output

But how do we integrate adapters into a pre-trained model?

Let’s see how to modify a standard transformer layer to include our bottleneck adapter:

class AdapterTransformerLayer(nn.Module):
    def __init__(self, transformer_layer, adapter_size):
        """
        A wrapper around an existing transformer layer that adds adapters after
        attention and after the feed-forward sublayers.

        Args:
            transformer_layer: One layer from a pre-trained transformer (e.g., BERTLayer)
            adapter_size: Bottleneck size for the adapters
        """
        super().__init__()
        self.layer = transformer_layer
        self.hidden_size = transformer_layer.attention.self.all_head_size  # Model-specific

        # Freeze all transformer weights (we don’t train them)
        for param in self.layer.parameters():
            param.requires_grad = False

        # Add bottleneck adapters at two key places:
        self.attention_adapter = BottleneckAdapter(self.hidden_size, adapter_size)
        self.ffn_adapter = BottleneckAdapter(self.hidden_size, adapter_size)

    def forward(self, hidden_states, attention_mask=None):
        # Standard attention (output of frozen pre-trained layer)
        attention_output = self.layer.attention(hidden_states, attention_mask)[0]

        # Inject adapter after attention
        adapted_attention = self.attention_adapter(attention_output)

        # Apply frozen feed-forward network
        intermediate_output = self.layer.intermediate(adapted_attention)
        layer_output = self.layer.output(intermediate_output, adapted_attention)

        # Inject second adapter after feed-forward
        output = self.ffn_adapter(layer_output)

        return output

All set, and we need to load pre-trained model and wrap it’s target layers with this.

model = BertModel.from_pretrained("bert-base-uncased")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# Define adapter size
adapter_size = 64

# Wrap all encoder layers with adapter-enabled versions
for i in range(len(model.encoder.layer)):
    original_layer = model.encoder.layer[i]
    model.encoder.layer[i] = AdapterTransformerLayer(original_layer, adapter_size)
# Check that only adapters will be trained
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"Trainable parameters: {trainable_params} / {total_params}")

# Now you can tokenize input and train like usual.
inputs = tokenizer("Adapters are lightweight and powerful.", return_tensors="pt")
outputs = model(**inputs)

Now, I’m gonna show how to use adapters in transformers library. It’s faster, easier, and production-tested.

# pip install adapter-transformers

# `BertAdapterModel` = a special version of BERT that allows adapter injection.
from transformers import BertTokenizer, BertAdapterModel
from transformers.adapters import AdapterConfig

# Load model and tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertAdapterModel.from_pretrained("bert-base-uncased")

# Define adapter configuration
config = AdapterConfig.load(
    "pfeiffer",                    # Adapter type: "pfeiffer", "houlsby", etc.
    reduction_factor=16,          # Bottleneck size (768 / 16 = 48)
    leave_out=[0, 11],            # Skip layer 0 and 11 (i.e., don't inject there)
    non_linearity="gelu",
)

# Add adapter with a custom name
model.add_adapter("my_task_adapter", config=config)

# Activate + train this adapter
model.train_adapter("my_task_adapter")

#Tokenize Input and Forward Pass
inputs = tokenizer("Adapters are efficient!", return_tensors="pt")
outputs = model(**inputs)

# Last hidden state (batch_size, seq_len, hidden_dim)
print(outputs.last_hidden_state.shape)

# Add a Classification Head (for downstream tasks)
model.add_classification_head("my_task_adapter", num_labels=2)

# Switch to training mode
model.train_adapter("my_task_adapter")

# Forward pass for classification
inputs = tokenizer("Adapters are awesome!", return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits

# Training Only Adapter Parameters
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-5)

# Save / Load Adapters Separately

# Save adapter after training
model.save_adapter("saved/my_task_adapter", "my_task_adapter")

# Load later into another model
model.load_adapter("saved/my_task_adapter", load_as="my_task_adapter")
model.set_active_adapters("my_task_adapter")

Parallel Adapters

While bottleneck adapters are inserted sequentially in the model’s architecture, parallel Adapters inject adapter modules in parallel with the main feed-forward layers in each Transformer block, instead of sequentially. This means that the output of the adapter is added to the output of the feed-forward network, not to its input.

Parallel Adapter

Let \(x\) be the input to the Transformer block. The original feed-forward output is \(\mathrm{FFN}(x)\), and the adapter path is:

\[\mathrm{Adapter}(x) = W_\text{up} \, \sigma(W_\text{down} \, x)\]

The final output becomes:

\[y = \mathrm{FFN}(x) + \mathrm{Adapter}(x)\]

This allows the adapter to independently learn task-specific modifications without disrupting the main path.

The parallel design has a slight computational overhead but can better preserve the pre-trained representations.

class ParallelAdapter(nn.Module):
    def __init__(self, hidden_size, adapter_size, dropout_rate=0.1):
        super().__init__()
        self.down_project = nn.Linear(hidden_size, adapter_size)
        self.activation = nn.GELU()
        self.up_project = nn.Linear(adapter_size, hidden_size)
        self.dropout = nn.Dropout(dropout_rate)
        
        # Initialize weights
        nn.init.xavier_uniform_(self.down_project.weight)
        nn.init.zeros_(self.down_project.bias)
        nn.init.xavier_uniform_(self.up_project.weight)
        nn.init.zeros_(self.up_project.bias)
        
        # Scale factor - can be trained or fixed
        self.scaling_factor = nn.Parameter(torch.tensor(0.1))
    
    def forward(self, hidden_states):
        x = self.down_project(hidden_states)
        x = self.activation(x)
        x = self.up_project(x)
        x = self.dropout(x)
        
        # Scale the adapter output and add to original
        return hidden_states + self.scaling_factor * x

The integration into a transformer layer would be similar to the bottleneck adapter, but the adapter would be applied in parallel rather than sequentially.

Low-Rank Adaptation (LoRA)

LoRA (Low-Rank Adaptation) introduced by Hu et al. (2021) replaces or augments weight matrices with low-rank decompositions. Instead of fine-tuning a full matrix \(W \in \mathbb{R}^{d \times d}\), LoRA learns two smaller matrices:

\[W' = W + A B \quad \text{with} \quad A \in \mathbb{R}^{d \times r}, \; B \in \mathbb{R}^{r \times d}\]

Where \(r \ll d\) (typically \(r = 8\) or \(4\)). This drastically reduces the number of trainable parameters. LoRA is usually applied to the attention projection layers (query/key/value/output).

LoRA Adapter

Intuitively, LoRA adds a “low-rank path” through which task-specific information can flow, while keeping the rest of the model fixed.

class LoRALayer(nn.Module):
    def __init__(self, in_features, out_features, rank=8, alpha=32):
        """
        LoRA implementation for linear layers.
        
        Args:
            in_features: Input dimension
            out_features: Output dimension
            rank: Rank of the low-rank decomposition
            alpha: Scaling factor for the LoRA contribution
        """
        super().__init__()
        self.rank = rank
        self.scaling = alpha / rank
        
        # LoRA weights
        self.lora_A = nn.Parameter(torch.zeros(in_features, rank))
        self.lora_B = nn.Parameter(torch.zeros(rank, out_features))
        
        # Initialize weights
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)
    
    def forward(self, x):
        # LoRA contribution: scaling * (x @ A) @ B
        return self.scaling * (x @ self.lora_A) @ self.lora_B

Now, let’s apply LoRA to a pre-trained linear layer:

class LoRALinear(nn.Module):
    def __init__(self, linear_layer, rank=8, alpha=32):
        """
        Wraps a pre-trained linear layer with LoRA functionality.
        
        Args:
            linear_layer: The pre-trained nn.Linear module to adapt
            rank: Rank of the low-rank decomposition
            alpha: Scaling factor
        """
        super().__init__()
        self.linear = linear_layer
        
        # Freeze original weights
        self.linear.weight.requires_grad = False
        if self.linear.bias is not None:
            self.linear.bias.requires_grad = False
            
        # Add LoRA components
        self.lora = LoRALayer(
            linear_layer.in_features, 
            linear_layer.out_features,
            rank=rank,
            alpha=alpha
        )
    
    def forward(self, x):
        # Combine original output with LoRA contribution
        return self.linear(x) + self.lora(x)

The genius of LoRA is in its efficiency.

If the original weight matrix has dimensions n×m, full fine-tuning would require updating n×m parameters. With LoRA, using a rank r, we only need to update r×(n+m) parameters. For large matrices where r « min(n,m), this represents a massive reduction in trainable parameters.

Applying LoRA to a Transformer

In practice, LoRA is typically applied to specific weight matrices within a transformer, most commonly the query and value projection matrices in attention layers. Here’s how to adapt a transformer’s attention mechanism with LoRA:

import math
from transformers import AutoModel

def apply_lora_to_model(model, rank=8, alpha=32, target_modules=["q_proj", "v_proj"]):
    """
    Apply LoRA to specific modules in a transformer model.
    
    Args:
        model: A Hugging Face transformer model
        rank: Rank for LoRA decomposition
        alpha: Scaling factor
        target_modules: List of module names to apply LoRA to
    """
    # First, freeze all parameters
    for param in model.parameters():
        param.requires_grad = False
    
    # Then apply LoRA to target modules
    for name, module in model.named_modules():
        if any(target_name in name for target_name in target_modules):
            if isinstance(module, nn.Linear):
                # Get the parent module
                parent_name = '.'.join(name.split('.')[:-1])
                child_name = name.split('.')[-1]
                parent_module = model.get_submodule(parent_name)
                
                # Replace with LoRA version
                lora_layer = LoRALinear(module, rank=rank, alpha=alpha)
                setattr(parent_module, child_name, lora_layer)
    
    return model

model = AutoModel.from_pretrained("bert-base-uncased")
lora_model = apply_lora_to_model(model)

Quantized LoRA (QLoRA)

QLoRA, takes LoRA’s efficiency to the next level by combining it with quantization techniques. The key insight is to keep the base model in a quantized format (typically 4-bit precision) while applying LoRA adapters in full precision.

QLoRA has been a game-changer for democratizing LLM fine-tuning, enabling the adaptation of models with over 70 billion parameters on a single consumer GPU.

Prefix Tuning: Virtual Tokens in Hidden Space

Now let’s shift our focus to another family of PEFT methods that operate by introducing trainable tokens to the input sequence or hidden states: Prefix Tuning and Prompt Tuning.

Prefix Tuning, introduced by Li and Liang (2021), prepends a small number of learned key-value vectors (“prefixes”) to the attention mechanism.

Prefix Tuning

Instead of modifying weights, it expands the input to attention as:

\[\text{Attention}(\text{prefix} + x)\]

This means the model sees the learned prefix as a pseudo-context for every input, influencing the attention output without changing the underlying Transformer parameters.

Prefix tuning is powerful for generation tasks like summarization or translation where modifying the attention context is sufficient.

class PrefixTuningModule(nn.Module):
    def __init__(self, hidden_size, prefix_length=20, num_layers=12, num_heads=12, head_dim=64):
        """
        Implementation of Prefix Tuning.
        
        Args:
            hidden_size: Model's hidden size
            prefix_length: Number of virtual tokens to add
            num_layers: Number of transformer layers
            num_heads: Number of attention heads
            head_dim: Dimension of each attention head
        """
        super().__init__()
        self.prefix_length = prefix_length
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_dim = head_dim
        
        # Create a prefix for each layer for both key and value states
        # Shape: [num_layers, 2, prefix_length, num_heads, head_dim]
        self.prefix_tokens = nn.Parameter(
            torch.zeros(num_layers, 2, prefix_length, num_heads, head_dim)
        )
        
        # Initialize with a small standard deviation
        nn.init.normal_(self.prefix_tokens, std=0.02)
    
    def forward(self, key_value_states, layer_idx):
        """
        Prepend prefix to key and value states for a specific layer.
        
        Args:
            key_value_states: Tuple of (key, value) states from the model
            layer_idx: Current transformer layer index
        """
        key_states, value_states = key_value_states
        batch_size = key_states.shape[0]
        
        # Get the prefix for the current layer
        # Shape: [2, prefix_length, num_heads, head_dim]
        prefix = self.prefix_tokens[layer_idx]
        
        # Extract key and value prefixes
        key_prefix = prefix[0].expand(batch_size, -1, -1, -1)
        value_prefix = prefix[1].expand(batch_size, -1, -1, -1)
        
        # Reshape to match model's key and value shapes
        # From: [batch_size, prefix_length, num_heads, head_dim]
        # To: [batch_size, num_heads, prefix_length, head_dim]
        key_prefix = key_prefix.permute(0, 2, 1, 3)
        value_prefix = value_prefix.permute(0, 2, 1, 3)
        
        # Concatenate with original states
        # Original shape: [batch_size, num_heads, seq_length, head_dim]
        new_key_states = torch.cat([key_prefix, key_states], dim=2)
        new_value_states = torch.cat([value_prefix, value_states], dim=2)
        
        return (new_key_states, new_value_states)

To integrate this with a transformer model, we need to modify each attention layer to incorporate the prefixes:

class PrefixTransformerLayer(nn.Module):
    def __init__(self, transformer_layer, prefix_module, layer_idx):
        super().__init__()
        self.layer = transformer_layer
        self.prefix_module = prefix_module
        self.layer_idx = layer_idx
        
        # Freeze the original layer
        for param in self.layer.parameters():
            param.requires_grad = False
    
    def forward(self, hidden_states, attention_mask=None):
        # Extract the attention module (implementation depends on model architecture)
        attention = self.layer.attention.self
        
        # Prepare key, query, value states as in the original attention
        query_states = attention.query(hidden_states)
        key_states = attention.key(hidden_states)
        value_states = attention.value(hidden_states)
        
        # Reshape for multi-head attention
        batch_size, seq_length = hidden_states.shape[:2]
        head_dim = attention.head_size
        num_heads = attention.num_attention_heads
        
        query_states = query_states.view(batch_size, seq_length, num_heads, head_dim).transpose(1, 2)
        key_states = key_states.view(batch_size, seq_length, num_heads, head_dim).transpose(1, 2)
        value_states = value_states.view(batch_size, seq_length, num_heads, head_dim).transpose(1, 2)
        
        # Apply prefix
        key_states, value_states = self.prefix_module((key_states, value_states), self.layer_idx)
        
        # Update attention mask for the additional prefix tokens
        if attention_mask is not None:
            prefix_attention_mask = torch.ones(
                batch_size, 
                1, 
                1, 
                self.prefix_module.prefix_length, 
                device=attention_mask.device
            )
            extended_attention_mask = torch.cat([prefix_attention_mask, attention_mask], dim=-1)
        else:
            extended_attention_mask = None
        
        # Calculate attention scores and outputs
        # (Implementation depends on the specific attention mechanism)
        # ...
        
        return output

The above implementation is conceptual and would need to be adapted based on the specific transformer architecture you’re working with.

Prompt Tuning

Prompt Tuning, can be seen as a simplified version of Prefix Tuning. Rather than adding virtual tokens at every layer, Prompt Tuning only prepends trainable embeddings to the input sequence embeddings at the first layer.

Here’s a straightforward implementation:

class PromptTuning(nn.Module):
    def __init__(self, model, prompt_length=20):
        """
        Implementation of Prompt Tuning.
        
        Args:
            model: The pre-trained transformer model
            prompt_length: Number of virtual tokens to add
        """
        super().__init__()
        self.model = model
        self.prompt_length = prompt_length
        
        # Freeze model parameters
        for param in model.parameters():
            param.requires_grad = False
        
        # Get embedding dimension from the model
        embed_dim = model.get_input_embeddings().weight.shape[1]
        
        # Create soft prompt embeddings
        self.soft_prompts = nn.Parameter(torch.randn(prompt_length, embed_dim))
        
        # Initialize with embeddings of random tokens from the vocabulary
        with torch.no_grad():
            vocab_size = model.get_input_embeddings().weight.shape[0]
            random_indices = torch.randint(0, vocab_size, (prompt_length,))
            self.soft_prompts.data = model.get_input_embeddings().weight[random_indices].clone()
    
    def forward(self, input_ids=None, attention_mask=None, **kwargs):
        batch_size = input_ids.shape[0] if input_ids is not None else attention_mask.shape[0]
        
        # Get input embeddings
        if input_ids is not None:
            inputs_embeds = self.model.get_input_embeddings()(input_ids)
        else:
            inputs_embeds = kwargs.pop("inputs_embeds")
        
        # Expand soft prompts for batch size and prepend to input embeddings
        prompt_embeds = self.soft_prompts.unsqueeze(0).expand(batch_size, -1, -1)
        inputs_embeds = torch.cat([prompt_embeds, inputs_embeds], dim=1)
        
        # Adjust attention mask for the added prompt tokens
        if attention_mask is not None:
            prompt_mask = torch.ones(batch_size, self.prompt_length, device=attention_mask.device)
            attention_mask = torch.cat([prompt_mask, attention_mask], dim=1)
        
        # Forward pass through the model without input_ids
        return self.model(
            input_ids=None,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            **kwargs
        )

BitFit Adapter

BitFit, proposed by Zaken et al. (2021), takes a radically different approach from the methods we’ve discussed so far. Instead of adding new parameters, BitFit selectively trains only the bias terms in the pre-trained model, leaving all other parameters frozen.

Despite its extreme parameter efficiency, BitFit has shown good performance across various tasks.

def apply_bitfit_to_model(model):
    """
    Apply BitFit to a transformer model by only training bias terms.
    
    Args:
        model: A PyTorch model
    """
    # First, freeze all parameters
    for param in model.parameters():
        param.requires_grad = False
    
    # Then unfreeze only bias parameters
    for name, param in model.named_parameters():
        if "bias" in name:
            param.requires_grad = True
    
    return model

The implementation is remarkably simple, yet BitFit can achieve competitive performance while training less than 0.1% of the original model parameters in many cases.

IA³: Infused Adapter by Inhibiting and Amplifying Inner Activations

IA³ (Input-Aware Activation Adjustment) by Liu et al. (2022), modifies the element-wise activation scale and bias after each linear transformation. For a layer with output \(x\), IA³ computes:

\[x' = \alpha \cdot x + \beta\]

Here, \(\alpha\) and \(\beta\) are trainable parameters. This is similar to fine-tuning just the scale and shift of activations and can be extremely efficient.

IA³ is useful when slight shifts in activation distributions are enough to steer the model to the new task.

IA3 Adapter

Let’s check how it looks like in the code.

class IA3Module(nn.Module):
    def __init__(self, hidden_size, ia3_type="feed_forward"):
        """
        Implementation of IA³ scaling vectors.
        
        Args:
            hidden_size: Dimension to scale
            ia3_type: Where to apply IA³ ('feed_forward', 'attention_output', 'attention_value')
        """
        super().__init__()
        self.ia3_type = ia3_type
        
        # Create scaling vectors initialized to ones
        if ia3_type == "feed_forward":
            # For the output of the feed-forward layer
            self.scaling_vector = nn.Parameter(torch.ones(hidden_size))
        elif ia3_type == "attention_output":
            # For scaling attention outputs
            self.scaling_vector = nn.Parameter(torch.ones(hidden_size))
        elif ia3_type == "attention_value":
            # For scaling value vectors in attention
            self.scaling_vector = nn.Parameter(torch.ones(hidden_size))
    
    def forward(self, x):
        """
        Apply scaling to input tensor.
        """
        if self.ia3_type == "attention_value":
            # For attention values, we reshape for broadcasting across batch and seq dimensions
            return x * self.scaling_vector.view(1, 1, 1, -1)
        else:
            # For feed-forward and attention outputs
            return x * self.scaling_vector

Integrating IA³ with a transformer model requires injecting the scaling at specific points:

class IA3TransformerLayer(nn.Module):
    def __init__(self, transformer_layer, hidden_size):
        super().__init__()
        self.layer = transformer_layer
        
        # Freeze original parameters
        for param in self.layer.parameters():
            param.requires_grad = False
        
        # Add IA³ modules
        self.attention_value_ia3 = IA3Module(hidden_size, ia3_type="attention_value")
        self.attention_output_ia3 = IA3Module(hidden_size, ia3_type="attention_output")
        self.feed_forward_ia3 = IA3Module(hidden_size, ia3_type="feed_forward")
    
    def forward(self, hidden_states, attention_mask=None):
        # Extract components (implementation is model-specific)
        attention = self.layer.attention.self
        
        # Compute query, key, value projections
        query = attention.query(hidden_states)
        key = attention.key(hidden_states)
        value = attention.value(hidden_states)
        
        # Apply IA³ to value projections
        value = self.attention_value_ia3(value)
        
        # Compute attention
        attention_output = attention(hidden_states, attention_mask)[0]
        
        # Apply IA³ to attention output
        attention_output = self.attention_output_ia3(attention_output)
        
        # Feed-forward network
        intermediate_output = self.layer.intermediate(attention_output)
        layer_output = self.layer.output(intermediate_output, attention_output)
        
        # Apply IA³ to feed-forward output
        output = self.feed_forward_ia3(layer_output)
        
        return output

IA³’s efficiency is remarkable: for a model with hidden size h, it adds only 3h parameters per layer, compared to the millions in the original layer.

Compacter: Kronecker Products for Ultimate Efficiency

Compacter, proposed by Mahabadi et al. (2021), builds on the adapter idea, but instead of learning full matrices for down/up projection, it composes them from Kronecker products of smaller matrices.

\[W = W_1 \otimes W_2\]

This gives an expressive yet parameter-efficient formulation. Compacter adapters can learn more complex transformations than simple low-rank matrices without adding much overhead.

Compacter Adapter

Let’s implement Compacter:

class PHM(nn.Module):
    """
    Parameterized Hypercomplex Multiplication using Kronecker products.
    """
    def __init__(self, in_features, out_features, rank=4, factorized_phm=True):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank
        self.factorized_phm = factorized_phm
        
        # Calculate dimensions for the factors
        self.in_factor_size = int(math.sqrt(in_features))
        self.out_factor_size = int(math.sqrt(out_features))
        
        # Ensure dimensions are compatible with factorization
        assert self.in_factor_size * self.in_factor_size == in_features, \
            "Input features must be a perfect square for factorization"
        assert self.out_factor_size * self.out_factor_size == out_features, \
            "Output features must be a perfect square for factorization"
        
        if factorized_phm:
            # Factorized representation using shared factors
            self.A = nn.Parameter(torch.empty(rank, self.in_factor_size, self.out_factor_size))
            self.B = nn.Parameter(torch.empty(rank, self.in_factor_size, self.out_factor_size))
        else:
            # Full Kronecker factors
            self.W = nn.Parameter(torch.empty(rank, in_features, out_features))
        
        self.s = nn.Parameter(torch.ones(rank))
        
        # Initialize parameters
        self._init_parameters()
    
    def _init_parameters(self):
        """Initialize the parameters with small random values."""
        if self.factorized_phm:
            nn.init.normal_(self.A, mean=0.0, std=0.02)
            nn.init.normal_(self.B, mean=0.0, std=0.02)
        else:
            nn.init.normal_(self.W, mean=0.0, std=0.02)
        
        nn.init.ones_(self.s)
    
    def kronecker_product(self, A, B):
        """
        Compute the Kronecker product of matrices A and B.
        """
        batch_size = A.size(0)
        s1, s2 = A.size(1), A.size(2)
        s3, s4 = B.size(1), B.size(2)
        
        # Reshape for matrix multiplication
        A_reshaped = A.view(batch_size, s1 * s2, 1)
        B_reshaped = B.view(batch_size, 1, s3 * s4)
        
        # Perform outer product
        kron_prod = torch.bmm(A_reshaped, B_reshaped)
        
        # Reshape to get the final Kronecker product
        return kron_prod.view(batch_size, s1, s2, s3, s4).view(batch_size, s1 * s3, s2 * s4)
    
    def forward(self, x):
        """
        Forward pass using PHM.
        x: Input tensor of shape [batch_size, in_features]
        """
        batch_size = x.size(0)
        
        # Compute the weight matrix using PHM
        if self.factorized_phm:
            # Using factorized representation
            weight = 0
            for r in range(self.rank):
                # Apply scaling factor
                kronecker_factor = self.kronecker_product(
                    self.A[r].unsqueeze(0).repeat(batch_size, 1, 1),
                    self.B[r].unsqueeze(0).repeat(batch_size, 1, 1)
                )
                weight += self.s[r] * kronecker_factor
        else:
            # Using full representation
            weight = torch.sum(self.W * self.s.view(self.rank, 1, 1), dim=0)
        
        # Apply the weight matrix - handle factorized vs full differently
        if self.factorized_phm:
            # For factorized version, we already have batch-specific weights
            output = torch.bmm(x.unsqueeze(1), weight).squeeze(1)
        else:
            # For the full version, we use simple matrix multiply
            output = x @ weight
        
        return output

class CompacterLayer(nn.Module):
    """
    Compacter adapter implementation using PHM for weight parameterization.
    """
    def __init__(self, hidden_size, adapter_size=64, rank=4, factorized_phm=True):
        super().__init__()
        self.hidden_size = hidden_size
        self.adapter_size = adapter_size
        
        # Down projection using PHM
        self.down_proj = PHM(hidden_size, adapter_size, rank=rank, factorized_phm=factorized_phm)
        
        # Activation function
        self.activation = nn.GELU()
        
        # Up projection using PHM
        self.up_proj = PHM(adapter_size, hidden_size, rank=rank, factorized_phm=factorized_phm)
        
        # Additional components
        self.dropout = nn.Dropout(0.1)
        self.layer_norm = nn.LayerNorm(hidden_size)
        
        # Scaling factor for the adapter output
        self.scaling_factor = nn.Parameter(torch.tensor(0.1))
    
    def forward(self, hidden_states):
        """
        Forward pass through the Compacter adapter.
        """
        residual = hidden_states
        
        # Down projection with PHM
        x = self.down_proj(hidden_states)
        x = self.activation(x)
        
        # Up projection with PHM
        x = self.up_proj(x)
        x = self.dropout(x)
        
        # Apply scaling and add residual
        output = residual + self.scaling_factor * x
        output = self.layer_norm(output)
        
        return output

# Integrating Compacter with a transformer layer would be similar to the adapter implementation
class CompacterTransformerLayer(nn.Module):
    def __init__(self, transformer_layer, adapter_size=64, rank=4):
        super().__init__()
        self.layer = transformer_layer
        hidden_size = transformer_layer.attention.self.all_head_size  # Model specific
        
        # Freeze original parameters
        for param in self.layer.parameters():
            param.requires_grad = False
            
        # Add Compacter adapters
        self.attention_adapter = CompacterLayer(hidden_size, adapter_size, rank)
        self.ffn_adapter = CompacterLayer(hidden_size, adapter_size, rank)
    
    def forward(self, hidden_states, attention_mask=None):
        # Original attention mechanism
        attention_output = self.layer.attention(hidden_states, attention_mask)[0]
        
        # Apply attention adapter
        adapted_attention = self.attention_adapter(attention_output)
        
        # Original feed-forward network
        intermediate_output = self.layer.intermediate(adapted_attention)
        layer_output = self.layer.output(intermediate_output, adapted_attention)
        
        # Apply ffn adapter
        output = self.ffn_adapter(layer_output)
        
        return output



Enjoy Reading This Article?

Here are some more articles you might like to read next:

  • Physical Symbol Systems and the Language of Thought
  • Building a Transformer (Cross-Attention and MHA Explained)
  • Understanding Byte-Pair Encoding Algorithm
  • Can AI Achieve True Creativity?
  • Algebraic Foundations of Low-Rank Adaptation