diff --git a/retrain.py b/retrain.py index 29593d7..9d7cd53 100644 --- a/retrain.py +++ b/retrain.py @@ -84,21 +84,7 @@ from tensorflow.python.util import compat FLAGS = None - -# Input and output file flags. - -# Details of the training configuration. - -# File-system cache locations. - -# Controls the distortions used during training. - -# These are all parameters that are tied to the particular model architecture -# we're using for Inception v3. These include things like tensor names and their -# sizes. If you want to adapt this script to work with another model, you will -# need to update these to reflect the values in the network you're using. -# pylint: disable=line-too-long -DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' +current_dir_path = os.path.dirname(os.path.realpath(__file__)) # pylint: enable=line-too-long BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0' BOTTLENECK_TENSOR_SIZE = 2048 @@ -284,37 +270,6 @@ def run_bottleneck_on_image(sess, image_data, image_data_tensor, bottleneck_values = np.squeeze(bottleneck_values) return bottleneck_values - -def maybe_download_and_extract(): - """Download and extract model tar file. - - If the pretrained model we're using doesn't already exist, this function - downloads it from the TensorFlow.org website and unpacks it into a directory. - """ - dest_directory = FLAGS.model_dir - print (dest_directory) - if not os.path.exists(dest_directory): - os.makedirs(dest_directory) - filename = DATA_URL.split('/')[-1] - filepath = os.path.join(dest_directory, filename) - print ("filepath is:" + filepath) - if not os.path.exists(filepath): - print ("file not found" + filepath) - def _progress(count, block_size, total_size): - sys.stdout.write('\r>> Downloading %s %.1f%%' % - (filename, - float(count * block_size) / float(total_size) * 100.0)) - sys.stdout.flush() - - filepath, _ = urllib.request.urlretrieve(DATA_URL, - filepath, - _progress) - print() - statinfo = os.stat(filepath) - print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') - tarfile.open(filepath, 'r:gz').extractall(dest_directory) - - def ensure_dir_exists(dir_name): """Makes sure the folder exists on disk. @@ -749,8 +704,6 @@ def main(_): tf.gfile.DeleteRecursively(FLAGS.summaries_dir) tf.gfile.MakeDirs(FLAGS.summaries_dir) - # Set up the pre-trained graph. - maybe_download_and_extract() graph, bottleneck_tensor, jpeg_data_tensor, resized_image_tensor = ( create_inception_graph())