Skip to content

Commit

Permalink
adding dcgan
Browse files Browse the repository at this point in the history
  • Loading branch information
dendisuhubdy committed May 14, 2019
1 parent b918379 commit 269c433
Show file tree
Hide file tree
Showing 6 changed files with 314 additions and 2 deletions.
88 changes: 88 additions & 0 deletions .clang-format
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
---
AccessModifierOffset: -1
AlignAfterOpenBracket: AlwaysBreak
AlignConsecutiveAssignments: false
AlignConsecutiveDeclarations: false
AlignEscapedNewlinesLeft: true
AlignOperands: false
AlignTrailingComments: false
AllowAllParametersOfDeclarationOnNextLine: false
AllowShortBlocksOnASingleLine: false
AllowShortCaseLabelsOnASingleLine: false
AllowShortFunctionsOnASingleLine: Empty
AllowShortIfStatementsOnASingleLine: false
AllowShortLoopsOnASingleLine: false
AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: true
AlwaysBreakTemplateDeclarations: true
BinPackArguments: false
BinPackParameters: false
BraceWrapping:
AfterClass: false
AfterControlStatement: false
AfterEnum: false
AfterFunction: false
AfterNamespace: false
AfterObjCDeclaration: false
AfterStruct: false
AfterUnion: false
BeforeCatch: false
BeforeElse: false
IndentBraces: false
BreakBeforeBinaryOperators: None
BreakBeforeBraces: Attach
BreakBeforeTernaryOperators: true
BreakConstructorInitializersBeforeComma: false
BreakAfterJavaFieldAnnotations: false
BreakStringLiterals: false
ColumnLimit: 80
CommentPragmas: '^ IWYU pragma:'
CompactNamespaces: false
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Cpp11BracedListStyle: true
DerivePointerAlignment: false
DisableFormat: false
ForEachMacros: [ FOR_EACH_RANGE, FOR_EACH, ]
IncludeCategories:
- Regex: '^<.*\.h(pp)?>'
Priority: 1
- Regex: '^<.*'
Priority: 2
- Regex: '.*'
Priority: 3
IndentCaseLabels: true
IndentWidth: 2
IndentWrappedFunctionNames: false
KeepEmptyLinesAtTheStartOfBlocks: false
MacroBlockBegin: ''
MacroBlockEnd: ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
ObjCBlockIndentWidth: 2
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: false
PenaltyBreakBeforeFirstCallParameter: 1
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
PenaltyBreakString: 1000
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 2000000
PointerAlignment: Left
ReflowComments: true
SortIncludes: true
SpaceAfterCStyleCast: false
SpaceBeforeAssignmentOperators: true
SpaceBeforeParens: ControlStatements
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 1
SpacesInAngles: false
SpacesInContainerLiterals: true
SpacesInCStyleCastParentheses: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
Standard: Cpp11
TabWidth: 8
UseTab: Never
...
3 changes: 2 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,5 @@ before_install:
script:
- mkdir build
- cd build
- cmake -DCMAKE_CXX_COMPILER=$COMPILER .. && make
- cmake -DCMAKE_CXX_COMPILER=$COMPILER ..
- make -j8
2 changes: 1 addition & 1 deletion include/mnist.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include <vector>

struct Options {
std::string data_root{ "../data/mnist" };
std::string data_root{ "./data/mnist" };
int32_t batch_size{ 64 };
int32_t epochs{ 10 };
double lr{ 0.01 };
Expand Down
28 changes: 28 additions & 0 deletions scripts/display_samples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from __future__ import print_function
from __future__ import unicode_literals

import argparse

import matplotlib.pyplot as plt
import torch


parser = argparse.ArgumentParser()
parser.add_argument("-i", "--sample-file", required=True)
parser.add_argument("-o", "--out-file", default="out.png")
parser.add_argument("-d", "--dimension", type=int, default=3)
options = parser.parse_args()

module = torch.jit.load(options.sample_file)
images = list(module.parameters())[0]

for index in range(options.dimension * options.dimension):
image = images[index].detach().cpu().reshape(28, 28).mul(255).to(torch.uint8)
array = image.numpy()
axis = plt.subplot(options.dimension, options.dimension, 1 + index)
plt.imshow(array, cmap="gray")
axis.get_xaxis().set_visible(False)
axis.get_yaxis().set_visible(False)

plt.savefig(options.out_file)
print("Saved ", options.out_file)
8 changes: 8 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,11 @@ add_executable(mnist
target_compile_features(mnist PUBLIC cxx_range_for)

target_link_libraries(mnist ${TORCH_LIBRARIES})

set_property(TARGET mnist PROPERTY CXX_STANDARD 11)

add_executable(dcgan dcgan.cpp)

target_link_libraries(dcgan "${TORCH_LIBRARIES}")

set_property(TARGET dcgan PROPERTY CXX_STANDARD 11)
187 changes: 187 additions & 0 deletions src/dcgan.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
#include <torch/torch.h>

#include <cmath>
#include <cstdio>
#include <iostream>

// The size of the noise vector fed to the generator.
const int64_t kNoiseSize = 100;

// The batch size for training.
const int64_t kBatchSize = 64;

// The number of epochs to train.
const int64_t kNumberOfEpochs = 30;

// Where to find the MNIST dataset.
const char* kDataFolder = "./data/mnist";

// After how many batches to create a new checkpoint periodically.
const int64_t kCheckpointEvery = 200;

// How many images to sample at every checkpoint.
const int64_t kNumberOfSamplesPerCheckpoint = 10;

// Set to `true` to restore models and optimizers from previously saved
// checkpoints.
const bool kRestoreFromCheckpoint = false;

// After how many batches to log a new update with the loss value.
const int64_t kLogInterval = 10;

using namespace torch;

int main(int argc, const char* argv[]) {
torch::manual_seed(1);

// Create the device we pass around based on whether CUDA is available.
torch::Device device(torch::kCPU);
if (torch::cuda::is_available()) {
std::cout << "CUDA is available! Training on GPU." << std::endl;
device = torch::Device(torch::kCUDA);
}

nn::Sequential generator(
// Layer 1
nn::Conv2d(nn::Conv2dOptions(kNoiseSize, 256, 4)
.with_bias(false)
.transposed(true)),
nn::BatchNorm(256),
nn::Functional(torch::relu),
// Layer 2
nn::Conv2d(nn::Conv2dOptions(256, 128, 3)
.stride(2)
.padding(1)
.with_bias(false)
.transposed(true)),
nn::BatchNorm(128),
nn::Functional(torch::relu),
// Layer 3
nn::Conv2d(nn::Conv2dOptions(128, 64, 4)
.stride(2)
.padding(1)
.with_bias(false)
.transposed(true)),
nn::BatchNorm(64),
nn::Functional(torch::relu),
// Layer 4
nn::Conv2d(nn::Conv2dOptions(64, 1, 4)
.stride(2)
.padding(1)
.with_bias(false)
.transposed(true)),
nn::Functional(torch::tanh));
generator->to(device);

nn::Sequential discriminator(
// Layer 1
nn::Conv2d(
nn::Conv2dOptions(1, 64, 4).stride(2).padding(1).with_bias(false)),
nn::Functional(torch::leaky_relu, 0.2),
// Layer 2
nn::Conv2d(
nn::Conv2dOptions(64, 128, 4).stride(2).padding(1).with_bias(false)),
nn::BatchNorm(128),
nn::Functional(torch::leaky_relu, 0.2),
// Layer 3
nn::Conv2d(
nn::Conv2dOptions(128, 256, 4).stride(2).padding(1).with_bias(false)),
nn::BatchNorm(256),
nn::Functional(torch::leaky_relu, 0.2),
// Layer 4
nn::Conv2d(
nn::Conv2dOptions(256, 1, 3).stride(1).padding(0).with_bias(false)),
nn::Functional(torch::sigmoid));
discriminator->to(device);

// Assume the MNIST dataset is available under `kDataFolder`;
auto dataset = torch::data::datasets::MNIST(kDataFolder)
.map(torch::data::transforms::Normalize<>(0.5, 0.5))
.map(torch::data::transforms::Stack<>());
const int64_t batches_per_epoch =
std::ceil(dataset.size().value() / static_cast<double>(kBatchSize));

auto data_loader = torch::data::make_data_loader(
std::move(dataset),
torch::data::DataLoaderOptions().batch_size(kBatchSize).workers(2));

torch::optim::Adam generator_optimizer(
generator->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
torch::optim::Adam discriminator_optimizer(
discriminator->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));

if (kRestoreFromCheckpoint) {
torch::load(generator, "generator-checkpoint.pt");
torch::load(generator_optimizer, "generator-optimizer-checkpoint.pt");
torch::load(discriminator, "discriminator-checkpoint.pt");
torch::load(
discriminator_optimizer, "discriminator-optimizer-checkpoint.pt");
}

int64_t checkpoint_counter = 1;
for (int64_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {
int64_t batch_index = 0;
for (torch::data::Example<>& batch : *data_loader) {
// Train discriminator with real images.
discriminator->zero_grad();
torch::Tensor real_images = batch.data.to(device);
torch::Tensor real_labels =
torch::empty(batch.data.size(0), device).uniform_(0.8, 1.0);
torch::Tensor real_output = discriminator->forward(real_images);
torch::Tensor d_loss_real =
torch::binary_cross_entropy(real_output, real_labels);
d_loss_real.backward();

// Train discriminator with fake images.
torch::Tensor noise =
torch::randn({batch.data.size(0), kNoiseSize, 1, 1}, device);
torch::Tensor fake_images = generator->forward(noise);
torch::Tensor fake_labels = torch::zeros(batch.data.size(0), device);
torch::Tensor fake_output = discriminator->forward(fake_images.detach());
torch::Tensor d_loss_fake =
torch::binary_cross_entropy(fake_output, fake_labels);
d_loss_fake.backward();

torch::Tensor d_loss = d_loss_real + d_loss_fake;
discriminator_optimizer.step();

// Train generator.
generator->zero_grad();
fake_labels.fill_(1);
fake_output = discriminator->forward(fake_images);
torch::Tensor g_loss =
torch::binary_cross_entropy(fake_output, fake_labels);
g_loss.backward();
generator_optimizer.step();

if (batch_index % kLogInterval == 0) {
std::printf(
"\r[%2ld/%2ld][%3ld/%3ld] D_loss: %.4f | G_loss: %.4f",
epoch,
kNumberOfEpochs,
++batch_index,
batches_per_epoch,
d_loss.item<float>(),
g_loss.item<float>());
}

if (batch_index % kCheckpointEvery == 0) {
// Checkpoint the model and optimizer state.
torch::save(generator, "generator-checkpoint.pt");
torch::save(generator_optimizer, "generator-optimizer-checkpoint.pt");
torch::save(discriminator, "discriminator-checkpoint.pt");
torch::save(
discriminator_optimizer, "discriminator-optimizer-checkpoint.pt");
// Sample the generator and save the images.
torch::Tensor samples = generator->forward(torch::randn(
{kNumberOfSamplesPerCheckpoint, kNoiseSize, 1, 1}, device));
torch::save(
(samples + 1.0) / 2.0,
torch::str("dcgan-sample-", checkpoint_counter, ".pt"));
std::cout << "\n-> checkpoint " << ++checkpoint_counter << '\n';
}
}
}

std::cout << "Training complete!" << std::endl;
}

0 comments on commit 269c433

Please sign in to comment.