Skip to content

Commit

Permalink
Update train_image.py
Browse files Browse the repository at this point in the history
  • Loading branch information
longmingsheng authored Dec 23, 2018
1 parent b14e2f8 commit 0995248
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pytorch/train_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def train(config):
features = torch.cat((features_source, features_target), dim=0)
outputs = torch.cat((outputs_source, outputs_target), dim=0)
softmax_out = nn.Softmax(dim=1)(outputs)
if config['method'] == 'CDAN-E':
if config['method'] == 'CDAN+E':
entropy = loss.Entropy(softmax_out)
transfer_loss = loss.CDAN([features, softmax_out], ad_net, entropy, network.calc_coeff(i), random_layer)
elif config['method'] == 'CDAN':
Expand All @@ -195,7 +195,7 @@ def train(config):

if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Conditional Domain Adversarial Network')
parser.add_argument('method', type=str, default='CDAN-E', choices=['CDAN', 'CDAN-E', 'DANN'])
parser.add_argument('method', type=str, default='CDAN+E', choices=['CDAN', 'CDAN+E', 'DANN'])
parser.add_argument('--gpu_id', type=str, nargs='?', default='0', help="device id to run")
parser.add_argument('--net', type=str, default='ResNet50', choices=["ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152", "VGG11", "VGG13", "VGG16", "VGG19", "VGG11BN", "VGG13BN", "VGG16BN", "VGG19BN", "AlexNet"])
parser.add_argument('--dset', type=str, default='office', choices=['office', 'image-clef', 'visda', 'office-home'], help="The dataset or source dataset used")
Expand Down

0 comments on commit 0995248

Please sign in to comment.