Skip to content

Clay 1.5 segmentation training with terratorch not working #970

@hkristen

Description

@hkristen

I finally found the timehad a closer look at was discussed in #592 with romeokienzler, as I was trying to move my segmentation training from Clay 1.0 to Clay 1.5 with terratorch.

The problem
The current Clay1_5ModelFactory returns the MAE pretraining losses (loss, reconstruction_loss, representation_loss) instead of a ModelOutput with predictions. Looking at the code:

  • Line 548 in ClayMAE.forward(): returns (loss, reconstruction_loss, representation_loss)
  • Line 89 in Clay1_5ModelFactory: just wraps ClayMAE in ModelWrapper, doesn't use PixelWiseModel or ScalarOutputModel

This breaks downstream tasks because the model is doing MAE pretraining (reconstruction), not feature extraction for segmentation/classification.

The fix
I created an adapter that uses only the ClayMAE encoder for feature extraction (skipping the masking and loss computation). Then I wrap it in the standard PixelWiseModel pipeline, just like Clay 1.0 does.

The key insight is that for downstream tasks we don't want the full MAE forward pass - we just need the encoder to extract features:

class ClayMAEBackbone(nn.Module):
    def forward(self, x):
        # Use encoder only (no masking for downstream tasks)
        patches, waves_encoded = self.clayma.encoder.to_patch_embed(x, waves)
        patches = self.clayma.encoder.add_encodings(patches, time, latlon, gsd)

        # Add CLS token and pass through transformer
        cls_tokens = repeat(self.clayma.encoder.cls_token, "1 1 D -> B 1 D", B=B)
        patches_with_cls = torch.cat((cls_tokens, patches), dim=1)
        encoded = self.clayma.encoder.transformer(patches_with_cls)

        # Remove CLS, reshape to spatial features
        features = encoded[:, 1:, :].reshape(B, self.dim, grid_size, grid_size)
        return [features]

Then the factory wraps this in PixelWiseModel which handles the decoder and head properly.

I've validated this with full training runs - it works for semantic segmentation (~7 it/s, 313M params).

Happy to share the full implementation or help with a PR if you're interested.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions