Skip to content

Commit 845028b

Browse files
authored
Merge pull request #44 from r7rohan/patch-1
Autoencoder updates
2 parents 5105d47 + 7ee8a89 commit 845028b

File tree

3 files changed

+73
-50
lines changed

3 files changed

+73
-50
lines changed

Autoencoder/README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55

66
## Features
77

8+
- `Autoencoder is densenet layered with batchnorm layers and multiple skip connections`
9+
810
- `Select the structure for the DenseNet and see the performance of the model. `
911

1012
- `Sample autoencoded MNIST Digits can be seen`
1113

1214
- `2D Visualize the encoded space of the autoencoder, see the decoded digit for the corresponding latent point`
1315

14-
- `Autoencode your digit drawing`
16+
- `Autoencode your own digit drawing`
1517

1618
## Installation and execution
1719

Autoencoder/index.html

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -49,25 +49,28 @@
4949
}
5050
</style>
5151
</head>
52-
52+
5353
<body style="text-align:center;background-color:white;font-size:3vh;top:0%; font-family:Georgia;margin-top:0%;margin-left:0%;">
5454
<p style="background-color:#111111E9;font-size:3vw;top:0%;font-family:Times;margin-top:0%;padding-bottom:1%;padding-top:1%;position:fixed;width:100%;color:#EEEEFF;">TensorFlow.js: MNIST Autoencoder</p>
5555
<p style="background-color:#7777AA;font-size:3vw;top:0%;font-family:Times;margin-top:0%;padding-bottom:1%;padding-top:1%;visibility: hidden;width:100%;margin-bottom:0%;">TensorFlow.js: MNIST Autoencoder</p>
5656
<div id="body" style="text-align:center;height:100%;width:55%;margin-left:22%;background-color:white;padding-left:1vw;padding-top:1vh;">
5757
<br>
58-
59-
58+
59+
6060
<section class='title-area' >
61-
<p style="text-align:left;color:#111111;margin-top:0%;">Train a model to autoencode handwritten digits from the MNIST database using the tf.layers
62-
api.
63-
<br>
64-
This examples lets you train a MNIST Autoencoder using a Fully Connected Neural Network (also known as a DenseNet).<br><br>
61+
<p style="text-align:left;color:#111111;margin-top:0%;">
62+
This examples lets you train a MNIST Autoencoder using a Fully Connected Neural Network (also known as a DenseNet) in written in Tfjs<br><br>
6563
You can select the structure for the DenseNet and see the performance of the model.
6664
<br>The MNIST dataset is used as training data.
65+
<br>
66+
<br>
67+
Set latent space dimension to 2 for 2d Exploration of the latent space. Otherwise set it high for accurate autoencoding
68+
<br>
69+
Visualization scale determines the scale of 2d pane
6770
</p>
6871
</section>
69-
70-
72+
73+
7174
<div style="width:100%;height:5px;background-color:#EFEFEF;"></div>
7275
<section style="text-align:center;">
7376
<p class='section-head' >
@@ -76,7 +79,7 @@ <h1 style="background-color:#EFEFFF;padding-top:0.5%;padding-bottom:1%;">Trainin
7679
<div style="font-family:Times;width:20vw;background-color:#EFEFEF;text-align:left;padding-left:3%;padding-top:3%;padding-bottom:2%;display:inline-block;">
7780
<div>
7881
<label>N hidden layers in encoder and decoder</label>
79-
<input id="n_layers" value="2">
82+
<input id="n_layers" value="3">
8083
</div>
8184
<div>
8285
<label>Output dimension of each layer</label>
@@ -88,29 +91,29 @@ <h1 style="background-color:#EFEFFF;padding-top:0.5%;padding-bottom:1%;">Trainin
8891
</div>
8992
<button id="Create">Create model</button>
9093
</div>
91-
94+
9295
<div style="font-family:Times;width:20vw;background-color:#EFEFEF;text-align:left;padding-left:3%;padding-top:1.3%;padding-bottom:2%;display:inline-block;">
9396
<div>
9497
<label># Batch size:</label>
9598
<input id="batchsize" value="300">
9699
</div>
97100
<div>
98101
<label># LearnRate:</label>
99-
<input id="lr" value="0.3">
102+
<input id="lr" value="0.1">
100103
</div>
101104
<div>
102105
<label># Training epochs:</label>
103106
<input id="train-epochs" value="1">
104107
</div>
105108
<div>
106109
<label># Visualization scale</label>
107-
<input id="vis" value="50">
110+
<input id="vis" value="0.1">
108111
</div>
109112
<button id="train">Train Model</button>
110113
</div>
111114
</section>
112-
113-
115+
116+
114117
<br>
115118
<div style="width:100%;height:5px;background-color:#EFEFEF;"></div>
116119
</div>
@@ -124,11 +127,11 @@ <h2>This will show the examples of autoencoder once it its trained</h2>
124127
</div>
125128
</div>
126129
<br><br><br>
127-
128-
130+
131+
129132
<div style="width:100%;background-color:white;height:15px;"></div>
130133
<div style="text-align:center;">
131-
<h2>This is for 2d plot visualization of latent space of autoencoder.<br> If your latent space dimension is set to 2D<br></h2>
134+
<h2>This is for 2d plot visualization of latent space of autoencoder.<br> Drag in the 2d Pane below slowly<br></h2>
132135
<div id="cn" style="display:none;margin-left:35%;text-align:center;">
133136
<canvas id="mot" style="height:80px;width:80px;display:block;margin-left:20%;border:solid 3px black;"></canvas>
134137
<br>
@@ -146,8 +149,8 @@ <h2>This is for autoencoding your drawing on the canvas<br></h2>
146149
<button id="clear" style="display:inline;">Clear</button>
147150
</div>
148151
</div>
149-
150-
152+
153+
151154
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0.0/dist/tf.min.js"></script>
152155
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-vis@1.0.2/dist/tfjs-vis.umd.min.js"></script>
153156
<script src='https://cdnjs.cloudflare.com/ajax/libs/tensorflow/1.2.7/tf.min.js'></script>

Autoencoder/index.js

Lines changed: 47 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,52 +8,63 @@
88
// for arbitrary data though. It's worth a look :)
99
import {IMAGE_H, IMAGE_W, MnistData} from './datas.js';
1010

11-
// This is a helper class for drawing loss graphs and MNIST images to the
12-
// window. For the purposes of understanding the machine learning bits, you can
13-
// largely ignore it
1411
import * as ui from './ui.js';
1512

1613

17-
function createConvModel(n_layers,n_units,hidden) {
18-
14+
function createConvModel(n_layers,n_units,hidden) { //resnet-densenet-batchnorm
1915
this.latent_dim = Number(hidden); //final dimension of hidden layer
2016
this.n_layers = Number(n_layers); //how many hidden layers in encoder and decoder
2117
this.n_units = Number(n_units); //output dimension of each layer
2218
this.img_shape = [28,28];
2319
this.img_units = this.img_shape[0] * this.img_shape[1];
2420
// build the encoder
21+
2522
var i = tf.input({shape: this.img_shape});
2623
var h = tf.layers.flatten().apply(i);
27-
28-
for (var j=0; j<this.n_layers; j++) {
24+
h=tf.layers.batchNormalization(-1).apply(h);
25+
h = tf.layers.dense({units: this.n_units, activation:'relu'}).apply(h);
26+
for (var j=0; j<this.n_layers-1; j++) {
27+
var tm=h;
28+
const addLayer = tf.layers.add();
2929
var h = tf.layers.dense({units: this.n_units, activation:'relu'}).apply(h); //n hidden
30+
h=addLayer.apply([tm,h]);
31+
h=tf.layers.batchNormalization(0).apply(h);
3032
}
3133

32-
var o = tf.layers.dense({units: this.latent_dim}).apply(h); //1 final
34+
var o = tf.layers.dense({units: this.latent_dim}).apply(h);
35+
//1 final
3336
this.encoder = tf.model({inputs: i, outputs: o});
3437

3538
// build the decoder
3639
var i = h = tf.input({shape: this.latent_dim});
37-
for (var j=0; j<this.n_layers; j++) { //n hidden
40+
h = tf.layers.dense({units: this.n_units, activation:'relu'}).apply(h);
41+
for (var j=0; j<this.n_layers-1; j++) {
42+
var tm=h;
43+
const addLayer = tf.layers.add(); //n hidden
3844
var h = tf.layers.dense({units: this.n_units, activation:'relu'}).apply(h);
45+
h=addLayer.apply([tm,h]);
3946
}
40-
var o = tf.layers.dense({units: this.img_units}).apply(h) ; //1 final
47+
48+
var o = tf.layers.dense({units: this.img_units}).apply(h); //1 final
4149
var o = tf.layers.reshape({targetShape: this.img_shape}).apply(o);
4250
this.decoder = tf.model({inputs: i, outputs: o});
4351

4452
// stack the autoencoder
4553
var i = tf.input({shape: this.img_shape});
4654
var z = this.encoder.apply(i); //z is hidden code
47-
4855
var o = this.decoder.apply(z);
4956
this.auto = tf.model({inputs: i, outputs: o});
5057

5158
}
59+
60+
5261
let epochs=0,trainEpochs,batch;
5362
var trainData;
5463
var testData;
5564
var b;var model;
5665

66+
67+
5768
async function train(model) {
5869

5970
const e=document.getElementById('batchsize');
@@ -84,8 +95,6 @@ await showPredictions(model,epochs); //Triv
8495

8596
}
8697

87-
88-
8998
async function showPredictions(model,epochs) { //Trivial Samples of autoencoder
9099
const testExamples = 10;
91100
const examples = data.getTestData(testExamples);
@@ -106,14 +115,15 @@ async function run(){
106115
testData = data.getTestData();
107116
}
108117

118+
document.getElementById('vis').oninput=function(){vis=Number(document.getElementById('vis').value);console.log(vis);};
109119

110120
async function load() {
111121
var ele=document.getElementById('barc');
112122
ele.style.display="none";
113123
const n_units=document.getElementById('n_units').value;
114124
const n_layers=document.getElementById('n_layers').value;
115125
const hidden=document.getElementById('hidden').value;
116-
model = new createConvModel(n_layers,n_units,hidden);
126+
model = new createConvModel(n_layers,n_units,hidden); //load model
117127
const elem=document.getElementById('new')
118128
elem.innerHTML="Model Created!!!"
119129
epochs=0;
@@ -122,13 +132,15 @@ async function load() {
122132

123133
load();
124134

135+
136+
125137
async function runtrain(){
126138
var ele=document.getElementById('barc');
127139
ele.style.display="block";
128140
var elem=document.getElementById('new');
129141
elem.innerHTML="";
130142
b=0;
131-
await train(model);
143+
await train(model); //start training
132144
vis=Number(document.getElementById('vis').value);
133145
}
134146

@@ -151,7 +163,7 @@ function normaltensor(prediction){
151163
prediction= prediction.sub(inputMin).div(inputMax.sub(inputMin));
152164
return prediction;}
153165
function normal(prediction){
154-
const inputMax = prediction.max();
166+
const inputMax = prediction.max(); //normailization
155167
const inputMin = prediction.min();
156168
prediction= prediction.sub(inputMin).div(inputMax.sub(inputMin));
157169
return prediction;
@@ -163,22 +175,27 @@ const canvas=document.getElementById('celeba-scene');
163175
const mot=document.getElementById('mot');
164176
var cont=mot.getContext('2d');
165177

178+
179+
180+
181+
182+
183+
184+
185+
186+
166187
function sample(obj) { //plotting
167188
obj.x = (obj.x) * vis;
168189
obj.y = (obj.y) * vis;
169190
// convert 10, 50 into a vector
170191
var y = tf.tensor2d([[obj.x, obj.y]]);
171-
// sample from region 10, 50 in latent space
172192

173193
var prediction = model.decoder.predict(y).dataSync();
174-
175-
//scaling
194+
//scaling
176195
prediction=normaltensor(prediction);
177196
prediction=prediction.reshape([28,28]);
178197

179-
prediction=prediction.mul(255).toInt();
180-
181-
198+
prediction=prediction.mul(255).toInt(); //for2dplot
182199
// log the prediction to the browser console
183200
tf.browser.toPixels(prediction, canvas);
184201
}
@@ -190,7 +207,7 @@ cont.fillRect(0,0,mot.width,mot.height);
190207
mot.addEventListener('mousemove', function(e) {
191208
mouse.x = (e.pageX - this.offsetLeft)*3.43;
192209
mouse.y = (e.pageY - this.offsetTop)*1.9;
193-
}, false);
210+
}, false); //mouse movement for 2dplot
194211

195212
mot.addEventListener('mousedown', function(e) {
196213
mot.addEventListener('mousemove', on, false);
@@ -209,11 +226,6 @@ var on= function() {
209226
};
210227

211228

212-
213-
214-
215-
216-
217229
function plot2d(){
218230
load();
219231
const decision=Number(document.getElementById("hidden").value);
@@ -241,6 +253,12 @@ document.addEventListener('DOMContentLoaded',plot2d);
241253

242254

243255

256+
257+
258+
259+
260+
261+
244262
const canv=document.getElementById('canv');
245263
const outcanv=document.getElementById('outcanv');
246264
var ct = outcanv.getContext('2d');
@@ -250,7 +268,7 @@ var ctx = canv.getContext('2d');
250268
function clear(){
251269
ctx.clearRect(0, 0, canv.width, canv.height);
252270
ctx.fillStyle = "black";
253-
ctx.fillRect(0, 0, canv.width, canv.height);
271+
ctx.fillRect(0, 0, canv.width, canv.height); //for canvas autoencoding
254272
ct.clearRect(0, 0, outcanv.width, outcanv.height);
255273
ct.fillStyle = "#DDDDDD";
256274
ct.fillRect(0, 0, outcanv.width, outcanv.height);

0 commit comments

Comments
 (0)