Source code for FairLangProc.algorithms.inprocessors.adapter

"""Submodule inside of the FairLangProc.algorithms.inprocessors module which stores all
processors related with the addition of adapters

The supported method is ADELE.
"""

# Standard imports
from typing import Union

# Pytorch
import torch
import torch.nn as nn

# Adapters
import adapters


[docs] class DebiasAdapter(nn.Module): """Implements ADELE debiasing based on bottleneck adapter. Example ------- >>> from adapters import AdapterTrainer >>> from FairLangProc.algorithms.inprocessors import DebiasAdapter >>> >>> DebiasAdapter = DebiasAdapter( ... model = AutoModel.from_pretrained('bert-base-uncased'), ... adapter_config = "seq_bn" ... ) >>> AdeleModel = DebiasAdapter.get_model() >>> >>> trainer = AdapterTrainer( ... model=AdeleModel, ... args=training_args, ... train_dataset=train_CDA, ... eval_dataset=val_dataset, ... optimizers=( ... AdamW(AdeleModel.parameters(),lr=1e-5, weight_decay=0.1), ... None ... ) ... ) >>> trainer.train() >>> results = trainer.evaluate() >>> print(results) """
[docs] def __init__( self, model: nn.Module, adapter_name: str = "debias_adapter", adapter_config: Union[str, dict] = "seq_bn", ) -> None: r"""Constructor of the DebiasAdapter class. Parameters ---------- model : nn.Module Pretrained model (e.g., BERT, GPT-2) adapter_name : str Tensor with ids of text with demographic information of group A adapter_config : Union[str, dict] Name or dictionary of the desired configuration for the adapter (bottleneck by default) """ super().__init__() self.adapter_name = adapter_name adapters.init(model) self.model = model # Verify support if not hasattr(self.model, "add_adapter"): raise ValueError("Model does not support adapters.") # Load adapter config if isinstance(adapter_config, str): config = adapters.AdapterConfig.load(adapter_config) elif isinstance(adapter_config, dict): config = adapters.AdapterConfig(**adapter_config) else: config = adapter_config # Add adapter and set it up self.model.add_adapter(adapter_name, config=config) self.model.set_active_adapters(adapter_name) self.model.train_adapter(self.adapter_name)
def forward(self, **kwargs): return self.model(**kwargs) def get_model(self): return self.model def save_adapter(self, save_path: str): # pragma: no cover self.model.save_adapter(save_path, self.adapter_name) def load_adapter(self, path: str): # pragma: no cover self.model.load_adapter(path, load_as=self.adapter_name) self.model.set_active_adapters(self.adapter_name)