Skip to content

Commit d6f8e4c

Browse files
Add more complete hosting example, python 3 changes
1 parent 3d5bdc8 commit d6f8e4c

File tree

5 files changed

+118
-12
lines changed

5 files changed

+118
-12
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ models
55
terraform.tfstate
66
terraform.tfstate.backup
77
terraform.tfvars
8+
.idea

pix2pix.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -581,13 +581,17 @@ def main():
581581
input = tf.placeholder(tf.string, shape=[1])
582582
input_data = tf.decode_base64(input[0])
583583
input_image = tf.image.decode_png(input_data)
584+
584585
# remove alpha channel if present
585-
input_image = input_image[:,:,:3]
586+
input_image = tf.cond(tf.equal(tf.shape(input_image)[2], 4), lambda: input_image[:,:,:3], lambda: input_image)
587+
# convert grayscale to RGB
588+
input_image = tf.cond(tf.equal(tf.shape(input_image)[2], 1), lambda: tf.image.grayscale_to_rgb(input_image), lambda: input_image)
589+
586590
input_image = tf.image.convert_image_dtype(input_image, dtype=tf.float32)
587591
input_image.set_shape([CROP_SIZE, CROP_SIZE, 3])
588592
batch_input = tf.expand_dims(input_image, axis=0)
589593

590-
with tf.variable_scope("generator") as scope:
594+
with tf.variable_scope("generator"):
591595
batch_output = deprocess(create_generator(preprocess(batch_input), 3))
592596

593597
output_image = tf.image.convert_image_dtype(batch_output, dtype=tf.uint8)[0]

server/README.md

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,107 @@ cp terraform.tfvars.example terraform.tfvars
8888
python ../tools/dockrun.py terraform plan
8989
python ../tools/dockrun.py terraform apply
9090
```
91+
92+
## Full training + exporting + hosting commands
93+
94+
Tested with Python 3.6, Tensorflow 1.0.0, Docker, gcloud, and Terraform (https://www.terraform.io/downloads.html)
95+
96+
```sh
97+
git clone https://github.com/affinelayer/pix2pix-tensorflow.git
98+
cd pix2pix-tensorflow
99+
100+
# get some images (only 2 for testing)
101+
mkdir source
102+
curl -o source/cat1.jpg https://farm5.staticflickr.com/4032/4394955222_eea73818d9_o.jpg
103+
curl -o source/cat2.jpg http://wallpapercave.com/wp/ePMeSmp.jpg
104+
105+
# resize source images
106+
python tools/process.py \
107+
--input_dir source \
108+
--operation resize \
109+
--output_dir resized
110+
111+
# create edges from resized images (uses docker container since compiling the dependencies is annoying)
112+
python tools/dockrun.py python tools/process.py \
113+
--input_dir resized \
114+
--operation edges \
115+
--output_dir edges
116+
117+
# combine resized with edges
118+
python tools/process.py \
119+
--input_dir edges \
120+
--b_dir resized \
121+
--operation combine \
122+
--output_dir combined
123+
124+
# train on images (only 1 epoch for testing)
125+
python pix2pix.py \
126+
--mode train \
127+
--output_dir train \
128+
--max_epochs 1 \
129+
--input_dir combined \
130+
--which_direction AtoB
131+
132+
# export model (creates a version of the model that works with the server in server/serve.py as well as google hosted tensorflow)
133+
python pix2pix.py \
134+
--mode export \
135+
--output_dir server/models/edges2cats_AtoB \
136+
--checkpoint train
137+
138+
# process image locally using exported model
139+
python server/tools/process-local.py \
140+
--model_dir server/models/edges2cats_AtoB \
141+
--input_file edges/cat1.png \
142+
--output_file output.png
143+
144+
# serve model locally
145+
cd server
146+
python serve.py --port 8000 --local_models_dir models
147+
148+
# open http://localhost:8000 in a browser, and scroll to the bottom, you should be able to process an edges2cat image and get a bunch of noise as output
149+
150+
# serve model remotely
151+
152+
export GOOGLE_PROJECT=<project name>
153+
154+
# build image
155+
# make sure models are in a directory called "models" in the current directory
156+
docker build --rm --tag us.gcr.io/$GOOGLE_PROJECT/pix2pix-server .
157+
158+
# test image locally
159+
docker run --publish 8000:8000 --rm --name server us.gcr.io/$GOOGLE_PROJECT/pix2pix-server python -u serve.py \
160+
--port 8000 \
161+
--local_models_dir models
162+
163+
# run this while the above server is running
164+
python tools/process-remote.py \
165+
--input_file static/edges2cats-input.png \
166+
--url http://localhost:8000/edges2cats_AtoB \
167+
--output_file output.png
168+
169+
# publish image to private google container repository
170+
python tools/upload-image.py --project $GOOGLE_PROJECT --version v1
171+
172+
# create a google cloud server
173+
cp terraform.tfvars.example terraform.tfvars
174+
# edit terraform.tfvars to put your cloud info in there
175+
# get the service-account.json from the google cloud console
176+
# make sure GCE is enabled on your account as well
177+
python terraform plan
178+
python terraform apply
179+
180+
# get name of server
181+
gcloud compute instance-groups list-instances pix2pix-manager
182+
# ssh to server
183+
gcloud compute ssh <name of instance here>
184+
# look at the logs (can take awhile to load docker image)
185+
sudo journalctl -f -u pix2pix
186+
# if you have never made an http-server before, apparently you may need this rule
187+
gcloud compute firewall-rules create http-server --allow=tcp:80 --target-tags http-server
188+
# get ip address of load balancer
189+
gcloud compute forwarding-rules list
190+
# open that in the browser, should see the same page you saw locally
191+
192+
# to destroy the GCP resources, use this
193+
terraform destroy
194+
```

server/serve.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from __future__ import print_function
44

55
import socket
6-
import urlparse
76
import time
87
import argparse
98
import base64
@@ -99,7 +98,7 @@ def do_GET(self):
9998
self.send_response(200)
10099
self.send_header("Content-Type", "text/html")
101100
self.end_headers()
102-
with open("static/index.html") as f:
101+
with open("static/index.html", "rb") as f:
103102
self.wfile.write(f.read())
104103
return
105104

@@ -117,7 +116,7 @@ def do_GET(self):
117116
else:
118117
self.send_header("Content-Type", "application/octet-stream")
119118
self.end_headers()
120-
with open("static/" + path) as f:
119+
with open("static/" + path, "rb") as f:
121120
self.wfile.write(f.read())
122121

123122

@@ -154,7 +153,7 @@ def do_POST(self):
154153

155154
variants = models[name] # "cloud" and "local" are the two possible variants
156155

157-
content_len = int(self.headers.getheader("content-length", 0))
156+
content_len = int(self.headers.get("content-length", "0"))
158157
if content_len > 1 * 1024 * 1024:
159158
raise Exception("post body too large")
160159
input_data = self.rfile.read(content_len)
@@ -192,9 +191,9 @@ def do_POST(self):
192191
raise Exception("too many requests")
193192

194193
# add any missing padding
195-
output_b64data += "=" * (-len(output_b64data) % 4)
194+
output_b64data += b"=" * (-len(output_b64data) % 4)
196195
output_data = base64.urlsafe_b64decode(output_b64data)
197-
if output_data.startswith("\x89PNG"):
196+
if output_data.startswith(b"\x89PNG"):
198197
headers["content-type"] = "image/png"
199198
else:
200199
headers["content-type"] = "image/jpeg"
@@ -207,7 +206,7 @@ def do_POST(self):
207206
body = "server error"
208207

209208
self.send_response(status)
210-
for key, value in headers.iteritems():
209+
for key, value in headers.items():
211210
self.send_header(key, value)
212211
self.end_headers()
213212
self.wfile.write(body)
@@ -273,7 +272,7 @@ def main():
273272
project_id = a.project
274273
else:
275274
credentials = oauth2client.service_account.ServiceAccountCredentials.from_json_keyfile_name(a.credentials, scopes)
276-
with open(a.credentials) as f:
275+
with open(a.credentials, "r") as f:
277276
project_id = json.loads(f.read())["project_id"]
278277

279278
# due to what appears to be a bug, we cannot get the discovery document when specifying an http client

tools/dockrun.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,6 @@ def main():
101101
"PYTHONUNBUFFERED=x",
102102
"--env",
103103
"CUDA_CACHE_PATH=/host/tmp/cuda-cache",
104-
"--env",
105-
"HOME=/host" + os.environ["HOME"],
106104
]
107105

108106
if a.port is not None:

0 commit comments

Comments
 (0)