@@ -12,9 +12,7 @@ def get_args():
12
12
parser = argparse .ArgumentParser ()
13
13
parser .add_argument ("--n_samples" , type = int , default = 100_000 )
14
14
parser .add_argument ("--device" , type = str , default = "cuda" )
15
- parser .add_argument (
16
- "--save_load_path" , type = str , default = "./fw_afaik_topics_100k_guard_s"
17
- )
15
+ parser .add_argument ("--save_load_path" , type = str , default = "./cc_100k" )
18
16
parser .add_argument (
19
17
"--input_dataset" , type = str , default = "HuggingFaceFW/FW-12-12-2023-CC-2023-06"
20
18
)
@@ -26,7 +24,9 @@ def get_args():
26
24
help = "Run the pipeline from scratch/load existing model to build hf datasets or to infer on new texts" ,
27
25
)
28
26
parser .add_argument (
29
- "--inference_repo_name" , type = str , default = "infer_fw_on_ultrachat" ,
27
+ "--inference_repo_name" ,
28
+ type = str ,
29
+ default = "infer_fw_on_ultrachat" ,
30
30
help = "HF repo name for the clusters dataset in inference mode" ,
31
31
)
32
32
parser .add_argument (
@@ -47,7 +47,7 @@ def build_hf_data_clusters(cc, texts=None, labels=None):
47
47
texts: list of texts used for inference mode.
48
48
labels: list of cluster labels corresponding to the texts for inference mode.
49
49
50
- If `texts` and `labels` are not provided, the function will use the data available in `cc`
50
+ If `texts` and `labels` are not provided, the function will use the data available in `cc`
51
51
to construct the dataset. Otherwise it will run in inference mode on texts.
52
52
"""
53
53
cluster_data = []
@@ -96,6 +96,19 @@ def build_hf_data_files(cc):
96
96
return ds
97
97
98
98
99
+ def build_and_push (cc , args ):
100
+ """Build HF files & clusters datasts and push them to the hub"""
101
+ print ("Building HF datasets..." )
102
+ ds = build_hf_data_clusters (cc )
103
+ data_clusters = build_hf_data_files (cc )
104
+ print (f"Files dataset { ds } \n Clusters dataset { data_clusters } " )
105
+
106
+ repo_name = args .save_load_path .split ("/" )[- 1 ]
107
+ print (f"Pushing to the hub at { repo_name } ..." )
108
+ ds .push_to_hub (f"{ args .username } /{ repo_name } " , private = True )
109
+ data_clusters .push_to_hub (f"{ args .username } /{ repo_name } _clusters" , private = True )
110
+
111
+
99
112
def main ():
100
113
args = get_args ()
101
114
cc = ClusterClassifier (embed_device = args .device )
@@ -112,41 +125,34 @@ def main():
112
125
cc .save (args .save_load_path )
113
126
print (f"Saved clusters in { args .save_load_path } ." )
114
127
128
+ if args .build_hf_ds :
129
+ build_and_push (cc , args )
130
+
115
131
elif args .mode == "infer" :
116
- cc .load (args .save_load_path )
117
- # run inference mode on texts using an existing pipeline
118
- cc .load (args .save_load_path )
119
- print (
120
- f"Running inference on { args .n_samples } samples of { args .input_dataset } using clusters in { args .save_load_path } "
121
- )
122
- texts = load_dataset (args .input_dataset , split = "train" , token = True ).select (
123
- range (args .n_samples )
124
- )[args .input_content ]
125
- cluster_labels , _ = cc .infer (texts , top_k = 1 )
126
- ds = build_hf_data_clusters (cc , texts , cluster_labels )
127
- print ("Pushing to hub..." )
128
- ds .push_to_hub (f"{ args .username } /{ args .inference_repo_name } " , private = True )
132
+ # Run inference mode on texts using an existing pipeline
133
+ cc .load (args .save_load_path )
134
+ print (
135
+ f"Running inference on { args .n_samples } samples of { args .input_dataset } using clusters in { args .save_load_path } ."
136
+ )
137
+ texts = load_dataset (args .input_dataset , split = "train" , token = True ).select (
138
+ range (args .n_samples )
139
+ )[args .input_content ]
140
+ cluster_labels , _ = cc .infer (texts , top_k = 1 )
129
141
130
- else :
131
- if not args .build_hf_ds :
132
- print ("Using mode=load but build_hf_ds is False, nothing to be done." )
142
+ ds = build_hf_data_clusters (cc , texts , cluster_labels )
143
+ target_repo = {args .username } / {args .inference_repo_name }
144
+ print (f"Pushing to hub at { target_repo } ..." )
145
+ ds .push_to_hub (f"{ target_repo } " , private = True )
133
146
134
- if args . build_hf_ds :
135
- print ( "Building HF clustering datasets..." )
136
- if args .mode == "load" :
147
+ else :
148
+ # Load existing pipeline
149
+ if args .build_hf_ds :
137
150
cc .load (args .save_load_path )
138
- ds = build_hf_data_clusters (cc )
139
- data_clusters = build_hf_data_files (cc )
140
- print (f"Files dataset { ds } \n Clusters dataset { data_clusters } " )
141
-
142
- repo_name = args .save_load_path .split ("/" )[- 1 ]
143
- print (f"Pushing to the hub at { repo_name } ..." )
144
- ds .push_to_hub (f"{ args .username } /{ repo_name } " , private = True )
145
- data_clusters .push_to_hub (
146
- f"{ args .username } /{ repo_name } _clusters" , private = True
147
- )
151
+ build_and_push (cc , args )
152
+ else :
153
+ print ("Using mode=load but build_hf_ds is False, nothing to be done." )
148
154
149
- print ("Done 🎉! " )
155
+ print ("Done 🎉" )
150
156
151
157
152
158
if __name__ == "__main__" :
0 commit comments