-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
SIGMOD benchmark configurations (#581)
- Loading branch information
1 parent
fe37530
commit ee1904d
Showing
17 changed files
with
2,143 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import json | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
# Load the statistics from the JSON file | ||
with open("dataset_stats.json", "r") as f: | ||
stats = json.load(f) | ||
|
||
# Sort classes by sample count and get the top 20 | ||
class_counts = stats["train"]["per_class"] | ||
|
||
print("total classes: " + str(len(class_counts))) | ||
|
||
sorted_classes = sorted(class_counts.items(), key=lambda item: item[1], reverse=True) | ||
|
||
|
||
print("total sorted classes: " + str(len(class_counts))) | ||
|
||
top_classes = sorted_classes[:50] | ||
top_class_names, top_class_samples = zip(*top_classes) | ||
num_top_samples = sum(count for _, count in top_classes) | ||
|
||
print(f"total top classes: {len(top_classes)} num_samples in there: {num_top_samples}") | ||
|
||
# Sum the samples of all other classes | ||
other_samples = sum(count for _, count in sorted_classes[50:]) | ||
avg_samples_class = other_samples / len(sorted_classes[50:]) | ||
print("average samples / class in other: " + str(avg_samples_class)) | ||
print("remaining classes: " + str(len(sorted_classes[50:]))) | ||
|
||
# Print the top classes and their sample counts for verification | ||
print("Top classes and their sample counts:") | ||
for class_name, sample_count in top_classes: | ||
print(f"{class_name}: {sample_count}") | ||
print(f"Other: {other_samples}") | ||
|
||
# Plot bar chart of samples per class for top 20 classes and "Other" | ||
plt.figure(figsize=(14, 6)) | ||
plt.bar(top_class_names + ("Other",), top_class_samples + (other_samples,), color="skyblue") | ||
plt.title("Number of Samples per Top 20 Classes and Other") | ||
plt.xlabel("Class") | ||
plt.ylabel("Number of Samples") | ||
plt.xticks(rotation=90) | ||
plt.tight_layout() | ||
plt.show() | ||
|
||
# Prepare data for stacked bar chart | ||
years = sorted(stats["train"]["per_year_and_class"].keys()) | ||
top_classes_set = set(top_class_names) | ||
|
||
# Initialize the data structure for stacked bar chart | ||
stacked_data = {year: {class_name: 0 for class_name in top_class_names + ("Other",)} for year in years} | ||
|
||
# Populate the data structure with actual counts | ||
for year, classes_in_year in stats["train"]["per_year_and_class"].items(): | ||
for class_name, count in classes_in_year.items(): | ||
if class_name in top_classes_set: | ||
stacked_data[year][class_name] += count | ||
else: | ||
stacked_data[year]["Other"] += count | ||
|
||
# Plot stacked bar chart of samples per class within each year for top 20 classes and "Other" | ||
plt.figure(figsize=(14, 6)) | ||
bottom = np.zeros(len(years)) | ||
|
||
for class_name in top_class_names + ("Other",): | ||
samples_per_year = [stacked_data[year][class_name] for year in years] | ||
plt.bar(years, samples_per_year, bottom=bottom, label=class_name) | ||
bottom = np.add(bottom, samples_per_year) | ||
|
||
plt.title("Number of Samples per Top 20 Classes and Other within Each Year") | ||
plt.xlabel("Year") | ||
plt.ylabel("Number of Samples") | ||
plt.legend(title="Classes", bbox_to_anchor=(1.05, 1), loc="upper left") | ||
plt.xticks(rotation=90) | ||
plt.tight_layout() | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import json | ||
|
||
import matplotlib.pyplot as plt | ||
|
||
# Load the statistics from the JSON file | ||
with open("hierarchy_stats.json", "r") as f: | ||
stats = json.load(f) | ||
|
||
split = "train" | ||
|
||
|
||
# Get the list of all classes | ||
all_classes = stats[split]["per_class"].keys() | ||
num_classes = len(set(all_classes)) | ||
|
||
print(f"there are {len(set(all_classes))} classes") | ||
|
||
# Get the years available in the dataset | ||
years = sorted(stats[split]["per_year_and_class"].keys()) | ||
|
||
num_years = len(years) | ||
|
||
results = [0 for _ in range(num_years + 1)] | ||
|
||
# Plot the number of samples per year for each class | ||
for class_name in [str(i) for i in range(num_classes)]: | ||
samples_per_year = [stats[split]["per_year_and_class"].get(year, {}).get(class_name, 0) for year in years] | ||
|
||
max_samples = -1 | ||
max_year_idx = -1 | ||
for year_idx, samples in enumerate(samples_per_year): | ||
if samples > max_samples: | ||
max_year_idx = year_idx | ||
max_samples = samples | ||
|
||
for i in range(len(results)): | ||
base_idx = max_year_idx + i - (num_years // 2) | ||
if base_idx < 0 or base_idx >= len(samples_per_year): | ||
continue | ||
results[i] += samples_per_year[base_idx] / sum(samples_per_year) | ||
|
||
print(results) | ||
plt.figure(figsize=(10, 5)) | ||
plt.bar(range(-num_years // 2, (num_years // 2) + 1, 1), results, color="cornflowerblue") | ||
plt.title(f"Number of Samples per Year for Class: {class_name}") | ||
plt.xlabel("Year") | ||
plt.ylabel("Number of Samples") | ||
plt.xticks(range(-num_years // 2, (num_years // 2) + 1, 1), rotation=45) | ||
plt.tight_layout() | ||
plt.show() |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
from __future__ import annotations | ||
|
||
from modyn.config import ( | ||
CheckpointingConfig, | ||
LrSchedulerConfig, | ||
OptimizationCriterion, | ||
OptimizerConfig, | ||
OptimizerParamGroup, | ||
) | ||
from modyn.config.schema.pipeline import ( | ||
AccuracyMetricConfig, | ||
DataConfig, | ||
EvalDataConfig, | ||
EvaluationConfig, | ||
F1ScoreMetricConfig, | ||
FullModelStrategy, | ||
ModelConfig, | ||
ModynPipelineConfig, | ||
Pipeline, | ||
PipelineModelStorageConfig, | ||
SelectionStrategy, | ||
TimeTriggerConfig, | ||
TrainingConfig, | ||
) | ||
from modyn.config.schema.pipeline.evaluation.handler import EvalHandlerConfig | ||
from modyn.config.schema.pipeline.evaluation.strategy.slicing import SlicingEvalStrategyConfig | ||
|
||
|
||
def gen_arxiv_training_conf( | ||
optimizer: str, lr: float, gpu_device: str, lr_scheduler: LrSchedulerConfig | None, num_epochs: int, seed: int | ||
): | ||
if optimizer == "SGD": | ||
opti_conf = OptimizerConfig( | ||
name="default", | ||
algorithm="SGD", | ||
source="PyTorch", | ||
param_groups=[ | ||
OptimizerParamGroup(module="model", config={"lr": lr, "momentum": 0.9, "weight_decay": 0.01}) | ||
], | ||
) | ||
elif optimizer == "AdamW": | ||
opti_conf = OptimizerConfig( | ||
name="default", | ||
algorithm="AdamW", | ||
source="PyTorch", | ||
param_groups=[OptimizerParamGroup(module="model", config={"lr": lr, "weight_decay": 0.01})], | ||
) | ||
else: | ||
raise ValueError(optimizer) | ||
|
||
return TrainingConfig( | ||
gpus=1, | ||
device=gpu_device, | ||
dataloader_workers=1, | ||
use_previous_model=True, | ||
initial_model="random", | ||
batch_size=128, | ||
optimizers=[opti_conf], | ||
optimization_criterion=OptimizationCriterion(name="CrossEntropyLoss"), | ||
checkpointing=CheckpointingConfig(activated=False), | ||
lr_scheduler=lr_scheduler, | ||
epochs_per_trigger=num_epochs, | ||
shuffle=True, | ||
amp=False, | ||
seed=seed, | ||
) | ||
|
||
|
||
def gen_arxiv_config( | ||
config_id: str, | ||
num_epochs: int, | ||
gpu_device: str, | ||
selection_strategy: SelectionStrategy, | ||
lr_scheduler: LrSchedulerConfig | None, | ||
model: str, | ||
dataset: str, | ||
num_classes: int, | ||
seed: int, | ||
optimizer: str, | ||
lr: float, | ||
) -> ModynPipelineConfig: | ||
del model # ignored for now | ||
del dataset | ||
del num_classes | ||
|
||
return ModynPipelineConfig( | ||
pipeline=Pipeline(name=f"arxiv_{config_id}", description="Arxiv data selection config", version="0.0.1"), | ||
model=ModelConfig(id="ArticleNet", config={"num_classes": 172}), | ||
model_storage=PipelineModelStorageConfig(full_model_strategy=FullModelStrategy(name="PyTorchFullModel")), | ||
training=gen_arxiv_training_conf(optimizer, lr, gpu_device, lr_scheduler, num_epochs, seed), | ||
selection_strategy=selection_strategy, | ||
data=DataConfig( | ||
dataset_id="arxiv", | ||
transformations=[], | ||
bytes_parser_function=( | ||
"import torch\n" | ||
"def bytes_parser_function(data: memoryview) -> torch.Tensor:\n" | ||
" return str(data, 'utf8')" | ||
), | ||
tokenizer="DistilBertTokenizerTransform", | ||
), | ||
trigger=TimeTriggerConfig(every="1d", start_timestamp=0), | ||
evaluation=EvaluationConfig( | ||
handlers=[ | ||
EvalHandlerConfig( | ||
name="exactmatrix", | ||
execution_time="after_pipeline", | ||
models="matrix", | ||
datasets=["arxiv-test"], | ||
strategy=SlicingEvalStrategyConfig(eval_every="1d", eval_start_from=0, eval_end_at=1400000), | ||
) | ||
], | ||
after_pipeline_evaluation_workers=2, | ||
after_training_evaluation_workers=2, | ||
device=gpu_device, | ||
result_writers=["json"], | ||
datasets=[ | ||
EvalDataConfig( | ||
dataset_id=dataset, | ||
bytes_parser_function=( | ||
"import torch\n" | ||
"def bytes_parser_function(data: memoryview) -> torch.Tensor:\n" | ||
" return str(data, 'utf8')" | ||
), | ||
tokenizer="DistilBertTokenizerTransform", | ||
batch_size=256, | ||
dataloader_workers=1, | ||
metrics=[ | ||
AccuracyMetricConfig( | ||
evaluation_transformer_function=( | ||
"import torch\n" | ||
"def evaluation_transformer_function(model_output: torch.Tensor) -> torch.Tensor:\n" | ||
" return torch.argmax(model_output, dim=-1)" | ||
), | ||
topn=1, | ||
), | ||
AccuracyMetricConfig(evaluation_transformer_function="", topn=2), | ||
AccuracyMetricConfig(evaluation_transformer_function="", topn=5), | ||
F1ScoreMetricConfig( | ||
evaluation_transformer_function=( | ||
"import torch\n" | ||
"def evaluation_transformer_function(model_output: torch.Tensor) -> torch.Tensor:\n" | ||
" return torch.argmax(model_output, dim=-1)" | ||
), | ||
num_classes=172, | ||
average="weighted", | ||
), | ||
F1ScoreMetricConfig( | ||
evaluation_transformer_function=( | ||
"import torch\n" | ||
"def evaluation_transformer_function(model_output: torch.Tensor) -> torch.Tensor:\n" | ||
" return torch.argmax(model_output, dim=-1)" | ||
), | ||
num_classes=172, | ||
average="macro", | ||
), | ||
F1ScoreMetricConfig( | ||
evaluation_transformer_function=( | ||
"import torch\n" | ||
"def evaluation_transformer_function(model_output: torch.Tensor) -> torch.Tensor:\n" | ||
" return torch.argmax(model_output, dim=-1)" | ||
), | ||
num_classes=172, | ||
average="micro", | ||
), | ||
], | ||
) | ||
for dataset in ["arxiv", "arxiv-test"] | ||
], | ||
), | ||
) |
Oops, something went wrong.