-
Notifications
You must be signed in to change notification settings - Fork 17
Classify flags into general, linear, nn categories #26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
…me flags orders in main.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please help change settings in https://github.com/ntumlgroup/LibMultiLabel/blob/master/docs/conf.py#L48-L52 to prevent writing sg_execution_times.rst in Sphinx 5.
sphinx_gallery_conf = {
...
"write_computation_times": False,
...
}…es.rst in Sphinx 5 - optimize code in docs/cli/classifier.py - reformat the above scripts
| current_dir = os.path.dirname(os.path.abspath(__file__)) | ||
| lib_path = os.path.abspath(os.path.join(current_dir, "..", "..")) | ||
| sys.path.insert(0, lib_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| lib_path = os.path.abspath(os.path.join(current_dir, "..", "..")) | |
| sys.path.insert(0, lib_path) | |
| sys.path.insert(0, os.getcwd()) |
| def classify(raw_flags): | ||
|
|
||
| category_set = {"general": set(), "linear": set(), "nn": set()} | ||
| flags = fetch_option_flags(raw_flags) | ||
| allowed_keys = set(flag["instruction"] for flag in flags) | ||
| file_set = fetch_all_files() | ||
| usage_map = defaultdict(list) | ||
| collected = {} | ||
|
|
||
| for file_path in file_set: | ||
| detailed_results = find_config_usages_in_file(file_path, allowed_keys) | ||
| if detailed_results: | ||
| usage_map[file_path] = set(detailed_results.keys()) | ||
| for k, v in detailed_results.items(): | ||
| if k not in collected: | ||
| collected[k] = [] | ||
| collected[k].append(v) | ||
|
|
||
| for path, keys in usage_map.items(): | ||
| category, path = classify_file_category(path) | ||
| category_set[category] = category_set[category].union(keys) | ||
|
|
||
| category_set = move_duplicates_together(category_set, "general") | ||
|
|
||
| for flag in flags: | ||
| for k, v in category_set.items(): | ||
| for i in v: | ||
| if flag["instruction"] == i: | ||
| flag["category"] = k | ||
| if "category" not in flag: | ||
| flag["category"] = "general" | ||
|
|
||
| result = {} | ||
| for flag in flags: | ||
| if flag["category"] not in result: | ||
| result[flag["category"]] = [] | ||
|
|
||
| result[flag["category"]].append( | ||
| {"name": flag["name"].replace("--", r"\-\-"), "description": flag["description"]} | ||
| ) | ||
|
|
||
| result["details"] = [] | ||
| for k, v in collected.items(): | ||
| result["details"].append({"name": k, "file": v[0]["file"], "location": ", ".join(v[0]["lines"])}) | ||
| if len(v) > 1: | ||
| for i in v[1:]: | ||
| result["details"].append({"name": "", "file": i["file"], "location": ", ".join(i["lines"])}) | ||
|
|
||
| return result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about simplify the data structure (e.g., unused detailed line numbers) after the spec is decided?
| def classify(raw_flags): | |
| category_set = {"general": set(), "linear": set(), "nn": set()} | |
| flags = fetch_option_flags(raw_flags) | |
| allowed_keys = set(flag["instruction"] for flag in flags) | |
| file_set = fetch_all_files() | |
| usage_map = defaultdict(list) | |
| collected = {} | |
| for file_path in file_set: | |
| detailed_results = find_config_usages_in_file(file_path, allowed_keys) | |
| if detailed_results: | |
| usage_map[file_path] = set(detailed_results.keys()) | |
| for k, v in detailed_results.items(): | |
| if k not in collected: | |
| collected[k] = [] | |
| collected[k].append(v) | |
| for path, keys in usage_map.items(): | |
| category, path = classify_file_category(path) | |
| category_set[category] = category_set[category].union(keys) | |
| category_set = move_duplicates_together(category_set, "general") | |
| for flag in flags: | |
| for k, v in category_set.items(): | |
| for i in v: | |
| if flag["instruction"] == i: | |
| flag["category"] = k | |
| if "category" not in flag: | |
| flag["category"] = "general" | |
| result = {} | |
| for flag in flags: | |
| if flag["category"] not in result: | |
| result[flag["category"]] = [] | |
| result[flag["category"]].append( | |
| {"name": flag["name"].replace("--", r"\-\-"), "description": flag["description"]} | |
| ) | |
| result["details"] = [] | |
| for k, v in collected.items(): | |
| result["details"].append({"name": k, "file": v[0]["file"], "location": ", ".join(v[0]["lines"])}) | |
| if len(v) > 1: | |
| for i in v[1:]: | |
| result["details"].append({"name": "", "file": i["file"], "location": ", ".join(i["lines"])}) | |
| return result | |
| def classify(raw_flags): | |
| category_set = {"general": set(), "linear": set(), "nn": set()} | |
| flags = fetch_option_flags(raw_flags) | |
| allowed_keys = set(flag["instruction"] for flag in flags) | |
| file_set = fetch_all_files() | |
| for file_path in file_set: | |
| find_config_usages_in_file(file_path, allowed_keys, category_set) | |
| category_set = move_duplicates_together(category_set) | |
| result = defaultdict(list) | |
| for flag in raw_flags: | |
| for category, keys in category_set.items(): | |
| for key in keys: | |
| if key in flag["name"]: | |
| result[category].append(flag) | |
| return result |
| def find_config_usages_in_file(file_path, allowed_keys): | ||
| pattern = re.compile(r"\bconfig\.([a-zA-Z_][a-zA-Z0-9_]*)") | ||
| detailed_results = {} | ||
| try: | ||
| with open(file_path, "r", encoding="utf-8") as f: | ||
| lines = f.readlines() | ||
| except (IOError, UnicodeDecodeError): | ||
| return [] | ||
|
|
||
| _, path = classify_file_category(file_path) | ||
|
|
||
| if file_path.endswith("main.py"): | ||
| for idx in range(len(lines)): | ||
| if lines[idx].startswith("def main("): | ||
| lines = lines[idx:] | ||
| main_start = idx | ||
| break | ||
| for i, line in enumerate(lines[1:]): | ||
| if line and line[0] not in (" ", "\t") and line.strip() != "": | ||
| lines = lines[:i] | ||
| break | ||
|
|
||
| for i, line in enumerate(lines, start=1): | ||
| matches = pattern.findall(line) | ||
| for key in matches: | ||
| if key in allowed_keys: | ||
| if key not in detailed_results: | ||
| detailed_results[key] = {"file": path, "lines": []} | ||
| if file_path.endswith("main.py"): | ||
| detailed_results[key]["lines"].append(str(i + main_start)) | ||
| else: | ||
| detailed_results[key]["lines"].append(str(i)) | ||
|
|
||
| return detailed_results |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar in this function,
| def find_config_usages_in_file(file_path, allowed_keys): | |
| pattern = re.compile(r"\bconfig\.([a-zA-Z_][a-zA-Z0-9_]*)") | |
| detailed_results = {} | |
| try: | |
| with open(file_path, "r", encoding="utf-8") as f: | |
| lines = f.readlines() | |
| except (IOError, UnicodeDecodeError): | |
| return [] | |
| _, path = classify_file_category(file_path) | |
| if file_path.endswith("main.py"): | |
| for idx in range(len(lines)): | |
| if lines[idx].startswith("def main("): | |
| lines = lines[idx:] | |
| main_start = idx | |
| break | |
| for i, line in enumerate(lines[1:]): | |
| if line and line[0] not in (" ", "\t") and line.strip() != "": | |
| lines = lines[:i] | |
| break | |
| for i, line in enumerate(lines, start=1): | |
| matches = pattern.findall(line) | |
| for key in matches: | |
| if key in allowed_keys: | |
| if key not in detailed_results: | |
| detailed_results[key] = {"file": path, "lines": []} | |
| if file_path.endswith("main.py"): | |
| detailed_results[key]["lines"].append(str(i + main_start)) | |
| else: | |
| detailed_results[key]["lines"].append(str(i)) | |
| return detailed_results | |
| def find_config_usages_in_file(file_path, allowed_keys, category_set): | |
| pattern = re.compile(r"\bconfig\.([a-zA-Z_][a-zA-Z0-9_]*)") | |
| try: | |
| with open(file_path, "r", encoding="utf-8") as f: | |
| lines = f.readlines() | |
| except (IOError, UnicodeDecodeError): | |
| return [] | |
| # get start line in main.py | |
| if file_path.endswith("main.py"): | |
| for idx in range(len(lines)): | |
| if lines[idx].startswith("def main("): | |
| lines = lines[idx:] | |
| break | |
| all_str = " ".join(lines) | |
| matches = set(pattern.findall(all_str)) & allowed_keys | |
| category = classify_file_category(file_path)[0] | |
| for key in matches: | |
| category_set[category].add(key) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TBD: try-catch here can be removed, as we want to see error instantly when someone change the files (not like production settings)
| def classify_file_category(path): | ||
|
|
||
| relative_path = Path(path).relative_to(lib_path) | ||
| return_path = relative_path.as_posix() | ||
| filename = Path(*relative_path.parts[1:]).as_posix() if len(relative_path.parts) > 1 else return_path | ||
|
|
||
| if filename.startswith("linear"): | ||
| category = "linear" | ||
| elif filename.startswith("torch") or filename.startswith("nn"): | ||
| category = "nn" | ||
| else: | ||
| category = "general" | ||
| return category, return_path |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After simplifying find_config_usage_in_file, we no longer need return_path here.
Let's discuss how it can be simplified further.
| def move_duplicates_together(data, keep): | ||
| all_keys = list(data.keys()) | ||
| duplicates = set() | ||
|
|
||
| for i, key1 in enumerate(all_keys): | ||
| for key2 in all_keys[i + 1 :]: | ||
| duplicates |= data[key1] & data[key2] | ||
|
|
||
| data[keep] |= duplicates | ||
|
|
||
| for key in all_keys: | ||
| if key != keep: | ||
| data[key] -= duplicates | ||
|
|
||
| return data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TBD: readability
What does this PR do?
(Some descriptions here...)
Test CLI & API (
bash tests/autotest.sh)Test APIs used by main.py.
Check API Document
If any new APIs are added, please check if the description of the APIs is added to API document.
Test quickstart & API (
bash tests/docs/test_changed_document.sh)If any APIs in quickstarts or tutorials are modified, please run this test to check if the current examples can run correctly after the modified APIs are released.