11import itertools
22import logging
33from timeit import default_timer as timer
4- from typing import List , Tuple
4+ from typing import Dict , List , Tuple
55
66import nltk
77import torch
2020logger = logging .getLogger (__name__ )
2121
2222
23- def evaluate_documentation (predictions : List [List [ str ] ], references : List [str ]) -> Tuple [float , float ]:
23+ def evaluate_documentation (predictions : List [str ], references : List [str ]) -> Tuple [float , float ]:
2424 nltk .download ("wordnet" )
2525
2626 # Tokenize references
2727 tokenized_references = [ref .split () for ref in references ]
2828 # Currently there is only 1 prediction for 1 reference, need to avg in future
29- tokenized_predictions = [pred [ 0 ] .split () if pred else [] for pred in predictions ]
29+ tokenized_predictions = [pred .split () if pred else [] for pred in predictions ]
3030
3131 # Calculate BLEU score with smoothing function
3232 # SmoothingFunction().method1 is used to avoid zero scores for n-grams not found in the reference.
@@ -55,16 +55,13 @@ def eval(
5555 param : List [str ] = typer .Option (
5656 [], help = "Additional float parameters to pass to the model as name=float pairs"
5757 ),
58- ) -> List [List [str ]]:
59- import jsonlines
58+ ) -> Tuple [List [str ], float , float ]:
6059 import mlflow
6160
6261 mlflow .autolog ()
63-
64- param_dict = {pair [0 ]: float (pair [1 ]) for pair in [pair .split ("=" ) for pair in param ]}
6562 run = mlflow .active_run ()
63+ param_dict = {pair [0 ]: float (pair [1 ]) for pair in [pair .split ("=" ) for pair in param ]}
6664
67- prompt = PROMPTS [prompt_id ]
6865 if run is None :
6966 run = mlflow .start_run ()
7067 with run :
@@ -75,36 +72,51 @@ def eval(
7572 mlflow .log_param ("prompt_id" , prompt_id )
7673 mlflow .log_param ("model_path" , model_path )
7774 mlflow .log_param ("data_file" , data_file )
75+ prompt = PROMPTS [prompt_id ]
76+ pred = Predictor (model_path )
77+ return eval_prompt (data_file , pred , prompt , param_dict )
78+
79+
80+ def load_data (data_file : str ) -> Tuple [List [str ], List [str ]]:
81+ import jsonlines
82+
83+ with jsonlines .open (data_file ) as reader :
84+ items = [item for item in reader ]
85+ inputs = [f"{ item ['instruction' ]} " for item in items ]
86+ labels = [item ["output" ] for item in items ]
87+ return inputs , labels
88+
89+
90+ def eval_prompt (
91+ data_file : str , pred : Predictor , prompt : str , param_dict : Dict [str , float ]
92+ ) -> Tuple [List [str ], float , float ]:
93+ import mlflow
7894
79- with jsonlines .open (data_file ) as reader :
80- items = [item for item in reader ]
81- inputs = [item ["instruction" ] for item in items ]
82- labels = [item ["output" ] for item in items ]
83-
84- pred = Predictor (model_path )
85- timer_start = timer ()
86- predictions = pred .predict (prompt , inputs , ** param_dict )
87- timer_end = timer ()
88- bleu , meteor = evaluate_documentation (predictions , labels )
89- pred_time = timer_end - timer_start
90- mlflow .log_metric ("prediction_time/doc" , pred_time / (len (inputs )))
91- for i in range (len (inputs )):
92- mlflow .log_text (labels [i ], f"label_{ i } .txt" )
93- mlflow .log_text (inputs [i ], f"input_{ i } .py" )
94- for j in range (len (predictions [i ])):
95- mlflow .log_text (predictions [i ][j ], f"prediction_{ i } _{ j } .txt" )
96- mlflow .log_text ("bleu_score is " , str (bleu ))
97- mlflow .log_text ("meteor_score is " , str (meteor ))
98-
99- # flatten predictions for counting tokens
100- predictions_flat = list (itertools .chain .from_iterable (predictions ))
101- tokens = pred .tokenize (predictions_flat )["input_ids" ]
102- total_tokens = sum ([len (token ) for token in tokens ])
103- mlflow .log_metric ("total_tokens" , total_tokens )
104- mlflow .log_metric ("tokens/sec" , total_tokens / pred_time )
105- mlflow .log_metric ("bleu_score" , round (bleu , 5 ))
106- mlflow .log_metric ("meteor_score" , round (meteor , 5 ))
107- return predictions
95+ inputs , labels = load_data (data_file )
96+
97+ timer_start = timer ()
98+ predictions = pred .predict (prompt , inputs , ** param_dict )
99+ timer_end = timer ()
100+ bleu , meteor = evaluate_documentation (predictions , labels )
101+ pred_time = timer_end - timer_start
102+ mlflow .log_metric ("prediction_time/doc" , pred_time / (len (inputs )))
103+ for i in range (len (inputs )):
104+ mlflow .log_text (labels [i ], f"label_{ i } .txt" )
105+ mlflow .log_text (inputs [i ], f"input_{ i } .py" )
106+ for j in range (len (predictions [i ])):
107+ mlflow .log_text (predictions [i ][j ], f"prediction_{ i } _{ j } .txt" )
108+ mlflow .log_text ("bleu_score is " , str (bleu ))
109+ mlflow .log_text ("meteor_score is " , str (meteor ))
110+
111+ # flatten predictions for counting tokens
112+ predictions_flat = list (itertools .chain .from_iterable (predictions ))
113+ tokens = pred .tokenize (predictions_flat )["input_ids" ]
114+ total_tokens = sum ([len (token ) for token in tokens ])
115+ mlflow .log_metric ("total_tokens" , total_tokens )
116+ mlflow .log_metric ("tokens/sec" , total_tokens / pred_time )
117+ mlflow .log_metric ("bleu_score" , round (bleu , 5 ))
118+ mlflow .log_metric ("meteor_score" , round (meteor , 5 ))
119+ return predictions , bleu , meteor
108120
109121
110122@app .command ()
@@ -126,7 +138,7 @@ def generate(
126138 prompt = PROMPTS [prompt_id ]
127139 pred = Predictor (model_path )
128140 # grab first result since we only passed one input
129- predictions = pred .predict (prompt , [input ], ** param_dict )[ 0 ]
141+ predictions = pred .predict (prompt , [input ], ** param_dict )
130142 assert len (predictions ) == 1 , f"Expected only one output, got { len (predictions )} "
131143 logger .info (f"Writing output to { output } " )
132144 with open (output , "w" ) as f :
0 commit comments