Skip to content

Commit

Permalink
Python formatting, and gitignore additions. (#326)
Browse files Browse the repository at this point in the history
- Run black and isort on Python files.
- Move Spark config to example file.
- Update gitignore for 7a61f0e
additions.
  • Loading branch information
ruebot authored and ianmilligan1 committed Jul 18, 2019
1 parent f35d54e commit bd5ef14
Show file tree
Hide file tree
Showing 10 changed files with 115 additions and 59 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,8 @@ workbench.xmi
build
derby.log
metastore_db
__pycache__/
src/main/python/tf/model.zip
src/main/python/tf/util/spark.conf
src/main/python/tf/model/graph/
src/main/python/tf/model/category/
3 changes: 1 addition & 2 deletions src/main/python/aut/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from aut.common import WebArchive
from aut.udfs import extract_domain

__all__ = ['WebArchive', 'extract_domain']

__all__ = ["WebArchive", "extract_domain"]
2 changes: 1 addition & 1 deletion src/main/python/aut/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pyspark.sql import DataFrame


class WebArchive:
def __init__(self, sc, sqlContext, path):
self.sc = sc
Expand All @@ -12,4 +13,3 @@ def pages(self):

def links(self):
return DataFrame(self.loader.extractHyperlinks(self.path), self.sqlContext)

10 changes: 6 additions & 4 deletions src/main/python/aut/udfs.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType


def extract_domain_func(url):
url = url.replace('http://', '').replace('https://', '')
if '/' in url:
return url.split('/')[0].replace('www.', '')
url = url.replace("http://", "").replace("https://", "")
if "/" in url:
return url.split("/")[0].replace("www.", "")
else:
return url.replace('www.', '')
return url.replace("www.", "")


extract_domain = udf(extract_domain_func, StringType())
16 changes: 10 additions & 6 deletions src/main/python/tf/detect.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import os
import sys
from util.init import *

from pyspark.sql import DataFrame

from model.object_detection import *
from util.init import *

PYAUT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(PYAUT_DIR)

from aut.common import WebArchive
from pyspark.sql import DataFrame


if __name__ == "__main__":
# initialization
Expand All @@ -23,11 +25,13 @@
arc = WebArchive(sc, sql_context, args.web_archive)
df = DataFrame(arc.loader.extractImages(arc.path), sql_context)
filter_size = tuple(args.filter_size)
print("height >= %d and width >= %d"%filter_size)
preprocessed = df.filter("height >= %d and width >= %d"%filter_size)
print("height >= %d and width >= %d" % filter_size)
preprocessed = df.filter("height >= %d and width >= %d" % filter_size)

# detection
model_broadcast = detector.broadcast()
detect_udf = detector.get_detect_udf(model_broadcast)
res = preprocessed.select("url", detect_udf(col("bytes")).alias("prediction"), "bytes")
res = preprocessed.select(
"url", detect_udf(col("bytes")).alias("prediction"), "bytes"
)
res.write.json(args.output_path)
16 changes: 11 additions & 5 deletions src/main/python/tf/extract_images.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import numpy as np
import argparse

import numpy as np

from model.object_detection import SSDExtractor


def get_args():
parser = argparse.ArgumentParser(description='Extracting images from model output.')
parser.add_argument('--res_dir', help='Path of result (model output) directory.')
parser.add_argument('--output_dir', help='Path of extracted image file output directory.')
parser.add_argument('--threshold', type=float, help='Threshold of detection confidence scores.')
parser = argparse.ArgumentParser(description="Extracting images from model output.")
parser.add_argument("--res_dir", help="Path of result (model output) directory.")
parser.add_argument(
"--output_dir", help="Path of extracted image file output directory."
)
parser.add_argument(
"--threshold", type=float, help="Threshold of detection confidence scores."
)
return parser.parse_args()


Expand Down
50 changes: 28 additions & 22 deletions src/main/python/tf/model/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,21 @@ def __init__(self, res_dir, output_dir):
self.res_dir = res_dir
self.output_dir = output_dir


def _extract_and_save(self, rec, class_ids, threshold):
raise NotImplementedError("Please overwrite this method.")


def extract_and_save(self, class_ids, threshold):
if class_ids == "all":
class_ids = list(self.cate_dict.keys())

for idx in class_ids:
cls = self.cate_dict[idx]
check_dir(self.output_dir + "/%s/"%cls, create=True)
check_dir(self.output_dir + "/%s/" % cls, create=True)

for fname in os.listdir(self.res_dir):
if fname.startswith("part-"):
print("Extracting:", self.res_dir+"/"+fname)
with open(self.res_dir+"/"+fname) as f:
print("Extracting:", self.res_dir + "/" + fname)
with open(self.res_dir + "/" + fname) as f:
for line in f:
rec = json.loads(line)
self._extract_and_save(rec, class_ids, threshold)
Expand All @@ -43,47 +41,56 @@ class SSD:
def __init__(self, sc, sql_context, args):
self.sc = sc
self.sql_context = sql_context
self.category = load_cate_dict_from_pbtxt("%s/category/mscoco_label_map.pbtxt"%PKG_DIR)
self.checkpoint = "%s/graph/ssd_mobilenet_v1_fpn_640x640/frozen_inference_graph.pb"%PKG_DIR
self.category = load_cate_dict_from_pbtxt(
"%s/category/mscoco_label_map.pbtxt" % PKG_DIR
)
self.checkpoint = (
"%s/graph/ssd_mobilenet_v1_fpn_640x640/frozen_inference_graph.pb" % PKG_DIR
)
self.args = args
with tf.io.gfile.GFile(self.checkpoint, 'rb') as f:
with tf.io.gfile.GFile(self.checkpoint, "rb") as f:
model_params = f.read()
self.model_params = model_params


def broadcast(self):
return self.sc.broadcast(self.model_params)


def get_detect_udf(self, model_broadcast):
def batch_proc(bytes_batch):
with tf.Graph().as_default() as g:
graph_def = tf.GraphDef()
graph_def.ParseFromString(model_broadcast.value)
tf.import_graph_def(graph_def, name='')
image_tensor = g.get_tensor_by_name('image_tensor:0')
detection_scores = g.get_tensor_by_name('detection_scores:0')
detection_classes = g.get_tensor_by_name('detection_classes:0')
tf.import_graph_def(graph_def, name="")
image_tensor = g.get_tensor_by_name("image_tensor:0")
detection_scores = g.get_tensor_by_name("detection_scores:0")
detection_classes = g.get_tensor_by_name("detection_classes:0")

with tf.Session().as_default() as sess:
result = []
image_size = (640, 640)
images = np.array([img2np(b, image_size) for b in bytes_batch])
res = sess.run([detection_scores, detection_classes], feed_dict={image_tensor: images})
res = sess.run(
[detection_scores, detection_classes],
feed_dict={image_tensor: images},
)
for i in range(res[0].shape[0]):
result.append([res[0][i], res[1][i]])
return pd.Series(result)
return pandas_udf(ArrayType(ArrayType(FloatType())), PandasUDFType.SCALAR)(batch_proc)

return pandas_udf(ArrayType(ArrayType(FloatType())), PandasUDFType.SCALAR)(
batch_proc
)


class SSDExtractor(ImageExtractor):
def __init__(self, res_dir, output_dir):
super().__init__(res_dir, output_dir)
self.cate_dict = load_cate_dict_from_pbtxt("%s/category/mscoco_label_map.pbtxt"%PKG_DIR)

self.cate_dict = load_cate_dict_from_pbtxt(
"%s/category/mscoco_label_map.pbtxt" % PKG_DIR
)

def _extract_and_save(self, rec, class_ids, threshold):
pred = rec['prediction']
pred = rec["prediction"]
scores = np.array(pred[0])
classes = np.array(pred[1])
valid_classes = np.unique(classes[scores >= threshold])
Expand All @@ -102,8 +109,7 @@ def _extract_and_save(self, rec, class_ids, threshold):
cls = self.cate_dict[cls_idx]
try:
img = str2img(rec["bytes"])
img.save(self.output_dir+ "/%s/"%cls + url_parse(rec["url"]))
img.save(self.output_dir + "/%s/" % cls + url_parse(rec["url"]))
except:
fname = self.output_dir+ "/%s/"%cls + url_parse(rec["url"])
fname = self.output_dir + "/%s/" % cls + url_parse(rec["url"])
print("Failing to save:", fname)

5 changes: 2 additions & 3 deletions src/main/python/tf/model/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def str2img(byte_str):
return Image.open(io.BytesIO(base64.b64decode(bytes(byte_str, 'utf-8'))))
return Image.open(io.BytesIO(base64.b64decode(bytes(byte_str, "utf-8"))))


def img2np(byte_str, resize=None):
Expand All @@ -22,7 +22,7 @@ def img2np(byte_str, resize=None):
if len(img_shape) == 2:
img = np.stack([img, img, img], axis=-1)
elif img_shape[-1] >= 3:
img = img[:,:,:3]
img = img[:, :, :3]

return img

Expand Down Expand Up @@ -58,4 +58,3 @@ def load_cate_dict_from_pbtxt(path, key="id", value="display_name"):
cur_cate = re.findall(r'"(.*?)"', entry[1])[0]
cate_dict[cur_key] = cur_cate
return cate_dict

67 changes: 51 additions & 16 deletions src/main/python/tf/util/init.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import argparse
import os
import re
import zipfile

from pyspark import SparkConf, SparkContext, SQLContext
import re
import os


def init_spark(master, aut_jar):
conf = SparkConf()
conf.set("spark.jars", aut_jar)
conf_path = os.path.dirname(os.path.abspath(__file__))+"/spark.conf"
conf_path = os.path.dirname(os.path.abspath(__file__)) + "/spark.conf"
conf_dict = read_conf(conf_path)
for item, value in conf_dict.items():
conf.set(item, value)
Expand All @@ -18,29 +19,63 @@ def init_spark(master, aut_jar):


def get_args():
parser = argparse.ArgumentParser(description='PySpark for Web Archive Image Retrieval.')
parser.add_argument('--web_archive', help='Path to warcs.', default='/tuna1/scratch/nruest/geocites/warcs')
parser.add_argument('--aut_jar', help='Path to compiled aut jar.', default='aut/target/aut-0.17.1-SNAPSHOT-fatjar.jar')
parser.add_argument('--spark', help='Path to Apache Spark.', default='spark-2.3.2-bin-hadoop2.7/bin')
parser.add_argument('--master', help='Apache Spark master IP address and port.', default='spark://127.0.1.1:7077')
parser.add_argument('--img_model', help='Model for image processing.', default='ssd')
parser.add_argument('--filter_size', nargs='+', type=int, help='Filter out images smaller than filter_size', default=[640, 640])
parser.add_argument('--output_path', help='Path to image model output.', default='warc_res')
parser = argparse.ArgumentParser(
description="PySpark for Web Archive Image Retrieval."
)
parser.add_argument(
"--web_archive",
help="Path to warcs.",
default="/tuna1/scratch/nruest/geocites/warcs",
)
parser.add_argument(
"--aut_jar",
help="Path to compiled aut jar.",
default="aut/target/aut-0.17.1-SNAPSHOT-fatjar.jar",
)
parser.add_argument(
"--spark", help="Path to Apache Spark.", default="spark-2.3.2-bin-hadoop2.7/bin"
)
parser.add_argument(
"--master",
help="Apache Spark master IP address and port.",
default="spark://127.0.1.1:7077",
)
parser.add_argument(
"--img_model", help="Model for image processing.", default="ssd"
)
parser.add_argument(
"--filter_size",
nargs="+",
type=int,
help="Filter out images smaller than filter_size",
default=[640, 640],
)
parser.add_argument(
"--output_path", help="Path to image model output.", default="warc_res"
)
return parser.parse_args()


def zip_model_module(PYAUT_DIR):
zip = zipfile.ZipFile(os.path.join(PYAUT_DIR, "tf", "model.zip"), "w")
zip.write(os.path.join(PYAUT_DIR, "tf", "model", "__init__.py"), os.path.join("model", "__init__.py"))
zip.write(os.path.join(PYAUT_DIR, "tf", "model", "object_detection.py"), os.path.join("model", "object_detection.py"))
zip.write(os.path.join(PYAUT_DIR, "tf", "model", "preprocess.py"), os.path.join("model", "preprocess.py"))
zip.write(
os.path.join(PYAUT_DIR, "tf", "model", "__init__.py"),
os.path.join("model", "__init__.py"),
)
zip.write(
os.path.join(PYAUT_DIR, "tf", "model", "object_detection.py"),
os.path.join("model", "object_detection.py"),
)
zip.write(
os.path.join(PYAUT_DIR, "tf", "model", "preprocess.py"),
os.path.join("model", "preprocess.py"),
)


def read_conf(conf_path):
conf_dict = {}
with open(conf_path) as f:
for line in f:
conf = re.findall(r'\S+', line.strip())
conf = re.findall(r"\S+", line.strip())
conf_dict[conf[0]] = conf[1]
return conf_dict

File renamed without changes.

0 comments on commit bd5ef14

Please sign in to comment.