redbnn.bayesian_inference package¶
Submodules¶
redbnn.bayesian_inference.hmc module¶
- redbnn.bayesian_inference.hmc.forward(redbnn, inputs, n_samples, sample_idxs=None, softmax=True)¶
Forward pass of the inputs through the network using the chosen number of samples.
- Parameters
inputs (torch.tensor) – Input images.
n_samples (int, optional) – Number of samples drawn during the evaluation.
samples_idxs (list, optional) – Random seeds used for drawing samples. If samples_idxs is None it is defined as the range of integers from 0 to the maximum number of samples.
softmax (bool, optional) – If True computes the softmax of each output tensor.
- Returns
Output predictions
- Return type
(torch.Tensor)
- redbnn.bayesian_inference.hmc.load(redbnn, savedir, filename, hmc_samples)¶
Loads the learned parameters.
- Parameters
savedir (str) – Output directory.
filename (str) – Filename.
hmc_samples (str) – Number of samples drawn during HMC inference, needed for loading models trained with HMC.
- redbnn.bayesian_inference.hmc.model(redbnn, x_data, y_data)¶
Stochastic function that implements the generative process and is conditioned on the observations.
- Parameters
x_data (torch.tensor) – Observed data points.
y_data (torch.tensor) – Labels of the observed data.
- redbnn.bayesian_inference.hmc.save(redbnn, savedir, filename, hmc_samples)¶
Saves the learned parameters as torch.tensors on the CPU.
- Parameters
savedir (str) – Output directory.
filename (str) – Filename.
hmc_samples (str) – Number of samples drawn during HMC inference, needed for saving models trained with HMC.
- redbnn.bayesian_inference.hmc.to(device)¶
Sends pyro parameters to the chosen device.
- Parameters
device (str) – Name of the chosen device.
- redbnn.bayesian_inference.hmc.train(redbnn, dataloaders, device, n_samples, warmup, is_inception)¶
Freezes the deterministic parameters and infers the Bayesian paramaters using Hamiltonian Monte Carlo.
- Parameters
dataloaders (dict) – Dictionary containing training and validation torch dataloaders.
device (str) – Device chosen for training.
n_samples (int) – Number of Hamiltonian Monte Carlo samples.
warmup (int) – Number of Hamiltonian Monte Carlo warmup samples.
redbnn.bayesian_inference.svi module¶
- redbnn.bayesian_inference.svi.forward(redbnn, inputs, n_samples, sample_idxs=None, softmax=True)¶
Forward pass of the inputs through the network using the chosen number of samples.
- Parameters
inputs (torch.tensor) – Input images.
n_samples (int, optional) – Number of samples drawn during the evaluation.
samples_idxs (list, optional) – Random seeds used for drawing samples. If samples_idxs is None it is defined as the range of integers from 0 to the maximum number of samples.
softmax (bool, optional) – If True computes the softmax of each output tensor.
- Returns
Output predictions
- Return type
(torch.Tensor)
- redbnn.bayesian_inference.svi.guide(redbnn, x_data, y_data=None)¶
Variational distribution.
- Parameters
x_data (torch.tensor) – Input data points.
y_data (torch.tensor, optional) – Labels of the input data.
- redbnn.bayesian_inference.svi.load(redbnn, savedir, filename)¶
Loads the learned parameters.
- Parameters
savedir (str) – Output directory.
filename (str) – Filename.
- redbnn.bayesian_inference.svi.model(redbnn, x_data, y_data)¶
Stochastic function that implements the generative process and is conditioned on the observations.
- Parameters
x_data (torch.tensor) – Observed data points.
y_data (torch.tensor) – Labels of the observed data.
- redbnn.bayesian_inference.svi.save(redbnn, savedir, filename)¶
Saves the learned parameters on the CPU.
- Parameters
savedir (str) – Output directory.
filename (str) – Filename.
- redbnn.bayesian_inference.svi.to(device)¶
Sends pyro parameters to the chosen device.
- Parameters
device (str) – Name of the chosen device.
- redbnn.bayesian_inference.svi.train(redbnn, dataloaders, device, num_iters, is_inception, lr=0.01)¶
Freezes the deterministic parameters and infers the Bayesian paramaters using the chosen inference method.
- Parameters
dataloaders (dict) – Dictionary containing training and validation torch dataloaders.
device (str) – Device chosen for training.
num_iters (int) – Number of iterations for Stochastic Variational Inference.
lr (float, optional) – Learning rate for SVI.