Source code for FairLangProc.metrics.embedding

"""Submodule inside of the FairLangProc.metrics module which stores all methods and metrics related
with the embeddings of a Language Model.

The WEAT class is flexible enough to implement other embedding metrics like SEAT or CEAT.
"""
# Standard imports
from abc import ABC, abstractmethod
from tqdm import tqdm
from typing import TypeVar

# numpy
import numpy as np

# pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F


TokenizerType = TypeVar("TokenizerType", bound = "PreTrainedTokenizer")


[docs] class WEAT(ABC): """Class for handling WEAT metric with a PyTorch model and tokenizer. Attributes ---------- model : nn.Module PyTorch model (e.g., BERT, GPT from HuggingFace). tokenizer : TokenizerType Tokenizer for the model. device : str Device to run the WEAT test on. Methods ------- metric(W1_words, W2_words, A1_words, A2_words, n_perm, pval) Computation of the WEAT effect size between W1, W2 and A1, A2. _get_embedding(outputs) Abstract method whose implementation is required and which aims to compute the embedding of an output given by the model. """
[docs] def __init__( self, model: nn.Module, tokenizer: TokenizerType, device: str='cuda' ) -> None: r"""Constructor for the WEAT class Parameters ---------- model : nn.Module PyTorch model (e.g., BERT, GPT from HuggingFace). tokenizer : TokenizerType Tokenizer for the model. device : str Device to run the WEAT test on. """ self.model = model self.tokenizer = tokenizer self.device = device self.model.to(device) self.model.eval()
def get_embeddings(self, words: list[str]) -> torch.Tensor: """Get embeddings for a list of words using the LLM.""" embeddings = [] for word in words: # Tokenize and get embeddings inputs = self.tokenizer(word, return_tensors='pt').to(self.device) with torch.no_grad(): outputs = self.model(**inputs, output_hidden_states=True) # Get hidden states from specified layer word_embedding = self._get_embedding(outputs) embeddings.append(word_embedding) return torch.stack(embeddings)
[docs] @abstractmethod def _get_embedding(self, outputs): r"""Abstract method that instructs the class on how to obtain the embedding of a given input.""" pass
def cosine_similarity(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Compute cosine similarity between two tensors.""" return F.cosine_similarity(x.unsqueeze(1), y.unsqueeze(0), dim=-1) def effect_size(self, X: torch.Tensor, Y: torch.Tensor, A: torch.Tensor, B: torch.Tensor ) -> float: r"""Compute WEAT effect size. Parameters ---------- X : torch.Tensor Target concept 1 embeddings (n_X, dim) Y : torch.Tensor Target concept 2 embeddings (n_Y, dim) A : torch.Tensor Attribute 1 embeddings (n_A, dim) B : torch.Tensor Attribute 2 embeddings (n_B, dim) Returns ------- Effect size : float Example ------- >>> from transformers import AutoTokenizer, AutoModel >>> from FairLangProc.metrics import WEAT >>> class BertWEAT(WEAT): ... def _get_embedding(self, outputs): ... return outputs.last_hidden_state[:, 0, :] >>> tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') >>> model = AutoModel.from_pretrained('bert-base-uncased') >>> math = ['math', 'algebra', 'geometry', 'calculus', 'equations'] >>> arts = ['poetry', 'art', 'dance', 'literature', 'novel'] >>> masc = ['male', 'man', 'boy', 'brother', 'he'] >>> femn = ['female', 'woman', 'girl', 'sister', 'she'] >>> >>> weatClass = BertWEAT(model = model, tokenizer = tokenizer) >>> weatClass.metric( W1_words = math, W2_words = arts, A1_words = masc, A2_words = femn, pval = False ) """ # Compute similarities x_a = self.cosine_similarity(X, A).mean() x_b = self.cosine_similarity(X, B).mean() y_a = self.cosine_similarity(Y, A).mean() y_b = self.cosine_similarity(Y, B).mean() # Difference in mean similarities diff_x = x_a - x_b diff_y = y_a - y_b # Pooled standard deviation x_diffs = self.cosine_similarity(X, A) - self.cosine_similarity(X, B) y_diffs = self.cosine_similarity(Y, A) - self.cosine_similarity(Y, B) std_x = x_diffs.std(unbiased=False) std_y = y_diffs.std(unbiased=False) pooled_std = torch.sqrt((std_x**2 + std_y**2) / 2) return ((diff_x - diff_y) / pooled_std).item() def p_value(self, X: torch.Tensor, Y: torch.Tensor, A: torch.Tensor, B: torch.Tensor, n_perm: int = 10000) -> float: r"""Compute p-value using permutation test. Parameters ---------- X, Y, A, B : torch.Tensor Embedding tensors n_perm : int Number of permutations Returns ------- p-value : float """ combined = torch.cat([X, Y]) size_X = X.size(0) observed_effect = self.effect_size(X, Y, A, B) count = 0 for _ in tqdm(range(n_perm), desc="Running permutations"): # Shuffle and split perm = combined[torch.randperm(combined.size(0))] X_perm = perm[:size_X] Y_perm = perm[size_X:] # Compute effect for this permutation effect = self.effect_size(X_perm, Y_perm, A, B) if effect > observed_effect: count += 1 return (count + 1) / (n_perm + 1) # Add 1 for smoothing
[docs] def metric( self, W1_words: list[str], W2_words: list[str], A1_words: list[str], A2_words: list[str], n_perm: int = 10000, pval: bool = True ) -> dict[str, float]: r"""Run WEAT test. Parameters ---------- W1_words : list[str] Target concept 1 words/sentences W2_words : list[str] Target concept 2 words A1_words : list[str] Attribute 1 words/sentences A2_words : list[str] Attribute 2 words/sentences n_perm : int Number of permutations for p-value pval : bool Whether to compute or not the p-value Returns ------- results : dict[str, float] Dictionary with test results, namely mean similarity between W1, W2 and A1, A2; their sizes, the WEAT effect size and the p-value if needed. """ # Get embeddings X = self.get_embeddings(W1_words) Y = self.get_embeddings(W2_words) A = self.get_embeddings(A1_words) B = self.get_embeddings(A2_words) # Compute mean similarities x_a = self.cosine_similarity(X, A).mean().item() x_b = self.cosine_similarity(X, B).mean().item() y_a = self.cosine_similarity(Y, A).mean().item() y_b = self.cosine_similarity(Y, B).mean().item() results = { 'X-A_mean_sim': x_a, 'X-B_mean_sim': x_b, 'Y-A_mean_sim': y_a, 'Y-B_mean_sim': y_b, 'W1_size': len(W1_words), 'W2_size': len(W2_words), 'A1_size': len(A1_words), 'A2_size': len(A2_words) } # Compute statistics effect = self.effect_size(X, Y, A, B) results['effect_size'] = effect if pval: p_val = self.p_value(X, Y, A, B, n_perm) results['p_value']= p_val return results
class BertWEAT(WEAT): """Class with implementation of _get_embedding for bidirectional transformers """ def _get_embedding(self, outputs): return outputs.last_hidden_state[:, 0, :]