Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python formatting, and gitignore additions. #326

Merged
merged 4 commits into from
Jul 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,8 @@ workbench.xmi
build
derby.log
metastore_db
__pycache__/
src/main/python/tf/model.zip
src/main/python/tf/util/spark.conf
src/main/python/tf/model/graph/
src/main/python/tf/model/category/
3 changes: 1 addition & 2 deletions src/main/python/aut/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from aut.common import WebArchive
from aut.udfs import extract_domain

__all__ = ['WebArchive', 'extract_domain']

__all__ = ["WebArchive", "extract_domain"]
2 changes: 1 addition & 1 deletion src/main/python/aut/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pyspark.sql import DataFrame


class WebArchive:
def __init__(self, sc, sqlContext, path):
self.sc = sc
Expand All @@ -12,4 +13,3 @@ def pages(self):

def links(self):
return DataFrame(self.loader.extractHyperlinks(self.path), self.sqlContext)

10 changes: 6 additions & 4 deletions src/main/python/aut/udfs.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType


def extract_domain_func(url):
url = url.replace('http://', '').replace('https://', '')
if '/' in url:
return url.split('/')[0].replace('www.', '')
url = url.replace("http://", "").replace("https://", "")
if "/" in url:
return url.split("/")[0].replace("www.", "")
else:
return url.replace('www.', '')
return url.replace("www.", "")


extract_domain = udf(extract_domain_func, StringType())
16 changes: 10 additions & 6 deletions src/main/python/tf/detect.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import os
import sys
from util.init import *

from pyspark.sql import DataFrame

from model.object_detection import *
from util.init import *

PYAUT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(PYAUT_DIR)

from aut.common import WebArchive
from pyspark.sql import DataFrame


if __name__ == "__main__":
# initialization
Expand All @@ -23,11 +25,13 @@
arc = WebArchive(sc, sql_context, args.web_archive)
df = DataFrame(arc.loader.extractImages(arc.path), sql_context)
filter_size = tuple(args.filter_size)
print("height >= %d and width >= %d"%filter_size)
preprocessed = df.filter("height >= %d and width >= %d"%filter_size)
print("height >= %d and width >= %d" % filter_size)
preprocessed = df.filter("height >= %d and width >= %d" % filter_size)

# detection
model_broadcast = detector.broadcast()
detect_udf = detector.get_detect_udf(model_broadcast)
res = preprocessed.select("url", detect_udf(col("bytes")).alias("prediction"), "bytes")
res = preprocessed.select(
"url", detect_udf(col("bytes")).alias("prediction"), "bytes"
)
res.write.json(args.output_path)
16 changes: 11 additions & 5 deletions src/main/python/tf/extract_images.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import numpy as np
import argparse

import numpy as np

from model.object_detection import SSDExtractor


def get_args():
parser = argparse.ArgumentParser(description='Extracting images from model output.')
parser.add_argument('--res_dir', help='Path of result (model output) directory.')
parser.add_argument('--output_dir', help='Path of extracted image file output directory.')
parser.add_argument('--threshold', type=float, help='Threshold of detection confidence scores.')
parser = argparse.ArgumentParser(description="Extracting images from model output.")
parser.add_argument("--res_dir", help="Path of result (model output) directory.")
parser.add_argument(
"--output_dir", help="Path of extracted image file output directory."
)
parser.add_argument(
"--threshold", type=float, help="Threshold of detection confidence scores."
)
return parser.parse_args()


Expand Down
50 changes: 28 additions & 22 deletions src/main/python/tf/model/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,21 @@ def __init__(self, res_dir, output_dir):
self.res_dir = res_dir
self.output_dir = output_dir


def _extract_and_save(self, rec, class_ids, threshold):
raise NotImplementedError("Please overwrite this method.")


def extract_and_save(self, class_ids, threshold):
if class_ids == "all":
class_ids = list(self.cate_dict.keys())

for idx in class_ids:
cls = self.cate_dict[idx]
check_dir(self.output_dir + "/%s/"%cls, create=True)
check_dir(self.output_dir + "/%s/" % cls, create=True)

for fname in os.listdir(self.res_dir):
if fname.startswith("part-"):
print("Extracting:", self.res_dir+"/"+fname)
with open(self.res_dir+"/"+fname) as f:
print("Extracting:", self.res_dir + "/" + fname)
with open(self.res_dir + "/" + fname) as f:
for line in f:
rec = json.loads(line)
self._extract_and_save(rec, class_ids, threshold)
Expand All @@ -43,47 +41,56 @@ class SSD:
def __init__(self, sc, sql_context, args):
self.sc = sc
self.sql_context = sql_context
self.category = load_cate_dict_from_pbtxt("%s/category/mscoco_label_map.pbtxt"%PKG_DIR)
self.checkpoint = "%s/graph/ssd_mobilenet_v1_fpn_640x640/frozen_inference_graph.pb"%PKG_DIR
self.category = load_cate_dict_from_pbtxt(
"%s/category/mscoco_label_map.pbtxt" % PKG_DIR
)
self.checkpoint = (
"%s/graph/ssd_mobilenet_v1_fpn_640x640/frozen_inference_graph.pb" % PKG_DIR
)
self.args = args
with tf.io.gfile.GFile(self.checkpoint, 'rb') as f:
with tf.io.gfile.GFile(self.checkpoint, "rb") as f:
model_params = f.read()
self.model_params = model_params


def broadcast(self):
return self.sc.broadcast(self.model_params)


def get_detect_udf(self, model_broadcast):
def batch_proc(bytes_batch):
with tf.Graph().as_default() as g:
graph_def = tf.GraphDef()
graph_def.ParseFromString(model_broadcast.value)
tf.import_graph_def(graph_def, name='')
image_tensor = g.get_tensor_by_name('image_tensor:0')
detection_scores = g.get_tensor_by_name('detection_scores:0')
detection_classes = g.get_tensor_by_name('detection_classes:0')
tf.import_graph_def(graph_def, name="")
image_tensor = g.get_tensor_by_name("image_tensor:0")
detection_scores = g.get_tensor_by_name("detection_scores:0")
detection_classes = g.get_tensor_by_name("detection_classes:0")

with tf.Session().as_default() as sess:
result = []
image_size = (640, 640)
images = np.array([img2np(b, image_size) for b in bytes_batch])
res = sess.run([detection_scores, detection_classes], feed_dict={image_tensor: images})
res = sess.run(
[detection_scores, detection_classes],
feed_dict={image_tensor: images},
)
for i in range(res[0].shape[0]):
result.append([res[0][i], res[1][i]])
return pd.Series(result)
return pandas_udf(ArrayType(ArrayType(FloatType())), PandasUDFType.SCALAR)(batch_proc)

return pandas_udf(ArrayType(ArrayType(FloatType())), PandasUDFType.SCALAR)(
batch_proc
)


class SSDExtractor(ImageExtractor):
def __init__(self, res_dir, output_dir):
super().__init__(res_dir, output_dir)
self.cate_dict = load_cate_dict_from_pbtxt("%s/category/mscoco_label_map.pbtxt"%PKG_DIR)

self.cate_dict = load_cate_dict_from_pbtxt(
"%s/category/mscoco_label_map.pbtxt" % PKG_DIR
)

def _extract_and_save(self, rec, class_ids, threshold):
pred = rec['prediction']
pred = rec["prediction"]
scores = np.array(pred[0])
classes = np.array(pred[1])
valid_classes = np.unique(classes[scores >= threshold])
Expand All @@ -102,8 +109,7 @@ def _extract_and_save(self, rec, class_ids, threshold):
cls = self.cate_dict[cls_idx]
try:
img = str2img(rec["bytes"])
img.save(self.output_dir+ "/%s/"%cls + url_parse(rec["url"]))
img.save(self.output_dir + "/%s/" % cls + url_parse(rec["url"]))
except:
fname = self.output_dir+ "/%s/"%cls + url_parse(rec["url"])
fname = self.output_dir + "/%s/" % cls + url_parse(rec["url"])
print("Failing to save:", fname)

5 changes: 2 additions & 3 deletions src/main/python/tf/model/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def str2img(byte_str):
return Image.open(io.BytesIO(base64.b64decode(bytes(byte_str, 'utf-8'))))
return Image.open(io.BytesIO(base64.b64decode(bytes(byte_str, "utf-8"))))


def img2np(byte_str, resize=None):
Expand All @@ -22,7 +22,7 @@ def img2np(byte_str, resize=None):
if len(img_shape) == 2:
img = np.stack([img, img, img], axis=-1)
elif img_shape[-1] >= 3:
img = img[:,:,:3]
img = img[:, :, :3]

return img

Expand Down Expand Up @@ -58,4 +58,3 @@ def load_cate_dict_from_pbtxt(path, key="id", value="display_name"):
cur_cate = re.findall(r'"(.*?)"', entry[1])[0]
cate_dict[cur_key] = cur_cate
return cate_dict

67 changes: 51 additions & 16 deletions src/main/python/tf/util/init.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import argparse
import os
import re
import zipfile

from pyspark import SparkConf, SparkContext, SQLContext
import re
import os


def init_spark(master, aut_jar):
conf = SparkConf()
conf.set("spark.jars", aut_jar)
conf_path = os.path.dirname(os.path.abspath(__file__))+"/spark.conf"
conf_path = os.path.dirname(os.path.abspath(__file__)) + "/spark.conf"
conf_dict = read_conf(conf_path)
for item, value in conf_dict.items():
conf.set(item, value)
Expand All @@ -18,29 +19,63 @@ def init_spark(master, aut_jar):


def get_args():
parser = argparse.ArgumentParser(description='PySpark for Web Archive Image Retrieval.')
parser.add_argument('--web_archive', help='Path to warcs.', default='/tuna1/scratch/nruest/geocites/warcs')
parser.add_argument('--aut_jar', help='Path to compiled aut jar.', default='aut/target/aut-0.17.1-SNAPSHOT-fatjar.jar')
parser.add_argument('--spark', help='Path to Apache Spark.', default='spark-2.3.2-bin-hadoop2.7/bin')
parser.add_argument('--master', help='Apache Spark master IP address and port.', default='spark://127.0.1.1:7077')
parser.add_argument('--img_model', help='Model for image processing.', default='ssd')
parser.add_argument('--filter_size', nargs='+', type=int, help='Filter out images smaller than filter_size', default=[640, 640])
parser.add_argument('--output_path', help='Path to image model output.', default='warc_res')
parser = argparse.ArgumentParser(
description="PySpark for Web Archive Image Retrieval."
)
parser.add_argument(
"--web_archive",
help="Path to warcs.",
default="/tuna1/scratch/nruest/geocites/warcs",
)
parser.add_argument(
"--aut_jar",
help="Path to compiled aut jar.",
default="aut/target/aut-0.17.1-SNAPSHOT-fatjar.jar",
)
parser.add_argument(
"--spark", help="Path to Apache Spark.", default="spark-2.3.2-bin-hadoop2.7/bin"
)
parser.add_argument(
"--master",
help="Apache Spark master IP address and port.",
default="spark://127.0.1.1:7077",
)
parser.add_argument(
"--img_model", help="Model for image processing.", default="ssd"
)
parser.add_argument(
"--filter_size",
nargs="+",
type=int,
help="Filter out images smaller than filter_size",
default=[640, 640],
)
parser.add_argument(
"--output_path", help="Path to image model output.", default="warc_res"
)
return parser.parse_args()


def zip_model_module(PYAUT_DIR):
zip = zipfile.ZipFile(os.path.join(PYAUT_DIR, "tf", "model.zip"), "w")
zip.write(os.path.join(PYAUT_DIR, "tf", "model", "__init__.py"), os.path.join("model", "__init__.py"))
zip.write(os.path.join(PYAUT_DIR, "tf", "model", "object_detection.py"), os.path.join("model", "object_detection.py"))
zip.write(os.path.join(PYAUT_DIR, "tf", "model", "preprocess.py"), os.path.join("model", "preprocess.py"))
zip.write(
os.path.join(PYAUT_DIR, "tf", "model", "__init__.py"),
os.path.join("model", "__init__.py"),
)
zip.write(
os.path.join(PYAUT_DIR, "tf", "model", "object_detection.py"),
os.path.join("model", "object_detection.py"),
)
zip.write(
os.path.join(PYAUT_DIR, "tf", "model", "preprocess.py"),
os.path.join("model", "preprocess.py"),
)


def read_conf(conf_path):
conf_dict = {}
with open(conf_path) as f:
for line in f:
conf = re.findall(r'\S+', line.strip())
conf = re.findall(r"\S+", line.strip())
conf_dict[conf[0]] = conf[1]
return conf_dict