Skip to content

Commit

Permalink
Add script to generate and print boston housing dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
tobegit3hub committed Sep 13, 2017
1 parent 9fdc628 commit 3c87f7c
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 174 deletions.
44 changes: 44 additions & 0 deletions data/boston_housing/generate_csv_tfrecords.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#!/usr/bin/env python

import tensorflow as tf
import os


def generate_tfrecords(input_filename, output_filename):
print("Start to convert {} to {}".format(input_filename, output_filename))
writer = tf.python_io.TFRecordWriter(output_filename)

index = 0
for line in open(input_filename, "r"):
index += 1

# Ignore the first line
if index == 1:
continue

data = line.split(",")
label = float(data[14])
features = [float(i) for i in data[1:14]]

example = tf.train.Example(features=tf.train.Features(feature={
"label":
tf.train.Feature(float_list=tf.train.FloatList(value=[label])),
"features":
tf.train.Feature(float_list=tf.train.FloatList(value=features)),
}))
writer.write(example.SerializeToString())

writer.close()
print("Successfully convert {} to {}".format(input_filename,
output_filename))


def main():
current_path = os.getcwd()
for filename in os.listdir(current_path):
if filename.startswith("") and filename.endswith(".csv"):
generate_tfrecords(filename, filename + ".tfrecords")


if __name__ == "__main__":
main()
34 changes: 34 additions & 0 deletions data/boston_housing/print_csv_tfrecords.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#!/usr/bin/env python

import tensorflow as tf
import os


def print_tfrecords(input_filename):
max_print_number = 100
current_print_number = 0

for serialized_example in tf.python_io.tf_record_iterator(input_filename):
# Get serialized example from file
example = tf.train.Example()
example.ParseFromString(serialized_example)
label = example.features.feature["label"].float_list.value
features = example.features.feature["features"].float_list.value
print("Number: {}, label: {}, features: {}".format(current_print_number,
label, features))

# Return when reaching max print number
current_print_number += 1
if current_print_number > max_print_number:
exit()


def main():
current_path = os.getcwd()
tfrecords_file_name = "train.csv.tfrecords"
input_filename = os.path.join(current_path, tfrecords_file_name)
print_tfrecords(input_filename)


if __name__ == "__main__":
main()
174 changes: 0 additions & 174 deletions data/boston_housing/test.csv

This file was deleted.

Binary file added data/boston_housing/train.csv.tfrecords
Binary file not shown.

0 comments on commit 3c87f7c

Please sign in to comment.