Source code for FairLangProc.algorithms.intraprocessors.modular

"""Submodule inside of the FairLangProc.algorithms.intraprocessors module which stores all
processors related with the addition of modular subnetworks.

The supported method is Difference Pruning.
"""

# Standard library
import os
import math
import warnings
from enum import Enum, auto
from tqdm import tqdm
from functools import reduce
import contextlib
from typing import Union, List, Tuple, Optional, Dict, Callable
from collections import OrderedDict
from abc import ABC, abstractmethod

# Pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
from torch.utils.data import DataLoader
import torch.nn.utils.parametrize as parametrize

# Transformers
from transformers import AutoModel

# Custom imports
from FairLangProc.algorithms.output import CustomOutput



# SOURCE:
# https://github.com/CPJKU/ModularizedDebiasing/tree/main




#=====================================================================
#                           UTILS
#=====================================================================


def dict_to_device(d: dict, device: Union[str, torch.device]) -> dict:
    return {k:v.to(device) for k,v in d.items()}


def get_param_from_name(
    model: torch.nn.Module,
    param_name: str
):
    return reduce(lambda a,b: getattr(a,b), [model] + param_name.split("."))


def concrete_stretched(
    alpha: torch.Tensor,
    l: Union[float, int] = -1.5,
    r: Union[float, int] = 1.5,
    deterministic: bool = False
) -> torch.Tensor:
    if not deterministic:
        u = torch.zeros_like(alpha).uniform_().clamp_(0.001, 0.999)
        u_term = u.log() - (1-u).log()
    else:
        u_term = 0.
    s = (torch.sigmoid(u_term + alpha))
    s_stretched = s*(r-l) + l
    z = s_stretched.clamp(0, 1000).clamp(-1000, 1)
    return z



#=====================================================================
#                       PARAMETRIZATION CLASSES
#=====================================================================



class DiffWeightFinetune(nn.Module): # pragma: no cover

    def __init__(
        self,
        weight: nn.Parameter,
        alpha_init: float,
        concrete_lower: float,
        concrete_upper: float,
        structured: bool
    ):
        super().__init__()
        self.concrete_lower = concrete_lower
        self.concrete_upper = concrete_upper
        self.structured = structured

        self.register_parameter("finetune", Parameter(torch.clone(weight)))
        self.register_parameter("alpha", Parameter(torch.zeros_like(weight) + alpha_init))

        if structured:
            self.register_parameter("alpha_group", Parameter(torch.zeros((1,), device=weight.device) + alpha_init))

        self.active = True

    def forward(self, X):
        if not self.active: return X
        diff = (self.finetune - X).detach()
        return (self.finetune - diff) + self.diff_weight(X)

    def diff_weight(self, X):
        return self.z * (self.finetune - X)

    @property
    def z(self) -> Parameter:
        z = self.dist(self.alpha)
        if self.structured:
            z *= self.dist(self.alpha_group)
        return z

    @property
    def alpha_weights(self) -> list:
        alpha = [self.alpha]
        if self.structured:
            alpha.append(self.alpha_group)
        return alpha

    def dist(self, alpha) -> torch.Tensor:
        return concrete_stretched(
            alpha,
            l=self.concrete_lower,
            r=self.concrete_upper,
            deterministic=(not self.training)
        )

    def set_frozen(self, frozen: bool) -> None:
        self.finetune.requires_grad = not frozen
        self.alpha.requires_grad = not frozen
        if self.structured:
            self.alpha_group.requires_grad = not frozen
        if frozen:
            self.eval()
        else:
            self.train()



class DiffWeightFixmask(nn.Module):

    def __init__(self, diff_weight: torch.Tensor, mask: torch.Tensor):
        super().__init__()
        self.register_parameter("diff_weight", Parameter(diff_weight * mask))
        self.register_parameter("mask", Parameter(mask, requires_grad=False))
        self.active = True

    def forward(self, X):
        if not self.active: return X
        return X + self.mask * self.diff_weight

    def set_frozen(self, frozen: bool) -> None:
        self.diff_weight.requires_grad = not frozen


#=====================================================================
#                           MODEL CLASS
#=====================================================================

class ModelState(Enum):
    INIT = auto()
    FINETUNING = auto()
    FIXMASK = auto()


class BaseModel(nn.Module, ABC):

    @property
    def encoder_module(self) -> torch.nn.Module:
        if isinstance(self.encoder, torch.nn.DataParallel):
            return self.encoder.module
        else:
            return self.encoder


    @property
    def device(self) -> torch.device:
        return next(self.encoder.parameters()).device


    @property
    def model_type(self) -> str:
        return self.encoder_module.config.model_type


    @property
    def model_name(self) -> str:
        return self.encoder_module.config._name_or_path


    @property
    def hidden_size(self) -> int:
        return self.encoder_module.embeddings.word_embeddings.embedding_dim


    @property
    def total_layers(self) -> int:
        possible_keys = ["num_hidden_layers", "n_layer"]
        cfg = self.encoder_module.config
        for k in possible_keys:
            if k in cfg.__dict__:
                return getattr(cfg, k) + 1 # +1 for embedding layer and last layer
        raise Exception("number of layers of pre trained model could not be determined")


    def __init__(self, encoder: nn.Module = None, encoder_state_dict: OrderedDict = None):
        super().__init__()
        self.encoder = encoder

        if encoder_state_dict is not None:
            self.encoder.load_state_dict(encoder_state_dict)
            self.state_dict_init = True
        else:
            self.state_dict_init = False

    @abstractmethod
    def _forward(self, **x) -> torch.Tensor:
        pass


    def get_layer_idx_from_module(self, module_name: str) -> int:
        # get layer index based on module name
        if self.model_type == "xlnet":
            search_str_emb = "word_embedding"
            search_str_hidden = "layer"
        else:
            search_str_emb = "embeddings"
            search_str_hidden = "encoder.layer"

        if search_str_emb in module_name:
            return 0
        elif search_str_hidden in module_name:
            return int(module_name.split(search_str_hidden + ".")[1].split(".")[0]) + 1
        elif 'pooler.dense' in module_name:
            return self.total_layers - 1
        else:
            warnings.warn(f"layer idx could not be determined for module_name {module_name}")


#    def to(self, device: Union[list, Union[str, torch.device]], *args, **kwargs) -> None:
#        self._remove_parallel()
#        if isinstance(device, list):
#            super().to(device[0])
#            if len(device)>1:
#                asssert_fn = lambda x: x=="cuda" if isinstance(x, str) else x.type=="cuda"
#                assert all([asssert_fn(d) for d in device]), "if list of devices is given, all must be of type 'cuda'"
#                self.encoder = torch.nn.DataParallel(self.encoder, device_ids=device)
#        else:
#            super().to(device)
#
#
#    def cpu(self):
#        self._remove_parallel()
#        super().cpu()
#
#
#    def cuda(self, *args, **kwargs) -> None:
#        self._remove_parallel()
#        super().cuda(*args, **kwargs)
#
#
#    def _remove_parallel(self) -> None:
#        if isinstance(self.encoder, torch.nn.DataParallel):
#            self.encoder = self.encoder.module
#



class BasePruningModel(BaseModel):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.model_state = ModelState.INIT
        self.fixmask_pct = None


    @property
    def parametrized(self) -> bool:
        return (self.model_state == ModelState.FINETUNING or self.model_state == ModelState.FIXMASK)


    @property
    def fixmask_state(self) -> bool:
        return self.model_state == ModelState.FIXMASK


    @property
    def finetune_state(self) -> bool:
        return self.model_state == ModelState.FINETUNING


    @property
    def n_parametrizations(self) -> int:
        return len(list(self.get_encoder_base_modules()[0].parametrizations.values())[0])


    @staticmethod
    def get_log_ratio(concrete_lower: float, concrete_upper: float) -> int:
        return 0 if (concrete_lower == 0) else math.log(-concrete_lower / concrete_upper)


    @staticmethod
    def get_l0_norm_term(alpha: torch.Tensor, log_ratio: float) -> torch.Tensor:
        return torch.sigmoid(alpha - log_ratio).sum()


    def get_encoder_base_modules(self, return_names: bool = False):
        if self.parametrized:
            check_fn = lambda m: hasattr(m, "parametrizations")
        else:
            check_fn = lambda m: len(m._parameters)>0
        return [(n,m) if return_names else m for n,m in self.encoder.named_modules() if check_fn(m)]


    def get_sparsity_pen(self, sparsity_pen: Union[float, List[float], Tuple[float]]) -> None:
        if isinstance(sparsity_pen, (list, tuple)):
            assert len(sparsity_pen) == self.total_layers,  "invalid sparsity penalty per layer: # of layers mismatch"
            return sparsity_pen
        else:
            return [sparsity_pen/self.total_layers] * self.total_layers


    def _get_sparsity_loss(self, log_ratio: float, sparsity_pen: list, idx: int) -> torch.Tensor:
        assert self.finetune_state, "model needs to be in finetuning state"
        l0_pen = 0.
        for module_name, base_module in self.get_encoder_base_modules(return_names=True):
            layer_idx = self.get_layer_idx_from_module(module_name)
            module_pen = 0.
            for n, par_list in list(base_module.parametrizations.items()):
                for a in par_list[idx].alpha_weights:
                    module_pen += self.get_l0_norm_term(a, log_ratio)
            l0_pen += (module_pen * sparsity_pen[layer_idx])
        return l0_pen


    @torch.no_grad()
    def _count_non_zero_params(self, *args, **kwargs) -> Tuple[int, int, int]:
        assert self.parametrized, "Function only implemented for diff pruning"

        l = [self._count_non_zero_params_for_module(m, *args, **kwargs) for m in self.get_encoder_base_modules()]
        return [sum(x) for x in list(zip(*l))]


    @torch.no_grad()
    def _count_non_zero_params_per_layer(self, *args, **kwargs) -> Dict[int, Tuple[int, int, int]]:
        assert self.parametrized, "Function only implemented for diff pruning"

        t = torch.zeros((self.total_layers, 3), dtype=int)
        for module_name, base_module in self.get_encoder_base_modules(return_names=True):
            layer_idx = self.get_layer_idx_from_module(module_name)
            counts = self._count_non_zero_params_for_module(base_module, *args, **kwargs)
            t[layer_idx] += torch.tensor(counts)
        return {i:v.tolist() for i,v in enumerate(t)}


    @torch.no_grad()
    def _count_non_zero_params_for_module(self, m: torch.nn.Module, idx: Optional[int] = None, merged: bool = False) -> Tuple[int, int, int]:

        def count_fn(p, binary: bool):
            if binary:
                p = p.bool()
                n_p = p.numel()
                n_p_zero = (~p).sum()
                n_p_one = (n_p - n_p_zero)
            else:
                n_p = p.numel()
                n_p_zero = (p == 0.).sum()
                n_p_one = (p == 1.).sum()
            return torch.tensor([n_p, n_p_zero, n_p_one])

        assert hasattr(m, "parametrizations"), "module has no parametrizations"
        p_counts = torch.zeros((3,), dtype=int)
        with self.deterministic():
            for n, par_list in list(m.parametrizations.items()):
                if merged:
                    if isinstance(par_list[0], DiffWeightFixmask):
                        p = torch.stack([x.mask for x in par_list]).sum(0)
                    else:
                        p = torch.stack([(x.z != 0.) for x in par_list]).sum(0)
                    p_counts += count_fn(p, True)
                else:
                    if idx is not None: par_list = [par_list[idx]]
                    for par in par_list:
                        p = par.mask if isinstance(par, DiffWeightFixmask) else par.z
                        p_counts += count_fn(p, p.dtype==torch.bool)

        return p_counts.tolist()


    def _remove_parametrizations(self, leave_parametrized: bool = True) -> None:
        self._freeze_parametrizations(True)
        for module in self.get_encoder_base_modules():
            try:
                for n in list(module.parametrizations):
                    parametrize.remove_parametrizations(module, n, leave_parametrized=leave_parametrized)
            except AttributeError:
                pass
        self.model_state = ModelState.INIT


    def _add_diff_parametrizations(self, n_parametrizations: int = 1, p_requires_grad: bool = False, fixmask_init: bool = False, **kwargs) -> None:
        assert not self.parametrized, "cannot add diff parametrizations because of existing parametrizations in the model"
        for base_module in self.get_encoder_base_modules():
            for n,p in list(base_module.named_parameters()):
                p.requires_grad = p_requires_grad
                for _ in range(n_parametrizations): # number of diff networks to add
                    if fixmask_init:
                        # in case of fixmask init, can only initalize with dummy values
                        parametrize.register_parametrization(base_module, n, DiffWeightFixmask(
                            torch.zeros_like(p), torch.ones_like(p, dtype=bool)
                        ))
                    else:
                        parametrize.register_parametrization(base_module, n, DiffWeightFinetune(p, **kwargs))
        if fixmask_init:
            self.model_state = ModelState.FIXMASK
        else:
            self.model_state = ModelState.FINETUNING


    def _parametrizations_fn(self, fn: Callable, idx: Optional[int] = None):
        for base_module in self.get_encoder_base_modules():
            try:
                for par_list in base_module.parametrizations.values():
                    if idx is not None:
                        try:
                            fn(par_list[idx])
                        except IndexError:
                            pass
                    else:
                        for par in par_list:
                            fn(par)
            except AttributeError:
                pass


    def _parametrizations_set_attr(self, attr: str, v, idx: Optional[int] = None):
        fn = lambda x: setattr(x, attr, v)
        self._parametrizations_fn(fn, idx)


    def _activate_parametrizations(self, active: bool, idx: int):
        fn = lambda x: setattr(x, "active", active)
        self._parametrizations_fn(fn, idx)


    def _freeze_parametrizations(self, frozen: bool, idx: Optional[int] = None):
        fn = lambda x: x.set_frozen(frozen)
        self._parametrizations_fn(fn, idx)


    def _freeze_original_parameters(self, frozen: bool):
        for base_module in self.get_encoder_base_modules():
            for par_list in base_module.parametrizations.values():
                par_list.original.requires_grad = not frozen


    @torch.no_grad()
    def _finetune_to_fixmask(
        self,
        pct: Optional[float] = None,
        sequential: Union[bool, list, tuple] = True,
        merged_cutoff: bool = False,
        merged_min_pct: float = 0.01
    ) -> None:

        if isinstance(sequential, (list, tuple)):
            assert len(sequential) == self.n_parametrizations, "if sequential is list, needs to equal self.n_parametrizations"
        else:
            sequential = [sequential] * self.n_parametrizations

        def _get_cutoff(values, pct, abs = True):
            k = int(round(len(values) * pct, 0))
            if abs: values = torch.abs(values)
            return torch.topk(values, k+1, largest=True, sorted=True)[0][-1]

        assert self.model_state == ModelState.FINETUNING, "model needs to be in finetuning state"

        with self.deterministic():

            if pct is not None:

                diff_weights_abs = [torch.tensor([])] * self.n_parametrizations
                for base_module in self.get_encoder_base_modules():
                    for n, par_list in list(base_module.parametrizations.items()):
                        w = par_list.original.detach()
                        for idx, seq in enumerate(sequential):
                            diff_weight = par_list[idx].diff_weight(w)
                            diff_weights_abs[idx] = torch.cat([diff_weights_abs[idx], torch.abs(diff_weight.flatten().cpu())])
                            if seq: w = diff_weight + w

                if merged_cutoff and (self.n_parametrizations > 1):
                    min_cutoffs = [_get_cutoff(x, merged_min_pct, abs=False) for x in diff_weights_abs]
                    if merged_min_pct >= pct:
                        print(f"merged_min_pct >= pct, using target sparsity merged_min_pct={merged_min_pct}")
                        cutoffs = min_cutoffs
                    else:
                        remaining = torch.cat([x[x<c] for x,c in zip(diff_weights_abs, min_cutoffs)])
                        remaining_cutoff = _get_cutoff(remaining, pct - merged_min_pct)
                        cutoffs = [min(remaining_cutoff, c) for c in min_cutoffs]
                else:
                    cutoffs = [_get_cutoff(x, pct, abs=False) for x in diff_weights_abs]

            for base_module in self.get_encoder_base_modules():
                for n, par_list in list(base_module.parametrizations.items()):
                    diff_weights = []
                    w = par_list.original
                    for idx, seq in enumerate(sequential):
                        diff_weight = par_list[idx].diff_weight(w)
                        if pct is not None:
                            i = 0 if merged_cutoff else idx
                            diff_mask = (torch.abs(diff_weight) > cutoffs[i])
                        else:
                            diff_mask = ~torch.isclose(diff_weight, torch.tensor(0.), rtol=1e-8)
                        diff_weights.append((diff_weight, diff_mask))
                        if seq: w = diff_weight + w

                    parametrize.remove_parametrizations(base_module, n, leave_parametrized=False)
                    for (diff_weight, diff_mask) in diff_weights:
                        parametrize.register_parametrization(base_module, n, DiffWeightFixmask(diff_weight, diff_mask))

        self.model_state = ModelState.FIXMASK
        self.fixmask_pct = pct


    def _get_diff_param_groups(
        self,
        learning_rate: float,
        weight_decay: float = 0.0,
        learning_rate_alpha: Optional[float] = None,
        idx: Optional[int] = None,
    ) -> list:

        if idx is None:
            idx_len = 0
            idx = ""
        else:
            idx_len = len(str(idx))

        if self.model_state == ModelState.FIXMASK:
            return [
                {
                    "params": [p for n,p in self.encoder.named_parameters() if n[-(12+idx_len):] == f"{idx}.diff_weight"],
                    "weight_decay": weight_decay,
                    "lr": learning_rate
                }
            ]
        else:
            return [
                {
                    "params": [p for n,p in self.encoder.named_parameters() if n[-(9+idx_len):] == f"{idx}.finetune"],
                    "weight_decay": weight_decay,
                    "lr": learning_rate
                },
                {
                    "params": [p for n,p in self.encoder.named_parameters() if n[-(6+idx_len):]==f"{idx}.alpha" or n[-(12+idx_len):]==f"{idx}.alpha_group"],
                    "lr": learning_rate_alpha
                }
            ]


    @contextlib.contextmanager
    def deterministic(self):
        tmp_state = self.training
        if tmp_state: self.eval()
        yield
        if tmp_state: self.train()


    def _as_module(self, named_parameter_list: list):
        new_model = AutoModel.from_pretrained(self.model_name)
        with torch.no_grad():
            for p_name, p in named_parameter_list:
                _p = reduce(lambda a,b: getattr(a,b), [new_model] + p_name.split("."))
                _p.copy_(p)
        return new_model


    def get_diff_weights(self, idx: int, as_module: bool = False):
        res = []
        p_names = [n[:-9] for n, _ in self.encoder.named_parameters() if n[-9:]==".original"]
        with torch.no_grad():
            for p_name in p_names:
                par_list = reduce(lambda a,b: getattr(a,b), [self.encoder] + p_name.split("."))
                par = par_list[idx]
                if isinstance(par, DiffWeightFixmask):
                    diff_weight = par.mask * par.diff_weight
                elif isinstance(par, DiffWeightFinetune):
                    w = par_list.original.detach()
                    diff_weight = par.diff_weight(w)
                res.append((p_name.replace(".parametrizations", ""), diff_weight))

        if as_module:
            return self._as_module(res)
        else:
            return res


    def get_base_weights(self, as_module: bool = False):
        if self.parametrized:
            res = [(n[:-9].replace(".parametrizations", ""), p) for n,p in self.encoder.named_parameters() if n[-9:]==".original"]
        else:
            res = list(self.encoder.named_parameters())

        if as_module:
            return self._as_module(res)
        else:
            return res


    @torch.no_grad()
    def load_state_dict_to_parametrizations(
        self,
        encoder_state_dict: OrderedDict,
        idx: int = 0
    ):
        assert not any([("parametrizations" in k) for k in encoder_state_dict.keys()]), "cant use parametrized state dict"

        for k,v in encoder_state_dict.items():
            k_parts = k.split(".")
            try:
                par_list = get_param_from_name(self.encoder, ".".join(k_parts[:-1] + ["parametrizations"]))
            except:
                continue
            par = get_param_from_name(par_list, ".".join([k_parts[-1], str(idx)]))
            if isinstance(par, DiffWeightFinetune):
                par.finetune.copy_(v)
            elif isinstance(par, DiffWeightFixmask):
                diff = v - par_list.original
                par.diff_weight.copy_(par.mask * diff)




[docs] class DiffPrunDebiasing(BasePruningModel, ABC): r"""Implements differ pruning for bias mitigation in pretrained models. Requires the implementation of the '_forward' method, similar to '_ get_embedding' in other classes in that it should compute the embedding given some inputs. Example ------- >>> from FairLangProc.algorithms.intraprocessors import DiffPrunBERT >>> >>> class DiffPrunBERT(DiffPrunDebiasing): ... def _forward(self, input_ids, attention_mask=None, token_type_ids=None): ... outputs = self.encoder( ... input_ids = input_ids, ... attention_mask = attention_mask, ... token_type_ids = token_type_ids ... ) ... return outputs.last_hidden_state[:,0,:] >>> >>> tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') >>> model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased') >>> >>> gendered_pairs = [("manager", "manageress"), ("nephew", "niece"), ("prince", "princess"), ("baron", "baroness")] >>> tokens_male = [words[0] for words in gendered_pairs] >>> tokens_female = [words[1] for words in gendered_pairs] >>> inputs_male = tokenizer(tokens_male, padding = True, return_tensors = "pt") >>> inputs_female = tokenizer(tokens_female, padding = True, return_tensors = "pt") >>> >>> def normalize_by_column(x: torch.Tensor, eps: float = 1e-8): ... mean = x.mean(dim=0, keepdim=True) ... std = x.std(dim=0, keepdim=True) ... return (x - mean) / (std + eps) >>> >>> ModularDebiasingBERT = DiffPrunBERT( ... head = model.classifier, ... encoder = model.bert, ... loss_fn = torch.nn.CrossEntropyLoss(), ... input_ids_A = inputs_male, ... input_ids_B = inputs_female, ... bias_kernel = normalize_by_column, ... upper = 10, ... lower = -0.001, ... lambda_bias = 0.5, ... lambda_sparse = 0.00001 ... ) >>> trainer = Trainer( ... model=ModularDebiasingBERT, ... args=training_args, ... train_dataset=train_dataset, ... eval_dataset=val_dataset, ... optimizers=( ... AdamW(ModularDebiasingBERT.parameters(), lr=1e-5, weight_decay=0.1), ... None ... ) ... ) >>> trainer.train() >>> results = trainer.evaluate() >>> print(results) """
[docs] def __init__( self, head: nn.Module, encoder: nn.Module, loss_fn: Callable, input_ids_A: torch.Tensor, input_ids_B: torch.Tensor, lambda_sparse: float = 1.0, lambda_bias: float = 1.0, bias_kernel: Callable = None, fixmask_init: bool = False, alpha_init: Optional[Union[int, float]] = 5, structured_diff_prunning: Optional[bool] = False, upper: float = 1.1, lower: float = -0.1 ): r""" Constructor of the DiffPrunDebiasing class. Parameters ---------- head : nn.Module Head used for the task at hand (classification, question answering,...). encoder : nn.Module Pretrained model (e.g., BERT, GPT-2). loss_fn : Callable Loss function. input_ids_A : torch.Tensor Tensor with ids of text with demographic information of group A. input_ids_B : torch.Tensor Tensor with ids of text with demographic information of group B. lambda_sparse : float Weight for sparsity loss. lambda_bias : float Weight for bias mitigation loss. bias_kernel : Callable Kernel for the embeddings of the bias loss. If None, defaults to the identity. fixmask_init : bool If true, uses DiffWeightFixmask (i.e. only masks) instead of DiffWeightFinetune (i.e. smooth pruning). alpha_init : Optional[Union[int, float]] Initialization value for the log alpha parameters. structured_diff_prunning : Optional[bool] If true, adds a group structure to the diff pruning process (see DiffWeightFinetune) upper : float Parameter for concrete relaxation (has to be > 1). lower : float Parameter for concrete relaxation (has to be < 0). """ super().__init__(encoder = encoder) self.head = head self.loss_fn = loss_fn self.lambda_sparse = lambda_sparse self.sparsity_pen = self.get_sparsity_pen(self.lambda_sparse) self.upper = upper self.lower = lower self.log_ratio = math.log(-self.upper/self.lower) self.lambda_bias = lambda_bias self.kernel = bias_kernel self.inputs_A = input_ids_A self.inputs_B = input_ids_B self._add_diff_parametrizations( n_parametrizations = 1, p_requires_grad = False, fixmask_init = fixmask_init, alpha_init = alpha_init, concrete_lower = self.lower, concrete_upper = self.upper, structured = structured_diff_prunning )
def _get_bias_loss(self): """Compute debias loss as the difference of the kernel of the counterfactual pairs. """ # Get hidden states from last layer group_a = self._forward(**self.inputs_A) group_b = self._forward(**self.inputs_B) if self.kernel is not None: group_a = self.kernel(group_a) group_b = self.kernel(group_b) group_a_mean = group_a.mean(dim=0) group_b_mean = group_b.mean(dim=0) return F.mse_loss( group_a_mean.requires_grad_(True), group_b_mean.requires_grad_(True) ) def forward(self, input_ids, attention_mask=None, token_type_ids = None, labels=None): """Forward pass.""" outputs = self._forward(input_ids = input_ids, attention_mask = attention_mask, token_type_ids = token_type_ids) logits = self.head(outputs) if labels is not None: task_loss = self._loss(logits, labels) bias_loss = self.lambda_bias*self._get_bias_loss() sparse_loss = self._get_sparsity_loss(log_ratio = self.log_ratio, sparsity_pen = self.sparsity_pen, idx = 0) loss = task_loss + bias_loss + sparse_loss return CustomOutput(loss = loss, logits = logits) else: return CustomOutput(logits = logits) def _loss(self, output, target): """Loss function.""" return self.loss_fn(output, target) def to(self, device): """Override to() to handle device transfer consistently""" super().to(device) # Move input tensors self.inputs_A = {k: v.to(device) for k, v in self.inputs_A.items()} self.inputs_B = {k: v.to(device) for k, v in self.inputs_B.items()} return self
class DiffPrunBERT(DiffPrunDebiasing): """Concrete implementation for the BERT model.""" def _forward(self, input_ids, attention_mask=None, token_type_ids=None): outputs = self.encoder( input_ids = input_ids, attention_mask = attention_mask, token_type_ids = token_type_ids ) return outputs.last_hidden_state[:,0,:]