-
Notifications
You must be signed in to change notification settings - Fork 51
Open
Labels
bugSomething isn't workingSomething isn't workingmodelRelated to model training or definition (not generic infra)Related to model training or definition (not generic infra)
Description
What happened?
Target/source correspondence is currently with integers, addressing the list entries int the model_input and target_input sections of the config. This is brittle, e.g. since source/target inputs can be disabled so that, effectively, their index changes. We should hence use the names that are present anyway for the source target config. E.g, the relevant part of config_physical_jepa.yml should be:
losses : {
"physical": {
type: LossPhysical,
weight: 0.5,
loss_fcts: {
"mse": {
weight: 0.8,
target_source_correspondence: { "target_physical" : { "input_jepa" : "complement"} },
},
},
target_and_aux_calc: "Physical",
},
"student-teacher": {
enabled: True,
type: LossLatentSSLStudentTeacher,
weight: 0.5,
loss_fcts : {
"JEPA": {
'weight': 8, "loss_extra_args": {}, "out_dim": 2048, "head": transformer,
"num_blocks": 6, "num_heads": 12, "with_qk_lnorm": True, "intermediate_dim": 768,
"dropout_rate": 0.1,
target_source_correspondence: {"target_jepa" : { "input_jepa" : "subset"} },
},
},
target_and_aux_calc: { "EMATeacher" :
{ ema_ramp_up_ratio : 0.09,
ema_halflife_in_thousands: 1e-3,
model_param_overrides : {
training_config: { losses: { student-teacher:{ loss_fcts :{JEPA: {head: identity} }}}}
},
}
}
}
}
model_input: {
"input_physical" : {
# masking strategy: "random", "forecast"
masking_strategy: "random",
num_samples: 1,
num_steps_input: 1,
masking_strategy_config : {
diffusion_rn : True,
rate : 0.6,
rate_sampling: False
},
},
"input_jepa" : {
# masking strategy: "random", "forecast"
masking_strategy: "random",
num_samples: 1,
num_steps_input: 1,
masking_strategy_config : {
diffusion_rn : True,
rate : 0.6,
rate_sampling: False
},
},
}
target_input: {
"target_physical" : {
masking_strategy: "random",
num_samples: 1,
masking_strategy_config : { rate : 0.2, hl_mask: 0, rate_sampling: False },
},
"target_jepa" : {
masking_strategy: "healpix",
num_samples: 1,
masking_strategy_config : { rate : 0.2, hl_mask: 0, rate_sampling: False },
},
}
What are the steps to reproduce the bug?
No response
Hedgedoc link to logs and more information. This ticket is public, do not attach files directly.
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingmodelRelated to model training or definition (not generic infra)Related to model training or definition (not generic infra)
Type
Projects
Status
Todo