Home

[Return to Index]


Appendix: Architecture Choices

What do we use for the model?

We need a mapping \(f_\theta(\mathbf{x}_t,t) \mapsto \mathbb{R}^d\). For images, a natural choice is a U-Net. The architecture of the U-Net is a series of ResNet blocks arranged into a set of downsamples following by upsamples. Ho et al. (2020) make two modifications to this basic idea:

  1. To promote global sharing of features (rather than relying on local convolutions), they add self-attention layers.
  2. To embed \(t\), they take a corresponding sinusoidal embedding (Vaswani et al. 2017) and pass it through a feed-forward network to modify the biases and (and optionally scales as well) of intermediate activations in each ResNet block, right after each BatchNorm layer.

We provide a simple implementation of a U-Net modified in this way [here]. Below is the code for the ResNet block, which I found helpful for understanding.

class BasicBlock(nn.Module):

    def __init__(self, in_c, out_c, time_c):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, out_c, 3, 1, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_c)
        self.conv2 = nn.Conv2d(out_c, out_c, 3, 1, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_c)
        self.mlp_time = nn.Sequential(
            nn.Linear(time_c, time_c),
            nn.ReLU(),
            nn.Linear(time_c, out_c),
        )
        if in_c == out_c:
            self.shortcut = nn.Identity()
        else:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_c, out_c, 1, 1, bias=False),
                nn.BatchNorm2d(out_c)
            )

    def forward(self, x, t):
        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out + unsqueeze_as(self.mlp_time(t), x))
        out = self.conv2(out)
        out = self.bn2(out)
        out = F.relu(out + self.shortcut(x))
        return out