mcd.nn.MCDNet#
- class mcd.nn.MCDNet(x_dim, get_score_gamma_t, d_t=4, d_h=64, depth=3, act=<CompiledFunction of <function silu>>, *, key)[source]#
Bases:
ModuleResidual MLP network similar to the one used in the experiments in Score-Based Diffusion meets Annealed Importance Sampling. Parametrized as \(c_{\theta}(t, x) + (1 + s_{\theta}(t, x)) \nabla \log \gamma_t(x)\).
- __init__(x_dim, get_score_gamma_t, d_t=4, d_h=64, depth=3, act=<CompiledFunction of <function silu>>, *, key)[source]#
Initializes the network with the output set to \(\gamma_t(x)\).
- Parameters
x_dim (
int) – dimensionality of x.get_score_gamma_t (
Callable[[Array,Array],Array]) – \(\gamma_t(x)\), which is used to initialize the network’s output to the standard AIS backward kernel.d_t (
int) – dimensionality of t embedding.d_h (
int) – dimensionalixy of x embedding.depth (
int) – number of residual MLP blocks.act (
Callable[[Array],Array]) – activation function.key (
PRNGKeyArray) – PRNG key used to initialize layers.
- const_scale: Array[source]#
constant multiplying the \(c_{\theta}(t, x)\) network, initialized to zero.