Skip to content

Commit ceaca3a

Browse files
committed
rearrange code
1 parent c9b28d4 commit ceaca3a

File tree

2 files changed

+42
-36
lines changed

2 files changed

+42
-36
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,4 @@ python run_pipeline.py --mode load --save_load_path './cc_100k' --build_hf_ds
5555
# inference mode on new texts from an input dataset
5656
python run_pipeline.py --mode infer --save_load_path './cc_100k' --n_samples <NB_INFERENCE_SAMPLES> --input_dataset <HF_DATA_FOR_INFERENCE>
5757
```
58-
The `build_hf_ds` flag builds and pushes HF datasets for the files and clusters that can be directky used in the FW visualization space (we push the clusters dataset to the hub by default).
58+
The `build_hf_ds` flag builds and pushes HF datasets, for the files and clusters, that can be directly used in the FW visualization space. In `infer` mode, we push the clusters dataset by default.

run_pipeline.py

Lines changed: 41 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@ def get_args():
1212
parser = argparse.ArgumentParser()
1313
parser.add_argument("--n_samples", type=int, default=100_000)
1414
parser.add_argument("--device", type=str, default="cuda")
15-
parser.add_argument(
16-
"--save_load_path", type=str, default="./fw_afaik_topics_100k_guard_s"
17-
)
15+
parser.add_argument("--save_load_path", type=str, default="./cc_100k")
1816
parser.add_argument(
1917
"--input_dataset", type=str, default="HuggingFaceFW/FW-12-12-2023-CC-2023-06"
2018
)
@@ -26,7 +24,9 @@ def get_args():
2624
help="Run the pipeline from scratch/load existing model to build hf datasets or to infer on new texts",
2725
)
2826
parser.add_argument(
29-
"--inference_repo_name", type=str, default="infer_fw_on_ultrachat",
27+
"--inference_repo_name",
28+
type=str,
29+
default="infer_fw_on_ultrachat",
3030
help="HF repo name for the clusters dataset in inference mode",
3131
)
3232
parser.add_argument(
@@ -47,7 +47,7 @@ def build_hf_data_clusters(cc, texts=None, labels=None):
4747
texts: list of texts used for inference mode.
4848
labels: list of cluster labels corresponding to the texts for inference mode.
4949
50-
If `texts` and `labels` are not provided, the function will use the data available in `cc`
50+
If `texts` and `labels` are not provided, the function will use the data available in `cc`
5151
to construct the dataset. Otherwise it will run in inference mode on texts.
5252
"""
5353
cluster_data = []
@@ -96,6 +96,19 @@ def build_hf_data_files(cc):
9696
return ds
9797

9898

99+
def build_and_push(cc, args):
100+
"""Build HF files & clusters datasts and push them to the hub"""
101+
print("Building HF datasets...")
102+
ds = build_hf_data_clusters(cc)
103+
data_clusters = build_hf_data_files(cc)
104+
print(f"Files dataset {ds}\nClusters dataset {data_clusters}")
105+
106+
repo_name = args.save_load_path.split("/")[-1]
107+
print(f"Pushing to the hub at {repo_name}...")
108+
ds.push_to_hub(f"{args.username}/{repo_name}", private=True)
109+
data_clusters.push_to_hub(f"{args.username}/{repo_name}_clusters", private=True)
110+
111+
99112
def main():
100113
args = get_args()
101114
cc = ClusterClassifier(embed_device=args.device)
@@ -112,41 +125,34 @@ def main():
112125
cc.save(args.save_load_path)
113126
print(f"Saved clusters in {args.save_load_path}.")
114127

128+
if args.build_hf_ds:
129+
build_and_push(cc, args)
130+
115131
elif args.mode == "infer":
116-
cc.load(args.save_load_path)
117-
# run inference mode on texts using an existing pipeline
118-
cc.load(args.save_load_path)
119-
print(
120-
f"Running inference on {args.n_samples} samples of {args.input_dataset} using clusters in {args.save_load_path}"
121-
)
122-
texts = load_dataset(args.input_dataset, split="train", token=True).select(
123-
range(args.n_samples)
124-
)[args.input_content]
125-
cluster_labels, _ = cc.infer(texts, top_k=1)
126-
ds = build_hf_data_clusters(cc, texts, cluster_labels)
127-
print("Pushing to hub...")
128-
ds.push_to_hub(f"{args.username}/{args.inference_repo_name}", private=True)
132+
# Run inference mode on texts using an existing pipeline
133+
cc.load(args.save_load_path)
134+
print(
135+
f"Running inference on {args.n_samples} samples of {args.input_dataset} using clusters in {args.save_load_path}."
136+
)
137+
texts = load_dataset(args.input_dataset, split="train", token=True).select(
138+
range(args.n_samples)
139+
)[args.input_content]
140+
cluster_labels, _ = cc.infer(texts, top_k=1)
129141

130-
else:
131-
if not args.build_hf_ds:
132-
print("Using mode=load but build_hf_ds is False, nothing to be done.")
142+
ds = build_hf_data_clusters(cc, texts, cluster_labels)
143+
target_repo = {args.username} / {args.inference_repo_name}
144+
print(f"Pushing to hub at {target_repo}...")
145+
ds.push_to_hub(f"{target_repo}", private=True)
133146

134-
if args.build_hf_ds:
135-
print("Building HF clustering datasets...")
136-
if args.mode == "load":
147+
else:
148+
# Load existing pipeline
149+
if args.build_hf_ds:
137150
cc.load(args.save_load_path)
138-
ds = build_hf_data_clusters(cc)
139-
data_clusters = build_hf_data_files(cc)
140-
print(f"Files dataset {ds}\nClusters dataset {data_clusters}")
141-
142-
repo_name = args.save_load_path.split("/")[-1]
143-
print(f"Pushing to the hub at {repo_name}...")
144-
ds.push_to_hub(f"{args.username}/{repo_name}", private=True)
145-
data_clusters.push_to_hub(
146-
f"{args.username}/{repo_name}_clusters", private=True
147-
)
151+
build_and_push(cc, args)
152+
else:
153+
print("Using mode=load but build_hf_ds is False, nothing to be done.")
148154

149-
print("Done 🎉!")
155+
print("Done 🎉")
150156

151157

152158
if __name__ == "__main__":

0 commit comments

Comments
 (0)