-
Notifications
You must be signed in to change notification settings - Fork 127
Description
I finally found the timehad a closer look at was discussed in #592 with romeokienzler, as I was trying to move my segmentation training from Clay 1.0 to Clay 1.5 with terratorch.
The problem
The current Clay1_5ModelFactory returns the MAE pretraining losses (loss, reconstruction_loss, representation_loss) instead of a ModelOutput with predictions. Looking at the code:
- Line 548 in
ClayMAE.forward(): returns(loss, reconstruction_loss, representation_loss) - Line 89 in
Clay1_5ModelFactory: just wrapsClayMAEinModelWrapper, doesn't usePixelWiseModelorScalarOutputModel
This breaks downstream tasks because the model is doing MAE pretraining (reconstruction), not feature extraction for segmentation/classification.
The fix
I created an adapter that uses only the ClayMAE encoder for feature extraction (skipping the masking and loss computation). Then I wrap it in the standard PixelWiseModel pipeline, just like Clay 1.0 does.
The key insight is that for downstream tasks we don't want the full MAE forward pass - we just need the encoder to extract features:
class ClayMAEBackbone(nn.Module):
def forward(self, x):
# Use encoder only (no masking for downstream tasks)
patches, waves_encoded = self.clayma.encoder.to_patch_embed(x, waves)
patches = self.clayma.encoder.add_encodings(patches, time, latlon, gsd)
# Add CLS token and pass through transformer
cls_tokens = repeat(self.clayma.encoder.cls_token, "1 1 D -> B 1 D", B=B)
patches_with_cls = torch.cat((cls_tokens, patches), dim=1)
encoded = self.clayma.encoder.transformer(patches_with_cls)
# Remove CLS, reshape to spatial features
features = encoded[:, 1:, :].reshape(B, self.dim, grid_size, grid_size)
return [features]Then the factory wraps this in PixelWiseModel which handles the decoder and head properly.
I've validated this with full training runs - it works for semantic segmentation (~7 it/s, 313M params).
Happy to share the full implementation or help with a PR if you're interested.