Skip to content

Commit f02625b

Browse files
authored
Merge pull request #1 from huggingface/add-pipeline-script
Add pipeline script
2 parents 8bfd6aa + 3708845 commit f02625b

File tree

3 files changed

+179
-7
lines changed

3 files changed

+179
-7
lines changed

README.md

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,20 @@
33
## Install
44

55
```bash
6-
pip install sklearn umap sentence_transformers faiss-cpu plotly matplotlib datasets
6+
pip install scikit-learn umap-learn sentence_transformers faiss-cpu plotly matplotlib datasets
77
```
88

99
## Usage
1010

1111
Run pipeline and visualize results:
1212

1313
```python
14-
from src.text_cluster import ClusterClassifier
14+
from src.text_clustering import ClusterClassifier
1515
from datasets import load_dataset
1616

1717
SAMPLE = 100_000
1818

19-
texts = load_dataset("HuggingFaceFW/FW-12-12-2023-CC-2023-06").select(range(SAMPLE))["content"]
19+
texts = load_dataset("HuggingFaceFW/FW-12-12-2023-CC-2023-06", split="train").select(range(SAMPLE))["content"]
2020

2121
cc = ClusterClassifier(embed_device="mps")
2222

@@ -32,7 +32,7 @@ cc.save("./cc_100k")
3232

3333
Load classifier and run inference:
3434
```python
35-
from src.text_cluster import ClusterClassifier
35+
from src.text_clustering import ClusterClassifier
3636

3737
cc = ClusterClassifier(embed_device="mps")
3838

@@ -44,4 +44,15 @@ cc.show()
4444

4545
# classify new texts with k-nearest neighbour search
4646
cluster_labels, embeddings = cc.infer(some_texts, top_k=1)
47-
```
47+
```
48+
49+
You can also run the pipeline using a script with:
50+
```bash
51+
# run a new pipeline
52+
python run_pipeline.py --mode run --save_load_path './cc_100k' --n_samples 100000 --build_hf_ds
53+
# load existing pipeline
54+
python run_pipeline.py --mode load --save_load_path './cc_100k' --build_hf_ds
55+
# inference mode on new texts from an input dataset
56+
python run_pipeline.py --mode infer --save_load_path './cc_100k' --n_samples <NB_INFERENCE_SAMPLES> --input_dataset <HF_DATA_FOR_INFERENCE>
57+
```
58+
The `build_hf_ds` flag builds and pushes HF datasets, for the files and clusters, that can be directly used in the FW visualization space. In `infer` mode, we push the clusters dataset by default.

run_pipeline.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import argparse
2+
import textwrap
3+
4+
import pandas as pd
5+
import numpy as np
6+
from datasets import Dataset, load_dataset
7+
8+
from src.text_clustering import ClusterClassifier
9+
10+
11+
def get_args():
12+
parser = argparse.ArgumentParser()
13+
parser.add_argument("--n_samples", type=int, default=100_000)
14+
parser.add_argument("--device", type=str, default="cuda")
15+
parser.add_argument("--save_load_path", type=str, default="./cc_100k")
16+
parser.add_argument(
17+
"--input_dataset", type=str, default="HuggingFaceFW/FW-12-12-2023-CC-2023-06"
18+
)
19+
parser.add_argument("--input_content", type=str, default="content")
20+
parser.add_argument(
21+
"--mode",
22+
choices=["run", "load", "infer"],
23+
default="run",
24+
help="Run the pipeline from scratch/load existing model to build hf datasets or to infer on new texts",
25+
)
26+
parser.add_argument(
27+
"--inference_repo_name",
28+
type=str,
29+
default="infer_fw_on_ultrachat",
30+
help="HF repo name for the clusters dataset in inference mode",
31+
)
32+
parser.add_argument(
33+
"--build_hf_ds",
34+
action="store_true",
35+
help="Builds HF datasets used for space visualization and pushes them to the hub",
36+
)
37+
parser.add_argument("--username", type=str, default="loubnabnl")
38+
return parser.parse_args()
39+
40+
41+
def build_hf_data_clusters(cc, texts=None, labels=None):
42+
"""
43+
Build an HF dataset containing information on each cluster.
44+
45+
Args:
46+
cc: ClusterClassifier object.
47+
texts: list of texts used for inference mode.
48+
labels: list of cluster labels corresponding to the texts for inference mode.
49+
50+
If `texts` and `labels` are not provided, the function will use the data available in `cc`
51+
to construct the dataset. Otherwise it will run in inference mode on texts.
52+
"""
53+
cluster_data = []
54+
for cluster_id in cc.label2docs.keys():
55+
if cluster_id == -1:
56+
continue
57+
58+
# inference mode
59+
if texts is not None and labels is not None:
60+
labels_array = np.array(labels)
61+
files_in_cluster = np.where(labels_array == cluster_id)[0]
62+
examples = [texts[doc_id] for doc_id in files_in_cluster]
63+
else:
64+
doc_ids = cc.label2docs[cluster_id]
65+
examples = [cc.texts[doc_id] for doc_id in doc_ids]
66+
67+
cluster_info = {
68+
"cluster_id": cluster_id,
69+
"summary": cc.cluster_summaries[cluster_id],
70+
"examples": examples,
71+
}
72+
73+
if not texts:
74+
cluster_info["position"] = cc.cluster_centers[cluster_id]
75+
76+
cluster_data.append(cluster_info)
77+
78+
return Dataset.from_pandas(pd.DataFrame(cluster_data))
79+
80+
81+
def build_hf_data_files(cc):
82+
"""
83+
Build an HF dataset containing information on each file and the cluster they belong to
84+
"""
85+
86+
df = pd.DataFrame(
87+
data={
88+
"X": cc.projections[:, 0],
89+
"Y": cc.projections[:, 1],
90+
"labels": cc.cluster_labels,
91+
"content_display": [textwrap.fill(txt[:1024], 64) for txt in cc.texts],
92+
}
93+
)
94+
return Dataset.from_pandas(df)
95+
96+
97+
def build_and_push(cc, args):
98+
"""Build HF files & clusters datasts and push them to the hub"""
99+
print("Building HF datasets...")
100+
ds = build_hf_data_clusters(cc)
101+
data_clusters = build_hf_data_files(cc)
102+
print(f"Files dataset {ds}\nClusters dataset {data_clusters}")
103+
104+
repo_name = args.save_load_path.split("/")[-1]
105+
print(f"Pushing to the hub at {repo_name}...")
106+
ds.push_to_hub(f"{args.username}/{repo_name}", private=True)
107+
data_clusters.push_to_hub(f"{args.username}/{repo_name}_clusters", private=True)
108+
109+
110+
def main():
111+
args = get_args()
112+
cc = ClusterClassifier(embed_device=args.device)
113+
114+
if args.mode == "run":
115+
# Run a new pipeline on texts
116+
texts = load_dataset(args.input_dataset, split="train", token=True).select(
117+
range(args.n_samples)
118+
)[args.input_content]
119+
120+
_, _, summaries = cc.fit(texts)
121+
print(f"10 example Summaries:\n{[e for e in summaries.values()][:10]}")
122+
123+
cc.save(args.save_load_path)
124+
print(f"Saved clusters in {args.save_load_path}.")
125+
126+
if args.build_hf_ds:
127+
build_and_push(cc, args)
128+
129+
elif args.mode == "infer":
130+
# Run inference mode on texts using an existing pipeline
131+
cc.load(args.save_load_path)
132+
print(
133+
f"Running inference on {args.n_samples} samples of {args.input_dataset} using clusters in {args.save_load_path}."
134+
)
135+
texts = load_dataset(args.input_dataset, split="train", token=True).select(
136+
range(args.n_samples)
137+
)[args.input_content]
138+
cluster_labels, _ = cc.infer(texts, top_k=1)
139+
140+
ds = build_hf_data_clusters(cc, texts, cluster_labels)
141+
target_repo = {args.username} / {args.inference_repo_name}
142+
print(f"Pushing to hub at {target_repo}...")
143+
ds.push_to_hub(f"{target_repo}", private=True)
144+
145+
else:
146+
# Load existing pipeline
147+
if args.build_hf_ds:
148+
cc.load(args.save_load_path)
149+
build_and_push(cc, args)
150+
else:
151+
print("Using mode=load but build_hf_ds is False, nothing to be done.")
152+
153+
print("Done 🎉")
154+
155+
156+
if __name__ == "__main__":
157+
main()

src/text_clustering.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import os
1515
import json
1616
from collections import Counter
17+
from tqdm import tqdm
1718

1819
logging.basicConfig(level=logging.INFO)
1920

@@ -127,7 +128,7 @@ def infer(self, texts, top_k=1):
127128

128129
dist, neighbours = self.faiss_index.search(embeddings, top_k)
129130
inferred_labels = []
130-
for i in range(embeddings.shape[0]):
131+
for i in tqdm(range(embeddings.shape[0])):
131132
labels = [self.cluster_labels[doc] for doc in neighbours[i]]
132133
inferred_labels.append(Counter(labels).most_common(1)[0][0])
133134

@@ -198,7 +199,10 @@ def save(self, folder):
198199

199200
with open(f'{folder}/texts.json', 'w') as f:
200201
json.dump(self.texts, f)
201-
202+
203+
with open(f"{folder}/mistral_prompt.txt", "w") as f:
204+
f.write(DEFAULT_INSTRUCTION)
205+
202206
if self.cluster_summaries is not None:
203207
with open(f'{folder}/cluster_summaries.json', 'w') as f:
204208
json.dump(self.cluster_summaries, f)

0 commit comments

Comments
 (0)