Skip to content

Commit 9a80de0

Browse files
Migrate to DASWOW based cell classifier
1 parent 5d1b2c6 commit 9a80de0

File tree

13 files changed

+932
-234
lines changed

13 files changed

+932
-234
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,3 +387,4 @@ callsites-jupyternb-real-world-benchmark/notebooks/*.py
387387
evaluation/extended_dataset/DASWOW_dataset_fixed.csv
388388
.~lock*
389389
scripts/results
390+
daswow/models/

daswow/CellFeatures.py

Lines changed: 541 additions & 0 deletions
Large diffs are not rendered by default.

daswow/__init__.py

Whitespace-only changes.

daswow/daswow_model.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# %%
2+
import os
3+
from string import punctuation
4+
5+
import joblib
6+
import numpy as np
7+
from nltk.corpus import stopwords
8+
9+
from daswow.CellFeatures import CellFeatures
10+
from daswow.model_download import download_models_from_github_release
11+
12+
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
13+
MODELS_PATH = os.path.join(SCRIPT_DIR, "models")
14+
15+
16+
class Preprocessing:
17+
# init. set dataframe to be processed
18+
def __init__(self, df):
19+
self.df = df
20+
self.features = ["text"]
21+
self.stopWords = set(stopwords.words("english"))
22+
23+
def remove_stopwords(self, words):
24+
wordsFiltered = [w for w in words if w not in self.stopWords]
25+
return wordsFiltered
26+
27+
def set_column(self, col, newcol):
28+
self.df[newcol] = self.df[col].apply(self.combine_lists_to_text)
29+
return self.df
30+
31+
def custom_text_preprocessing(self, s):
32+
favourite_punc = [".", "#", "_"]
33+
if s:
34+
for char in punctuation:
35+
if char not in favourite_punc:
36+
s = s.replace(char, " ")
37+
s = " ".join(
38+
[
39+
"" if word.replace(".", "").isdigit() else word
40+
for word in s.split(" ")
41+
]
42+
)
43+
# s = " ".join(['$' if '$' in word and word.replace('$','').isnumeric() else word for word in s.split(' ')])
44+
s = " ".join(self.remove_stopwords(s.lower().split(" ")))
45+
s = " ".join([word.strip() for word in s.split(" ") if len(word) > 1])
46+
# s = " ".join([word for word in s if word not in throw_words])
47+
return s
48+
49+
def combine_lists_to_text(self, obj):
50+
text = ""
51+
if obj:
52+
try:
53+
if isinstance(obj, list):
54+
for element in obj:
55+
if isinstance(element, list):
56+
for e in element:
57+
text = text + " " + str(e)
58+
else:
59+
text = text + " " + str(element)
60+
elif isinstance(obj, str):
61+
text = text + " " + obj
62+
except:
63+
print("expecting string or list, found %s" % type(obj))
64+
65+
text = text.strip().lower()
66+
return text
67+
68+
def set_lexical(self, features):
69+
new_text = []
70+
for idx, row in self.df.iterrows():
71+
l = []
72+
for each in features:
73+
if isinstance(row[each], list):
74+
l = l + row[each]
75+
else:
76+
l = l + [row[each]]
77+
new_text.append(l)
78+
self.df["new_text"] = new_text
79+
return self.df
80+
81+
def process(self):
82+
self.df = self.set_lexical(self.features)
83+
self.df["new_text"] = self.df["text"].apply(self.combine_lists_to_text)
84+
self.df["new_text"] = self.df["new_text"].apply(self.custom_text_preprocessing)
85+
return self.df
86+
87+
88+
class DASWOWInference:
89+
def __init__(self, nb_path, models_path=MODELS_PATH):
90+
cf = CellFeatures()
91+
self.df = cf.get_cell_features_nb(nb_path)
92+
93+
download_models_from_github_release()
94+
95+
self.preprocesser = Preprocessing(self.df)
96+
self.model = joblib.load(f"{models_path}/rf_code_scaled.pkl")
97+
self.tfidf = joblib.load(f"{models_path}/tfidf_vectorizer.pkl")
98+
self.selector = joblib.load(f"{models_path}/selector.pkl")
99+
self.ss = joblib.load(f"{models_path}/scaler.pkl")
100+
self.stopWords = set(stopwords.words("english"))
101+
self.stat_features = [
102+
"linesofcomment",
103+
"linesofcode",
104+
"variable_count",
105+
"function_count",
106+
]
107+
self.labels = [
108+
"helper_functions",
109+
"load_data",
110+
"data_preprocessing",
111+
"data_exploration",
112+
"modelling",
113+
"evaluation",
114+
"prediction",
115+
"result_visualization",
116+
"save_results",
117+
"comment_only",
118+
]
119+
120+
def remove_stopwords(self, words):
121+
wordsFiltered = [w for w in words if w not in self.stopWords]
122+
return wordsFiltered
123+
124+
def preprocess(self):
125+
self.df = self.preprocesser.process()
126+
return True
127+
128+
def vectorize(self):
129+
text = self.tfidf.transform(self.df["new_text"])
130+
return text
131+
132+
def select_features(self, text):
133+
text = self.selector.transform(text)
134+
return text
135+
136+
def set_statistical_features(self, text):
137+
X_copy = text.toarray()
138+
139+
for each in self.stat_features:
140+
X_copy = np.c_[X_copy, self.df[each].values]
141+
142+
return X_copy
143+
144+
def scale_features(self, text):
145+
text = self.ss.transform(text)
146+
return text
147+
148+
def predict(self):
149+
self.preprocess()
150+
cells_features = self.vectorize()
151+
cells_features = self.select_features(cells_features)
152+
cells_features = self.set_statistical_features(cells_features)
153+
cells_features = self.scale_features(cells_features)
154+
prediction = self.model.predict(cells_features)
155+
# convert prediction to labels
156+
prediction = [
157+
[self.labels[i] for i, p in enumerate(pred) if p == 1]
158+
for pred in prediction
159+
]
160+
return prediction
161+
162+
163+
if __name__ == "__main__":
164+
nb_path = "/mnt/Projects/PhD/Research/Student-Thesis/7_Akshita/daswow-data-science-code-analysis/.scrapy/user_study_notebooks/user_study_notebooks/cyclegan-with-data-augmentation.ipynb"
165+
infer = DASWOWInference(
166+
nb_path=nb_path,
167+
)
168+
169+
infer.predict()

daswow/model_download.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
2+
import requests
3+
import os
4+
5+
SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
6+
7+
8+
9+
def download_models_from_github_release(repo_owner="secure-software-engineering",
10+
repo_name="HeaderGen",
11+
release_tag="models",
12+
asset_names=["rf_code_scaled.pkl", "scaler.pkl", "selector.pkl", "tfidf_vectorizer.pkl"],
13+
download_path=f"{SCRIPT_DIR}/models/"):
14+
"""Downloads specific files from a GitHub release.
15+
Args:
16+
repo_owner (str): The owner of the GitHub repository.
17+
repo_name (str): The name of the GitHub repository.
18+
release_tag (str): The tag name of the release (e.g., 'v1.0.0').
19+
asset_names (list): The names of the asset files to download.
20+
download_path (str): The local path where the files should be saved.
21+
"""
22+
23+
# first check if the download path exists
24+
if not os.path.exists(download_path):
25+
os.makedirs(download_path)
26+
27+
# check if files already exist and remove from the list
28+
for asset_name in asset_names:
29+
if os.path.exists(os.path.join(download_path, asset_name)):
30+
asset_names.remove(asset_name)
31+
32+
# API endpoint to get release info
33+
url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/tags/{release_tag}"
34+
35+
response = requests.get(url)
36+
response.raise_for_status() # Raise an exception for bad response status codes
37+
38+
release_data = response.json()
39+
40+
print("Assets in this release:") # Add this line
41+
for asset in release_data['assets']:
42+
print(asset['name']) # Add this line
43+
44+
for asset_name in asset_names:
45+
# Find the download URL of the asset
46+
asset_url = None
47+
for asset in release_data['assets']:
48+
if asset['name'] == asset_name:
49+
asset_url = asset['browser_download_url']
50+
break
51+
52+
if not asset_url:
53+
raise ValueError(f"Asset '{asset_name}' not found in the release.")
54+
55+
# Download the file
56+
response = requests.get(asset_url, stream=True)
57+
response.raise_for_status()
58+
59+
file_path = os.path.join(download_path, asset_name)
60+
with open(file_path, 'wb') as f:
61+
for chunk in response.iter_content(chunk_size=1024):
62+
if chunk:
63+
f.write(chunk)
64+
65+
print(f"File downloaded to: {file_path}")

framework_models/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,19 @@
9090
],
9191
}
9292

93+
DASWOW_PHASES = {
94+
"helper_functions": "Helper Functions",
95+
"load_data": "Load Data",
96+
"data_preprocessing": "Data Preprocessing",
97+
"data_exploration": "Data Exploration",
98+
"modelling": "Modelling",
99+
"evaluation": "Evaluation",
100+
"prediction": "Prediction",
101+
"result_visualization": "Result Visualization",
102+
"save_results": "Save Results",
103+
"comment_only": "Comment Only",
104+
}
105+
93106

94107
def get_high_level_phase(phase):
95108
for _k, _v in PHASE_GROUPS.items():

headergen/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,4 +114,4 @@ def types(input, output, json_output):
114114
@cli.command()
115115
def server():
116116
"""Start Server"""
117-
uvicorn.run(app, host="0.0.0.0", port=54068)
117+
uvicorn.run(app, host="0.0.0.0", port=8000)

0 commit comments

Comments
 (0)