Source code for torchnaut.epistemic

import torch
from torch import nn
import numpy as np


[docs] def get_kl_term(net): """Calculate total KL divergence term for all Bayesian layers in network. Args: net: Neural network containing Bayesian layers Returns: Sum of KL divergence terms from all Bayesian layers """ kl = 0 for m in net.modules(): if hasattr(m, "get_kl_term"): kl += m.get_kl_term() return kl
[docs] class BayesianParameter(nn.Module): """Bayesian parameter that uses the reparameterization trick for sampling. Maintains a variational posterior distribution over the parameter values and computes KL divergence against a specified prior distribution. Args: shape: Shape of the parameter tensor prior_mu: Mean of the prior normal distribution prior_sigma: Standard deviation of the prior normal distribution """ def __init__(self, shape, prior_mu, prior_sigma): super().__init__() self.shape = shape self.mu = nn.Parameter( torch.ones(*shape) * prior_mu + torch.randn(*shape) / np.sqrt(shape[-1]) ) self.rho = nn.Parameter( torch.ones(*shape) * torch.log(torch.exp(torch.tensor(prior_sigma)) - 1) ) self.register_buffer("prior_mu", torch.zeros_like(self.mu) + prior_mu) self.register_buffer("prior_sigma", torch.zeros_like(self.rho) + prior_sigma)
[docs] def get_kl_term(self): return torch.distributions.kl.kl_divergence( torch.distributions.Normal(self.mu, nn.functional.softplus(self.rho)), torch.distributions.Normal(self.prior_mu, self.prior_sigma), ).sum()
[docs] def forward(self): epsilon = torch.randn(*self.shape, device=self.mu.device) return self.mu + nn.functional.softplus(self.rho.clamp(min=-30)) * epsilon
[docs] class BayesianLinear(nn.Module): """Fully connected layer with Bayesian parameters. Uses variational inference with reparameterization to sample weights and biases from posterior distributions during forward passes. Args: in_features: Size of input features out_features: Size of output features prior_mu: Mean of the prior normal distribution for weights prior_sigma: Standard deviation of the prior normal distribution for weights """ def __init__(self, in_features, out_features, prior_mu=0, prior_sigma=0.1): super().__init__() self.in_features = in_features self.out_features = out_features self.weight = BayesianParameter( (in_features, out_features), prior_mu, prior_sigma ) self.bias = BayesianParameter((out_features,), 0, prior_sigma)
[docs] def forward(self, x): weight = self.weight() bias = self.bias() return x @ weight + bias
[docs] class CRPSEnsemble(nn.Module): """Ensemble of networks for CRPS computation. Wraps multiple networks and runs forward passes through all of them, concatenating the outputs along `concat_dim` for CRPS calculation. Args: networks: List of networks to include in the ensemble """ def __init__(self, networks, concat_dim=-2): super().__init__() self.networks = nn.ModuleList(networks) self.concat_dim = concat_dim
[docs] def forward(self, *args, **kwargs): # trickery to handle a variable number of outputs outputs = [network(*args, **kwargs) for network in self.networks] outputs = zip(*(o if isinstance(o, (tuple, list)) else (o,) for o in outputs)) outputs_cat = [ torch.cat([o.unsqueeze(self.concat_dim) for o in out], dim=self.concat_dim) for out in outputs ] if len(outputs_cat)==1: return outputs_cat[0] else: return tuple(outputs_cat)