UNet( time_embedding = Chain( SinusoidalPositionEmbedding(100 => 64), Dense(64 => 64, gelu), # 4_160 parameters Dense(64 => 64), # 4_160 parameters ), chain = ConditionalChain( init = Conv((3, 3), 1 => 16, pad=1), # 160 parameters down_1 = ResBlock( in_layers = ConvEmbed( embed_layers = Chain( Dense(64 => 16), # 1_040 parameters ), conv = Conv((3, 3), 16 => 16, pad=1), # 2_320 parameters norm = GroupNorm(16, 8), # 32 parameters activation = NNlib.swish, ), out_layers = Chain( Conv((3, 3), 16 => 16, pad=1), # 2_320 parameters GroupNorm(16, 8), # 32 parameters NNlib.swish, ), skip_transform = identity, ), skip_1 = ConditionalSkipConnection( ConditionalChain( downsample_1 = Conv((4, 4), 16 => 16, pad=1, stride=2), # 4_112 parameters down_2 = ResBlock( in_layers = ConvEmbed( embed_layers = Chain( NNlib.swish, Dense(64 => 16), # 1_040 parameters ), conv = Conv((3, 3), 16 => 16, pad=1), # 2_320 parameters norm = GroupNorm(16, 8), # 32 parameters activation = NNlib.swish, ), out_layers = Chain( Conv((3, 3), 16 => 16, pad=1), # 2_320 parameters GroupNorm(16, 8), # 32 parameters NNlib.swish, ), skip_transform = identity, ), skip_2 = ConditionalSkipConnection( ConditionalChain( downsample_2 = Conv((4, 4), 16 => 32, pad=1, stride=2), # 8_224 parameters down_3 = ResBlock( in_layers = ConvEmbed( embed_layers = Chain( NNlib.swish, Dense(64 => 32), # 2_080 parameters ), conv = Conv((3, 3), 32 => 32, pad=1), # 9_248 parameters norm = GroupNorm(32, 8), # 64 parameters activation = NNlib.swish, ), out_layers = Chain( Conv((3, 3), 32 => 32, pad=1), # 9_248 parameters GroupNorm(32, 8), # 64 parameters NNlib.swish, ), skip_transform = identity, ), skip_3 = ConditionalSkipConnection( ConditionalChain( down_4 = Conv((3, 3), 32 => 48, pad=1), # 13_872 parameters middle_1 = ResBlock( in_layers = ConvEmbed( embed_layers = Chain( NNlib.swish, Dense(64 => 48), # 3_120 parameters ), conv = Conv((3, 3), 48 => 48, pad=1), # 20_784 parameters norm = GroupNorm(48, 8), # 96 parameters activation = NNlib.swish, ), out_layers = Chain( Conv((3, 3), 48 => 48, pad=1), # 20_784 parameters GroupNorm(48, 8), # 96 parameters NNlib.swish, ), skip_transform = identity, ), middle_attention = SkipConnection( MultiheadAttention( nhead = 4, to_qkv = Conv((3, 3), 48 => 144, pad=1, bias=false), # 62_208 parameters to_out = Conv((3, 3), 48 => 48, pad=1), # 20_784 parameters ), +, ), middle_2 = ResBlock( in_layers = ConvEmbed( embed_layers = Chain( NNlib.swish, Dense(64 => 48), # 3_120 parameters ), conv = Conv((3, 3), 48 => 48, pad=1), # 20_784 parameters norm = GroupNorm(48, 8), # 96 parameters activation = NNlib.swish, ), out_layers = Chain( Conv((3, 3), 48 => 48, pad=1), # 20_784 parameters GroupNorm(48, 8), # 96 parameters NNlib.swish, ), skip_transform = identity, ), ), DenoisingDiffusion.cat_on_channel_dim, ), up_3 = ResBlock( in_layers = ConvEmbed( embed_layers = Chain( NNlib.swish, Dense(64 => 48), # 3_120 parameters ), conv = Conv((3, 3), 80 => 48, pad=1), # 34_608 parameters norm = GroupNorm(48, 8), # 96 parameters activation = NNlib.swish, ), out_layers = Chain( Conv((3, 3), 48 => 48, pad=1), # 20_784 parameters GroupNorm(48, 8), # 96 parameters NNlib.swish, ), skip_transform = Conv((3, 3), 80 => 48, pad=1), # 34_608 parameters ), upsample_3 = Chain( Upsample(:nearest, scale = (2, 2)), Conv((3, 3), 48 => 32, pad=1), # 13_856 parameters ), ), DenoisingDiffusion.cat_on_channel_dim, ), up_2 = ResBlock( in_layers = ConvEmbed( embed_layers = Chain( NNlib.swish, Dense(64 => 32), # 2_080 parameters ), conv = Conv((3, 3), 48 => 32, pad=1), # 13_856 parameters norm = GroupNorm(32, 8), # 64 parameters activation = NNlib.swish, ), out_layers = Chain( Conv((3, 3), 32 => 32, pad=1), # 9_248 parameters GroupNorm(32, 8), # 64 parameters NNlib.swish, ), skip_transform = Conv((3, 3), 48 => 32, pad=1), # 13_856 parameters ), upsample_2 = Chain( Upsample(:nearest, scale = (2, 2)), Conv((3, 3), 32 => 16, pad=1), # 4_624 parameters ), ), DenoisingDiffusion.cat_on_channel_dim, ), up_1 = ResBlock( in_layers = ConvEmbed( embed_layers = Chain( NNlib.swish, Dense(64 => 16), # 1_040 parameters ), conv = Conv((3, 3), 32 => 16, pad=1), # 4_624 parameters norm = GroupNorm(16, 8), # 32 parameters activation = NNlib.swish, ), out_layers = Chain( Conv((3, 3), 16 => 16, pad=1), # 2_320 parameters GroupNorm(16, 8), # 32 parameters NNlib.swish, ), skip_transform = Conv((3, 3), 32 => 16, pad=1), # 4_624 parameters ), final = Conv((3, 3), 16 => 1, pad=1), # 145 parameters ), ) # Total: 107 trainable arrays, 403_409 parameters, # plus 1 non-trainable, 6_400 parameters, summarysize 1.592 MiB.