Skip to content

target/source correspondence with integers is brittle #1677

@clessig

Description

@clessig

What happened?

Target/source correspondence is currently with integers, addressing the list entries int the model_input and target_input sections of the config. This is brittle, e.g. since source/target inputs can be disabled so that, effectively, their index changes. We should hence use the names that are present anyway for the source target config. E.g, the relevant part of config_physical_jepa.yml should be:


  losses : {
    "physical": {
        type: LossPhysical,
        weight: 0.5,
        loss_fcts: {
          "mse": {
            weight: 0.8,
            target_source_correspondence: { "target_physical" : { "input_jepa" : "complement"} },
          },
        },
        target_and_aux_calc: "Physical",
    },
    "student-teacher": {
        enabled: True,
        type: LossLatentSSLStudentTeacher,
        weight: 0.5,
        loss_fcts : {
          "JEPA": {
            'weight': 8, "loss_extra_args": {}, "out_dim": 2048, "head": transformer,
            "num_blocks": 6, "num_heads": 12, "with_qk_lnorm": True, "intermediate_dim": 768, 
            "dropout_rate": 0.1,
            target_source_correspondence: {"target_jepa" : { "input_jepa" : "subset"} },
          },
        },
        target_and_aux_calc: { "EMATeacher" : 
          { ema_ramp_up_ratio : 0.09,
            ema_halflife_in_thousands: 1e-3,
            model_param_overrides : { 
              training_config: { losses: { student-teacher:{ loss_fcts :{JEPA: {head: identity} }}}}
            },
          }
        }
    }
  }
  
  model_input: {
    "input_physical" : {
      # masking strategy: "random", "forecast"
      masking_strategy: "random",
      num_samples: 1,
      num_steps_input: 1,
      masking_strategy_config : { 
        diffusion_rn : True, 
        rate : 0.6,
        rate_sampling: False
      },
    },
    "input_jepa" : {
      # masking strategy: "random", "forecast"
      masking_strategy: "random",
      num_samples: 1,
      num_steps_input: 1,
      masking_strategy_config : { 
        diffusion_rn : True, 
        rate : 0.6,
        rate_sampling: False
      },
    },
  }

  target_input: {
    "target_physical" : {
      masking_strategy: "random",
      num_samples: 1,
      masking_strategy_config : { rate : 0.2, hl_mask: 0, rate_sampling: False },
    },
    "target_jepa" : {
      masking_strategy: "healpix",
      num_samples: 1,
      masking_strategy_config : { rate : 0.2, hl_mask: 0, rate_sampling: False },
    },
  }

What are the steps to reproduce the bug?

No response

Hedgedoc link to logs and more information. This ticket is public, do not attach files directly.

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingmodelRelated to model training or definition (not generic infra)

    Type

    Projects

    Status

    Todo

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions