Skip to content

Commit

Permalink
removed dependency to download models
Browse files Browse the repository at this point in the history
  • Loading branch information
akshaytc committed May 8, 2017
1 parent c8f9e63 commit 188ce45
Showing 1 changed file with 1 addition and 48 deletions.
49 changes: 1 addition & 48 deletions retrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,21 +84,7 @@
from tensorflow.python.util import compat

FLAGS = None

# Input and output file flags.

# Details of the training configuration.

# File-system cache locations.

# Controls the distortions used during training.

# These are all parameters that are tied to the particular model architecture
# we're using for Inception v3. These include things like tensor names and their
# sizes. If you want to adapt this script to work with another model, you will
# need to update these to reflect the values in the network you're using.
# pylint: disable=line-too-long
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
current_dir_path = os.path.dirname(os.path.realpath(__file__))
# pylint: enable=line-too-long
BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'
BOTTLENECK_TENSOR_SIZE = 2048
Expand Down Expand Up @@ -284,37 +270,6 @@ def run_bottleneck_on_image(sess, image_data, image_data_tensor,
bottleneck_values = np.squeeze(bottleneck_values)
return bottleneck_values


def maybe_download_and_extract():
"""Download and extract model tar file.
If the pretrained model we're using doesn't already exist, this function
downloads it from the TensorFlow.org website and unpacks it into a directory.
"""
dest_directory = FLAGS.model_dir
print (dest_directory)
if not os.path.exists(dest_directory):
os.makedirs(dest_directory)
filename = DATA_URL.split('/')[-1]
filepath = os.path.join(dest_directory, filename)
print ("filepath is:" + filepath)
if not os.path.exists(filepath):
print ("file not found" + filepath)
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' %
(filename,
float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()

filepath, _ = urllib.request.urlretrieve(DATA_URL,
filepath,
_progress)
print()
statinfo = os.stat(filepath)
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
tarfile.open(filepath, 'r:gz').extractall(dest_directory)


def ensure_dir_exists(dir_name):
"""Makes sure the folder exists on disk.
Expand Down Expand Up @@ -749,8 +704,6 @@ def main(_):
tf.gfile.DeleteRecursively(FLAGS.summaries_dir)
tf.gfile.MakeDirs(FLAGS.summaries_dir)

# Set up the pre-trained graph.
maybe_download_and_extract()
graph, bottleneck_tensor, jpeg_data_tensor, resized_image_tensor = (
create_inception_graph())

Expand Down

0 comments on commit 188ce45

Please sign in to comment.