What do we use for the model?
We need a mapping \(f_\theta(\mathbf{x}_t,t) \mapsto \mathbb{R}^d\). For images, a natural choice is a U-Net. The architecture of the U-Net is a series of ResNet blocks arranged into a set of downsamples following by upsamples. Ho et al. (2020) make two modifications to this basic idea:
We provide a simple implementation of a U-Net modified in this way [here]. Below is the code for the ResNet block, which I found helpful for understanding.
class BasicBlock(nn.Module):
def __init__(self, in_c, out_c, time_c):
super().__init__()
self.conv1 = nn.Conv2d(in_c, out_c, 3, 1, 1, bias=False)
self.bn1 = nn.BatchNorm2d(out_c)
self.conv2 = nn.Conv2d(out_c, out_c, 3, 1, 1, bias=False)
self.bn2 = nn.BatchNorm2d(out_c)
self.mlp_time = nn.Sequential(
nn.Linear(time_c, time_c),
nn.ReLU(),
nn.Linear(time_c, out_c),
)
if in_c == out_c:
self.shortcut = nn.Identity()
else:
self.shortcut = nn.Sequential(
nn.Conv2d(in_c, out_c, 1, 1, bias=False),
nn.BatchNorm2d(out_c)
)
def forward(self, x, t):
out = self.conv1(x)
out = self.bn1(out)
out = F.relu(out + unsqueeze_as(self.mlp_time(t), x))
out = self.conv2(out)
out = self.bn2(out)
out = F.relu(out + self.shortcut(x))
return out