Skip to content

Commit

Permalink
Generalized Label Retrieval
Browse files Browse the repository at this point in the history
Generic function to retrieve labels from images while scraping.
  • Loading branch information
hughhan1 committed Apr 27, 2017
1 parent 03b9298 commit 43074d3
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 31 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ thumbnails/*.png

# Data Files
*.h5
*.csv

# JSON Files
json/*.json
Expand Down
71 changes: 40 additions & 31 deletions moma.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import json
import os
import socket
import sys
import csv

Expand All @@ -19,7 +20,7 @@
artworks_file = 'json/artworks.json'

pad = 6 #padding to ensure filnames are 6 characters, for sorting in utils.py
max_imgs = 600 #TODO: pass as argument
max_imgs = 10000 #TODO: pass as argument


def make_soup(url):
Expand Down Expand Up @@ -57,15 +58,15 @@ def get_image(url, filename):
if image_link == "":
print('error : no image found at .' % (url))
else:


try:
request = urlopen(image_link, timeout=5) #timeout for getting image
with open(os.path.join(image_dir, filename), 'wb') as f:
request = urlopen(image_link, timeout=5)
with open(os.path.join(image_dir, filename), 'w') as f:
f.write(request.read())
except:
print('success : %s downloaded to %s directory.' % (filename, image_dir))
except AttributeError:
print('error : %s could not be downloaded to %s directory.' % (filename, image_dir))
except socket.timeout:
print('error : %s could not be downloaded to %s directory.' % (filename, image_dir))

return image_link


Expand All @@ -92,23 +93,31 @@ def get_thumbnail(url, filename):
print('error : %s could not be downloaded to %s directory.' % (filename, thumb_dir))
return url

def write_labels(nation_labels):

def write_labels(labels_map, filename):
"""
Creates CSV file of labels (nationality, date) to be used by utils.py
Args:
nation_labels : the nationalities
date_labels : the artwork dates
labels_map : a dictionary with key objectID and value label name
filename : the filename to which the output should be written
"""

print("French: "+ str(nation_labels.count("French")))
print("British: "+ str(nation_labels.count("British")))
print("American: "+ str(nation_labels.count("American")))
nation_csv = "nations.csv"
with open(nation_csv, "w") as output:
writer = csv.writer(output, lineterminator='\n')
for val in nation_labels:
writer.writerow([val])
unique_values = set() # Create a set of non-repeated labels
for val in labels_map.values(): # from our label map.
unique_values.add(val)

for val in unique_values: # Count the number of times each label
sys.stderr.write( # occured.
"%s : %d\n" %
(val, sum(x == val for x in labels_map.values()))
)

# Write the results to a CSV file, structured as objectID\tlabel
with open(filename, "w") as output:
writer = csv.writer(output, delimiter='\t', lineterminator='\n')
for key, val in labels_map.iteritems():
writer.writerow([key, val])


def get_images(artworks_filename):
Expand All @@ -122,25 +131,25 @@ def get_images(artworks_filename):
artworks_file = open(artworks_filename)
artworks_data = json.load(artworks_file)

nation_labels = []
classification_labels = {}
i = 0
for artwork in artworks_data[:max_imgs]:
url = artwork['URL']
object_id = artwork['ObjectID']
nation = artwork['Nationality']

print(str(object_id))
if url is not None and nation and object_id != 209 and object_id != 304 and object_id != 345 and object_id != 361:
if i % 500 != 0:
i += 1
continue
else:
i += 1

check = ["American", "British", "Italian"] #TODO: pass as command line
#ensure that nationality is known remove: ["", "Nationality unknown", "Nationality Unknown"]
if nation[0] in check:
get_image(url, str(object_id).zfill(pad) + '.jpg')
nation_labels.append(nation[0]) #choose first nationality, sometimes repeated
url = artwork['URL']
object_id = artwork['ObjectID']
classification = artwork['Classification']

print(str(object_id) + " : " + nation[0])
if url is not None:
get_image(url, str(object_id).zfill(pad) + '.jpg')
classification_labels[object_id] = classification

write_labels(nation_labels)
write_labels(classification_labels, "foo.csv")
artworks_file.close()


Expand Down

0 comments on commit 43074d3

Please sign in to comment.