torchnaut package

Submodules

torchnaut.crps module

class torchnaut.crps.EpsilonSampler(n_dim)[source]

Bases: Module

Layer that adds random normal samples to enable probabilistic predictions.

This layer transforms input tensors by concatenating random samples from a standard normal distribution. The number of samples can be controlled globally using the context manager interface or per-call using the n_samples parameter.

Parameters:

sample_dim (int) – Number of random dimensions to add

Example

>>> sampler = EpsilonSampler(16)
>>> # Default number of samples (100)
>>> out = sampler(x)  # Shape: [batch, 100, features+16]
>>>
>>> # Override samples for a specific call
>>> out = sampler(x, n_samples=1000)  # Shape: [batch, 1000, features+16]
>>>
>>> # Use context manager to temporarily change default samples
>>> with EpsilonSampler.n_samples(500):
...     out = sampler(x)  # Shape: [batch, 500, features+16]
>>> out = sampler(x)  # Back to default 100 samples
forward(x, n_samples=None)[source]

Forward pass adding random normal samples.

Parameters:
  • x (torch.Tensor) – Input tensor

  • n_samples (int, optional) – Override number of samples for this call. If None, uses the current default value.

classmethod n_samples(n_samples, force=False)[source]
classmethod get_n_samples(default=None)[source]
torchnaut.crps.crps_loss(yps, y)[source]

Calculates the Continuous Ranked Probability Score (CRPS) loss.

Parameters:
  • yps – Tensor of predicted samples [batch x num_samples]

  • y – Target values [batch x 1]

Returns:

CRPS loss value per batch element

torchnaut.crps.crps_loss_weighted(yps, w, y)[source]

Calculates the weighted Continuous Ranked Probability Score (CRPS) loss.

Parameters:
  • yps – Tensor of predicted samples [batch x num_samples]

  • w – Sample weights [batch x num_samples]

  • y – Target values [batch x 1]

Returns:

Weighted CRPS loss value per batch element

torchnaut.crps.crps_loss_mv(yps, y)[source]

Calculates the multivariate CRPS (Energy Score) loss.

Parameters:
  • yps – Tensor of predicted samples [batch x num_samples x dims]

  • y – Target values [batch x dims]

Returns:

Multivariate CRPS (Energy Score) loss value per batch element

torchnaut.crps.crps_loss_mv_weighted(yps, w, y)[source]

Calculates the weighted multivariate CRPS (Energy Score) loss.

Parameters:
  • yps – Tensor of predicted samples [batch x num_samples x dims]

  • w – Sample weights [batch x num_samples]

  • y – Target values [batch x dims]

Returns:

Weighted multivariate CRPS (Energy Score) loss value per batch element

torchnaut.epistemic module

torchnaut.epistemic.get_kl_term(net)[source]

Calculate total KL divergence term for all Bayesian layers in network.

Parameters:

net – Neural network containing Bayesian layers

Returns:

Sum of KL divergence terms from all Bayesian layers

class torchnaut.epistemic.BayesianParameter(shape, prior_mu, prior_sigma)[source]

Bases: Module

Bayesian parameter that uses the reparameterization trick for sampling.

Maintains a variational posterior distribution over the parameter values and computes KL divergence against a specified prior distribution.

Parameters:
  • shape – Shape of the parameter tensor

  • prior_mu – Mean of the prior normal distribution

  • prior_sigma – Standard deviation of the prior normal distribution

get_kl_term()[source]
forward()[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class torchnaut.epistemic.BayesianLinear(in_features, out_features, prior_mu=0, prior_sigma=0.1)[source]

Bases: Module

Fully connected layer with Bayesian parameters.

Uses variational inference with reparameterization to sample weights and biases from posterior distributions during forward passes.

Parameters:
  • in_features – Size of input features

  • out_features – Size of output features

  • prior_mu – Mean of the prior normal distribution for weights

  • prior_sigma – Standard deviation of the prior normal distribution for weights

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class torchnaut.epistemic.CRPSEnsemble(networks, concat_dim=-2)[source]

Bases: Module

Ensemble of networks for CRPS computation.

Wraps multiple networks and runs forward passes through all of them, concatenating the outputs along concat_dim for CRPS calculation.

Parameters:

networks – List of networks to include in the ensemble

forward(*args, **kwargs)[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

torchnaut.kde module

torchnaut.kde.nll_gpu(pred_samples, y)[source]

Calculate negative log likelihood using adaptive kernel density estimation.

Uses Silverman’s rule for global bandwidth and adaptive local bandwidth scaling.

Parameters:
  • pred_samples – Predicted samples [batch x num_samples]

  • y – Target values [batch, 1]

Returns:

Negative log likelihood values per batch element

torchnaut.kde.nll_gpu_weighted(pred_samples, pred_weights, y, max_pilot_samples=None, batch_size=16)[source]

Calculate negative log likelihood using weighted adaptive kernel density estimation.

Uses Silverman’s rule for global bandwidth and adaptive local bandwidth scaling.

Parameters:
  • pred_samples – Predicted samples [batch x num_samples]

  • pred_weights – Sample weights [batch x num_samples]

  • y – Target values [batch, 1]

  • max_pilot_samples – Maximum number of pilot samples for bandwidth estimation

  • batch_size – Batch size for memory-efficient computation

Returns:

Negative log likelihood values per batch element

torchnaut.kde.get_kde(pred_samples, pred_weights=None, max_pilot_samples=None, batch_size=16)[source]

Create weighted kernel density estimation distributions.

Uses Silverman’s rule for global bandwidth and adaptive local bandwidth scaling. Returns mixture distributions that can be used for sampling and likelihood computation.

Parameters:
  • pred_samples – Predicted samples [batch x num_samples]

  • pred_weights – Sample weights [batch x num_samples]. If None, equal weights are used.

  • max_pilot_samples – Maximum number of pilot samples for bandwidth estimation

  • batch_size – Batch size for memory-efficient computation

Returns:

List of MixtureSameFamily distributions for each batch element

torchnaut.mdn module

class torchnaut.mdn.MDN(n_components)[source]

Bases: Module

Univariate Mixture Density Network utility class.

Expected network output shape: [batch x num_components * 3]

get_dist(p)[source]

Convert network output to mixture distribution object.

Parameters:

p – Input tensor [batch x num_components * 3] containing (mu, sigma, pi)

Returns:

PyTorch mixture distribution object

log_likelihood(p, y, min_log_proba=-inf)[source]

Calculate log likelihood of mixture density network output.

Parameters:
  • p – Transformed output tensor

  • y – Target values

  • min_log_proba – Minimum log probability for clamping

Returns:

Log likelihood values per batch element

expected_value(p)[source]

Calculate expected value of the mixture distribution.

Parameters:

p – Transformed output tensor

Returns:

Expected value per batch element

sample(p, n=100)[source]

Draw samples from the mixture distribution.

Parameters:
  • p – Transformed output tensor

  • n – Number of samples to draw

Returns:

n samples from the mixture distribution

inverse_transform(p, labelscaler)[source]

Inverse transform the output tensor.

Parameters:
  • p – Transformed output tensor

  • labelscaler (LabelScaler) – LabelScaler object for inverse transformation

Returns:

Inverse transformed output tensor

class torchnaut.mdn.MDNMV(n_components, target_dim)[source]

Bases: Module

Multivariate Mixture Density Network utility class.

Expected network output shape: [batch x self.network_output_dim]

get_dist(p)[source]

Convert network output to mixture distribution object. Also handles all necessary transformations (activations and clamping).

Parameters:

p – Network output tensor containing mixture parameters

Returns:

PyTorch mixture distribution object

log_likelihood(p, y, min_log_proba=-inf)[source]

Calculate log likelihood of mixture density network output.

Parameters:
  • p – Network output tensor

  • y – Target values

  • min_log_proba – Minimum log probability for clamping

Returns:

Log likelihood values per batch element

expected_value(p)[source]

Calculate expected value of the mixture distribution.

Parameters:

p – Output tensor

Returns:

Expected value per batch element

sample(p, n=100)[source]

Draw samples from the mixture distribution.

Parameters:
  • p – Output tensor

  • n – Number of samples to draw

Returns:

n samples from the distribution [batch x n x dims]

torchnaut.utils module

torchnaut.utils.get_batch_ixs(ref_tensor, batch_size=16, permute=False)[source]

Generate batch indices for mini-batch processing.

Parameters:
  • ref_tensor – Reference tensor to determine total size

  • batch_size – Size of each batch

  • permute – Whether to randomly permute indices

Returns:

List of index tensors for each batch

class torchnaut.utils.LabelScaler[source]

Bases: object

A scaler for preprocessing label data with optional PCA transformation.

This class provides functionality to scale and transform label data by: 1. Centering the data by subtracting the mean 2. Optionally applying PCA dimensionality reduction 3. Scaling the data to unit variance

The transformation can be reversed using the inverse_transform method.

fit_transform(arr, pca_dims=-1, pca_whiten=True)[source]

Fit the scaler to the data and transform it.

Parameters:
  • arr (numpy.ndarray) – Input array of shape [dataset_size, *feature_dims]

  • pca_dims (int) – Number of PCA components. If -1, no PCA is applied

  • pca_whiten (bool) – Whether to apply whitening in PCA transformation

Returns:

Transformed array of shape:
  • [dataset_size, prod(feature_dims)] if pca_dims=-1

  • [dataset_size, pca_dims] if PCA is applied

Return type:

numpy.ndarray

transform(arr)[source]

Transform new data using the fitted scaler.

Parameters:

arr (numpy.ndarray) – Input array of shape [dataset_size, *feature_dims] Must match the dimensions of the data used in fit_transform

Returns:

Transformed array of shape:
  • [dataset_size, prod(feature_dims)] if pca_dims=-1

  • [dataset_size, pca_dims] if PCA is applied

Return type:

numpy.ndarray

inverse_transform(arr_scaled)[source]

Inverse transform scaled data back to the original space.

Parameters:

arr_scaled (numpy.ndarray or torch.Tensor) – Scaled input array of shape [*batch_dims, n_features] where n_features matches the output dimension of transform()

Returns:

Array in original space with shape

[*batch_dims, *feature_dims] where feature_dims matches the original input dimensions

Return type:

numpy.ndarray or torch.Tensor

Example

If original data was shape [1000, 32, 32]: - Can handle inputs of shape [64, 1024] -> [64, 32, 32] - Can handle inputs of shape [64, 20, 1024] -> [64, 20, 32, 32]

torchnaut.utils.calculate_pit_cdf(preds, y, weights=None)[source]

Calculate the Probability Integral Transform (PIT) and its CDF.

Parameters:
  • preds (numpy.ndarray) – Model predictions of shape [num_predictions, num_samples]

  • weights (numpy.ndarray) – Weights for each prediction of shape [num_predictions, num_samples]

  • y (numpy.ndarray) – Ground truth values of shape [num_predictions]

Returns:

Contains:
  • numpy.ndarray: Reference percentiles (linspace from 0 to 1)

  • numpy.ndarray: Empirical CDF of the PIT values

Return type:

tuple

Module contents