Skip to content
17 changes: 17 additions & 0 deletions script/get-dataset-mixtral/_cm.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,23 @@
"CM_DOWNLOAD_CHECKSUM": "78823c13e0e73e518872105c4b09628b"
},
"group": "download-source"
},
"generate-test-data.#":{
"base": [
"mlcommons-storage"
],
"env":{
"CM_DATASET_MIXTRAL_TEST_DATA_SIZE": "#",
"CM_DATASET_MIXTRAL_GENERATE_TEST_DATA": "yes"
},
"deps": [
{
"tags": "get,generic-python-lib,_package.pandas"
},
{
"tags": "get,python3"
}
]
}
}
}
6 changes: 6 additions & 0 deletions script/get-dataset-mixtral/customize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ def preprocess(i):

env = i['env']

if env.get('CM_DATASET_MIXTRAL_GENERATE_TEST_DATA', '') == "yes":
env['CM_DATASET_MIXTRAL_TEST_DATA_GENERATED_PATH'] = os.path.join(os.getcwd(), "mixtral-test-dataset.pkl")

return {'return':0}


Expand All @@ -15,4 +18,7 @@ def postprocess(i):

env['CM_DATASET_MIXTRAL_PREPROCESSED_PATH'] = env['CM_DATASET_PREPROCESSED_PATH']

if env.get('CM_DATASET_MIXTRAL_GENERATE_TEST_DATA', '') == "yes":
env['CM_DATASET_MIXTRAL_PREPROCESSED_PATH'] = env['CM_DATASET_MIXTRAL_TEST_DATA_GENERATED_PATH']

return {'return':0}
40 changes: 40 additions & 0 deletions script/get-dataset-mixtral/generate-test-dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pandas as pd
import argparse
import os

def main():
# Set up argument parser
parser = argparse.ArgumentParser(description="Sample test dataset from the original dataset.")
parser.add_argument('--dataset-path', required=True, help="Path to the input dataset (pickle file).")
parser.add_argument('--output-path', default=os.path.join(os.getcwd(),"mixtral-test-dataset.pkl"), help="Path to save the output dataset (pickle file).")
parser.add_argument('--samples', default=2, help="Number of entries to be extracted from each group.")

args = parser.parse_args()
dataset_path = args.dataset_path
output_path = args.output_path
no_of_samples = int(args.samples)

try:
# Load the dataset from the specified pickle file
print(f"Loading dataset from {dataset_path}...")
df = pd.read_pickle(dataset_path)

# Check if 'group' column exists
if 'dataset' not in df.columns:
raise ValueError("The input dataset must contain a 'dataset' column to identify data set groups.")

# Sample 2 entries from each group
print(f"Sampling {no_of_samples} entries from each group...")
sampled_df = df.groupby('dataset').apply(lambda x: x.sample(n=no_of_samples)).reset_index(drop=True)

# Save the sampled dataset to the specified output path
print(f"Saving the sampled dataset to {output_path}...")
sampled_df.to_pickle(output_path)

print("Dataset processing and saving completed successfully!")
except Exception as e:
print(f"Error: {e}")
exit(1)

if __name__ == '__main__':
main()
5 changes: 5 additions & 0 deletions script/get-dataset-mixtral/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/bin/bash

if [[ "$CM_DATASET_MIXTRAL_GENERATE_TEST_DATA" == "yes" ]]; then
${CM_PYTHON_BIN_WITH_PATH} ${CM_TMP_CURRENT_SCRIPT_PATH}/generate-test-dataset.py --dataset-path ${CM_DATASET_PREPROCESSED_PATH} --output-path ${CM_DATASET_MIXTRAL_TEST_DATA_GENERATED_PATH} --samples ${CM_DATASET_MIXTRAL_TEST_DATA_SIZE}
fi