@@ -66,6 +66,8 @@ def on_step_end(self, optimizer):
6666 self .step += 1
6767 time = datetime .now ().strftime ("%d-%m-%y %H:%M:%S:%f" )
6868 self .logger .critical (f"{ time } - ✨Step { self .step } ended✨" )
69+ time = datetime .now ().strftime ("%d-%m-%y %H:%M:%S:%f" )
70+ self .logger .critical (f"{ time } - ✨Step { self .step } ended✨" )
6971 for i , (prompt , score ) in enumerate (zip (optimizer .prompts , optimizer .scores )):
7072 self .logger .critical (f"*** Prompt { i } : Score: { score } " )
7173 self .logger .critical (f"{ prompt } " )
@@ -80,40 +82,48 @@ def on_train_end(self, optimizer, logs=None):
8082 logs: Additional information to log.
8183 """
8284 time = datetime .now ().strftime ("%d-%m-%y %H:%M:%S:%f" )
85+ time = datetime .now ().strftime ("%d-%m-%y %H:%M:%S:%f" )
8386 if logs is None :
8487 self .logger .critical (f"{ time } - Training ended" )
88+ self .logger .critical (f"{ time } - Training ended" )
8589 else :
8690 self .logger .critical (f"{ time } - Training ended - { logs } " )
91+ self .logger .critical (f"{ time } - Training ended - { logs } " )
8792
8893 return True
8994
9095
91- class CSVCallback (Callback ):
92- """Callback for saving optimization progress to a CSV file.
96+ class FileOutputCallback (Callback ):
97+ """Callback for saving optimization progress to a specified file type .
9398
94- This callback saves prompts and scores at each step to a CSV file.
99+ This callback saves information about each step to a file.
95100
96101 Attributes:
97- dir (str): Directory the CSV file is saved to.
102+ dir (str): Directory the file is saved to.
98103 step (int): The current step number.
104+ file_type (str): The type of file to save the output to.
99105 """
100106
101- def __init__ (self , dir ):
102- """Initialize the CSVCallback .
107+ def __init__ (self , dir , file_type : Literal [ "parquet" , "csv" ] = "parquet" ):
108+ """Initialize the FileOutputCallback .
103109
104110 Args:
105111 dir (str): Directory the CSV file is saved to.
112+ file_type (str): The type of file to save the output to.
106113 """
107114 if not os .path .exists (dir ):
108115 os .makedirs (dir )
109116
110- self .dir = dir
111- self .dir = dir
117+ self .file_type = file_type
118+
119+ if file_type == "parquet" :
120+ self .path = dir + "/step_results.parquet"
121+ elif file_type == "csv" :
122+ self .path = dir + "/step_results.csv"
123+ else :
124+ raise ValueError (f"File type { file_type } not supported." )
125+
112126 self .step = 0
113- self .input_tokens = 0
114- self .output_tokens = 0
115- self .start_time = datetime .now ()
116- self .step_time = datetime .now ()
117127
118128 def on_step_end (self , optimizer ):
119129 """Save prompts and scores to csv.
@@ -125,47 +135,24 @@ def on_step_end(self, optimizer):
125135 df = pd .DataFrame (
126136 {
127137 "step" : [self .step ] * len (optimizer .prompts ),
128- "input_tokens" : [optimizer .meta_llm .input_token_count - self . input_tokens ] * len (optimizer .prompts ),
129- "output_tokens" : [optimizer .meta_llm .output_token_count - self . output_tokens ] * len (optimizer .prompts ),
130- "time_elapsed " : [( datetime .now () - self . step_time ).total_seconds ()] * len (optimizer .prompts ),
138+ "input_tokens" : [optimizer .meta_llm .input_token_count ] * len (optimizer .prompts ),
139+ "output_tokens" : [optimizer .meta_llm .output_token_count ] * len (optimizer .prompts ),
140+ "time " : [datetime .now ().total_seconds ()] * len (optimizer .prompts ),
131141 "score" : optimizer .scores ,
132142 "prompt" : optimizer .prompts ,
133143 }
134144 )
135- self .step_time = datetime .now ()
136- self .input_tokens = optimizer .meta_llm .input_token_count
137- self .output_tokens = optimizer .meta_llm .output_token_count
138-
139- if not os .path .exists (self .dir + "step_results.csv" ):
140- df .to_csv (self .dir + "step_results.csv" , index = False )
141- else :
142- df .to_csv (self .dir + "step_results.csv" , mode = "a" , header = False , index = False )
143-
144- return True
145-
146- def on_train_end (self , optimizer ):
147- """Called at the end of training.
148-
149- Args:
150- optimizer: The optimizer object that called the callback.
151- """
152- df = pd .DataFrame (
153- dict (
154- steps = self .step ,
155- input_tokens = optimizer .meta_llm .input_token_count ,
156- output_tokens = optimizer .meta_llm .output_token_count ,
157- time_elapsed = (datetime .now () - self .start_time ).total_seconds (),
158- time = datetime .now (),
159- score = np .array (optimizer .scores ).mean (),
160- best_prompts = str (optimizer .prompts ),
161- ),
162- index = [0 ],
163- )
164145
165- if not os .path .exists (self .dir + "train_results.csv" ):
166- df .to_csv (self .dir + "train_results.csv" , index = False )
167- else :
168- df .to_csv (self .dir + "train_results.csv" , mode = "a" , header = False , index = False )
146+ if self .file_type == "parquet" :
147+ if self .step == 1 :
148+ df .to_parquet (self .path , index = False )
149+ else :
150+ df .to_parquet (self .path , mode = "a" , index = False )
151+ elif self .file_type == "csv" :
152+ if self .step == 1 :
153+ df .to_csv (self .path , index = False )
154+ else :
155+ df .to_csv (self .path , mode = "a" , header = False , index = False )
169156
170157 return True
171158
0 commit comments