mcd.nn.MCDNet#

class mcd.nn.MCDNet(x_dim, get_score_gamma_t, d_t=4, d_h=64, depth=3, act=<CompiledFunction of <function silu>>, *, key)[source]#

Bases: Module

Residual MLP network similar to the one used in the experiments in Score-Based Diffusion meets Annealed Importance Sampling. Parametrized as \(c_{\theta}(t, x) + (1 + s_{\theta}(t, x)) \nabla \log \gamma_t(x)\).

__call__(t, x)[source]#

Call self as a function.

Return type

Array

__init__(x_dim, get_score_gamma_t, d_t=4, d_h=64, depth=3, act=<CompiledFunction of <function silu>>, *, key)[source]#

Initializes the network with the output set to \(\gamma_t(x)\).

Parameters
  • x_dim (int) – dimensionality of x.

  • get_score_gamma_t (Callable[[Array, Array], Array]) – \(\gamma_t(x)\), which is used to initialize the network’s output to the standard AIS backward kernel.

  • d_t (int) – dimensionality of t embedding.

  • d_h (int) – dimensionalixy of x embedding.

  • depth (int) – number of residual MLP blocks.

  • act (Callable[[Array], Array]) – activation function.

  • key (PRNGKeyArray) – PRNG key used to initialize layers.

act: Callable[[Array], Array][source]#
const_scale: Array[source]#

constant multiplying the \(c_{\theta}(t, x)\) network, initialized to zero.

d_h: int[source]#
d_t: int[source]#
depth: int[source]#
final_layer: Linear[source]#
get_score_gamma_t: Callable[[Array, Array], Array][source]#
res_layers: List[MCDNetResBlock][source]#
score_scale: Array[source]#

constant multiplying the \(s_{\theta}(t, x)\) network, initialized to zero.

t_emb: Linear[source]#
x_dim: int[source]#
x_emb: Linear[source]#