22from timeit import default_timer as timer
33from typing import List
44
5- import jsonlines
6- import mlflow
75import torch
86import typer
97
1917
2018
2119@app .command ()
22- def predict (
23- data_file : str , model_path : str , sys_id : SystemPrompts , instruc_id : InstructionPrompts
24- ) -> List [str ]:
20+ def eval (data_file : str , model_path : str , sys_id : SystemPrompts , instruc_id : InstructionPrompts ) -> List [str ]:
21+ import jsonlines
22+ import mlflow
23+
24+ mlflow .autolog ()
25+
2526 run = mlflow .active_run ()
2627
2728 sys_prompt = SYS [sys_id ]
@@ -33,7 +34,6 @@ def predict(
3334 logger .info (f"running predict with { data_file } " )
3435 logger .info (f"model path: { model_path } " )
3536
36- # predictions = []
3737 with jsonlines .open (data_file ) as reader :
3838 items = [item for item in reader ]
3939 inputs = [item ["instruction" ] for item in items ]
@@ -57,6 +57,26 @@ def predict(
5757 return predictions
5858
5959
60+ @app .command ()
61+ def generate (
62+ python_file : str ,
63+ model_path : str = "meta-llama/llama-2-7b-chat-hf" ,
64+ output : str = "output.txt" ,
65+ sys_id : SystemPrompts = SystemPrompts .SYS_1 ,
66+ instruc_id : InstructionPrompts = InstructionPrompts .INSTR_SWEETP_1 ,
67+ ) -> None :
68+ with open (python_file , "r" ) as f :
69+ inputs = [f .read ()]
70+ sys_prompt = SYS [sys_id ]
71+ instr_prompt = INSTR [instruc_id ]
72+ pred = Predictor (model_path )
73+ predictions = pred .predict (sys_prompt , instr_prompt , inputs )
74+ assert len (predictions ) == 1 , f"Expected only one output, got { len (predictions )} "
75+ logger .info (f"Writing output to { output } " )
76+ with open (output , "w" ) as f :
77+ f .write (predictions [0 ])
78+
79+
6080@app .command ()
6181def import_model (model_name : str ) -> None :
6282 pass
@@ -65,5 +85,4 @@ def import_model(model_name: str) -> None:
6585if __name__ == "__main__" :
6686 logger .info (f"Torch version: { torch .__version__ } , Cuda available: { torch .cuda .is_available ()} " )
6787
68- mlflow .autolog ()
6988 app ()
0 commit comments