forked from tobegit3hub/tensorflow_template_application
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add script to generate and print boston housing dataset
- Loading branch information
1 parent
9fdc628
commit 3c87f7c
Showing
4 changed files
with
78 additions
and
174 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
#!/usr/bin/env python | ||
|
||
import tensorflow as tf | ||
import os | ||
|
||
|
||
def generate_tfrecords(input_filename, output_filename): | ||
print("Start to convert {} to {}".format(input_filename, output_filename)) | ||
writer = tf.python_io.TFRecordWriter(output_filename) | ||
|
||
index = 0 | ||
for line in open(input_filename, "r"): | ||
index += 1 | ||
|
||
# Ignore the first line | ||
if index == 1: | ||
continue | ||
|
||
data = line.split(",") | ||
label = float(data[14]) | ||
features = [float(i) for i in data[1:14]] | ||
|
||
example = tf.train.Example(features=tf.train.Features(feature={ | ||
"label": | ||
tf.train.Feature(float_list=tf.train.FloatList(value=[label])), | ||
"features": | ||
tf.train.Feature(float_list=tf.train.FloatList(value=features)), | ||
})) | ||
writer.write(example.SerializeToString()) | ||
|
||
writer.close() | ||
print("Successfully convert {} to {}".format(input_filename, | ||
output_filename)) | ||
|
||
|
||
def main(): | ||
current_path = os.getcwd() | ||
for filename in os.listdir(current_path): | ||
if filename.startswith("") and filename.endswith(".csv"): | ||
generate_tfrecords(filename, filename + ".tfrecords") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
#!/usr/bin/env python | ||
|
||
import tensorflow as tf | ||
import os | ||
|
||
|
||
def print_tfrecords(input_filename): | ||
max_print_number = 100 | ||
current_print_number = 0 | ||
|
||
for serialized_example in tf.python_io.tf_record_iterator(input_filename): | ||
# Get serialized example from file | ||
example = tf.train.Example() | ||
example.ParseFromString(serialized_example) | ||
label = example.features.feature["label"].float_list.value | ||
features = example.features.feature["features"].float_list.value | ||
print("Number: {}, label: {}, features: {}".format(current_print_number, | ||
label, features)) | ||
|
||
# Return when reaching max print number | ||
current_print_number += 1 | ||
if current_print_number > max_print_number: | ||
exit() | ||
|
||
|
||
def main(): | ||
current_path = os.getcwd() | ||
tfrecords_file_name = "train.csv.tfrecords" | ||
input_filename = os.path.join(current_path, tfrecords_file_name) | ||
print_tfrecords(input_filename) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file was deleted.
Oops, something went wrong.
Binary file not shown.