|
12 | 12 | from autora.doc.runtime.prompts import PROMPTS, PromptIds |
13 | 13 | from autora.doc.util import get_prompts_from_file |
14 | 14 |
|
15 | | -app = typer.Typer() |
16 | 15 | logging.basicConfig( |
17 | 16 | level=logging.INFO, |
18 | 17 | format="%(asctime)s %(levelname)s %(module)s.%(funcName)s(): %(message)s", |
19 | 18 | ) |
20 | 19 | logger = logging.getLogger(__name__) |
| 20 | +logger.info(f"Torch version: {torch.__version__} , Cuda available: {torch.cuda.is_available()}") |
| 21 | +app = typer.Typer() |
21 | 22 |
|
22 | 23 |
|
23 | 24 | @app.command(help="Evaluate a model for code-to-documentation generation for all prompts in the prompts_file") |
@@ -83,9 +84,9 @@ def eval( |
83 | 84 | mlflow.log_param("prompt_id", prompt_id) |
84 | 85 | mlflow.log_param("model_path", model_path) |
85 | 86 | mlflow.log_param("data_file", data_file) |
86 | | - prompt = PROMPTS[prompt_id] |
87 | | - pred = Predictor(model_path) |
88 | | - return eval_prompt(data_file, pred, prompt, param_dict) |
| 87 | + prompt = PROMPTS[prompt_id] |
| 88 | + pred = Predictor(model_path) |
| 89 | + return eval_prompt(data_file, pred, prompt, param_dict) |
89 | 90 |
|
90 | 91 |
|
91 | 92 | def load_data(data_file: str) -> Tuple[List[str], List[str]]: |
@@ -175,6 +176,4 @@ def read_text(file: str) -> str: |
175 | 176 |
|
176 | 177 |
|
177 | 178 | if __name__ == "__main__": |
178 | | - logger.info(f"Torch version: {torch.__version__} , Cuda available: {torch.cuda.is_available()}") |
179 | | - |
180 | 179 | app() |
0 commit comments