Skip to content

Commit

Permalink
Update SettingC-SO2-Spread-B.py
Browse files Browse the repository at this point in the history
Correct steerer
  • Loading branch information
georg-bn authored Jun 30, 2024
1 parent a51da82 commit 9c79b60
Showing 1 changed file with 20 additions and 14 deletions.
34 changes: 20 additions & 14 deletions experiments/SettingC-SO2-Spread-B.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,45 +76,51 @@ def train(detector_weights):

generator = torch.block_diag(
torch.zeros(
[NUM_PROTOTYPES // 2,
NUM_PROTOTYPES // 2],
[NUM_PROTOTYPES - 12 * (NUM_PROTOTYPES // 14),
NUM_PROTOTYPES - 12 * (NUM_PROTOTYPES // 14)],
device='cuda',
),
*(
torch.tensor([[0., 1],
[-1, 0]],
device='cuda')
for _ in range(NUM_PROTOTYPES // 8)
device='cuda',
)
for _ in range(NUM_PROTOTYPES // 14)
),
*(
torch.tensor([[0., 2],
[-2, 0]],
device='cuda')
for _ in range(NUM_PROTOTYPES // 16)
device='cuda',
)
for _ in range(NUM_PROTOTYPES // 14)
),
*(
torch.tensor([[0., 3],
[-3, 0]],
device='cuda')
for _ in range(NUM_PROTOTYPES // 32)
device='cuda',
)
for _ in range(NUM_PROTOTYPES // 14)
),
*(
torch.tensor([[0., 4],
[-4, 0]],
device='cuda')
for _ in range(NUM_PROTOTYPES // 64)
device='cuda',
)
for _ in range(NUM_PROTOTYPES // 14)
),
*(
torch.tensor([[0., 5],
[-5, 0]],
device='cuda')
for _ in range(NUM_PROTOTYPES // 128)
device='cuda',
)
for _ in range(NUM_PROTOTYPES // 14)
),
*(
torch.tensor([[0., 6],
[-6, 0]],
device='cuda')
for _ in range(NUM_PROTOTYPES // 128)
device='cuda',
)
for _ in range(NUM_PROTOTYPES // 14)
),
)
steerer = ContinuousSteerer(generator)
Expand Down

0 comments on commit 9c79b60

Please sign in to comment.