Skip to content

Commit 4c79762

Browse files
authored
chore: Update dependencies, logging (#39)
1 parent 5615549 commit 4c79762

File tree

5 files changed

+15
-11
lines changed

5 files changed

+15
-11
lines changed

azureml/conda.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ dependencies:
1010
- typer
1111
- jsonlines
1212
- accelerate>=0.24.1
13-
- bitsandbytes>=0.41.2.post2
14-
- transformers>=4.35.2
13+
- bitsandbytes>=0.42.0
14+
- transformers>=4.37.2
1515
- xformers
1616
- scipy
1717
- nltk

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ dev = [
4444
"hf_transfer",
4545
]
4646
pipelines = ["jsonlines", "mlflow", "nltk", "sentence-transformers>=2.3.1"]
47+
# NOTE: When updating dependencies, in particular cuda/azure ml, make sure to update the azureml/conda.yaml too
4748
azure = ["azureml-core", "azureml-mlflow"]
4849
cuda = ["bitsandbytes>=0.42.0", "accelerate>=0.24.1", "xformers"]
4950

src/autora/doc/pipelines/main.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
from autora.doc.runtime.prompts import PROMPTS, PromptIds
1313
from autora.doc.util import get_prompts_from_file
1414

15-
app = typer.Typer()
1615
logging.basicConfig(
1716
level=logging.INFO,
1817
format="%(asctime)s %(levelname)s %(module)s.%(funcName)s(): %(message)s",
1918
)
2019
logger = logging.getLogger(__name__)
20+
logger.info(f"Torch version: {torch.__version__} , Cuda available: {torch.cuda.is_available()}")
21+
app = typer.Typer()
2122

2223

2324
@app.command(help="Evaluate a model for code-to-documentation generation for all prompts in the prompts_file")
@@ -83,9 +84,9 @@ def eval(
8384
mlflow.log_param("prompt_id", prompt_id)
8485
mlflow.log_param("model_path", model_path)
8586
mlflow.log_param("data_file", data_file)
86-
prompt = PROMPTS[prompt_id]
87-
pred = Predictor(model_path)
88-
return eval_prompt(data_file, pred, prompt, param_dict)
87+
prompt = PROMPTS[prompt_id]
88+
pred = Predictor(model_path)
89+
return eval_prompt(data_file, pred, prompt, param_dict)
8990

9091

9192
def load_data(data_file: str) -> Tuple[List[str], List[str]]:
@@ -175,6 +176,4 @@ def read_text(file: str) -> str:
175176

176177

177178
if __name__ == "__main__":
178-
logger.info(f"Torch version: {torch.__version__} , Cuda available: {torch.cuda.is_available()}")
179-
180179
app()

src/autora/doc/runtime/predict_hf.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
logger = logging.getLogger(__name__)
1111

1212
quantized_models = {"meta-llama/Llama-2-7b-chat-hf": "autora-doc/Llama-2-7b-chat-hf-nf4"}
13+
non_quantized_models = {"meta-llama/Llama-2-7b-chat-hf": "autora-doc/Llama-2-7b-chat-hf"}
1314

1415

1516
def preprocess_code(code: str) -> str:
@@ -91,6 +92,7 @@ def tokenize(self, input: List[str]) -> Dict[str, List[List[int]]]:
9192
@staticmethod
9293
def get_config(model_path: str) -> Tuple[str, Dict[str, str]]:
9394
if torch.cuda.is_available():
95+
logger.info("CUDA is available, attempting to load quantized model")
9496
from transformers import BitsAndBytesConfig
9597

9698
config = {"device_map": "auto"}
@@ -108,4 +110,6 @@ def get_config(model_path: str) -> Tuple[str, Dict[str, str]]:
108110
)
109111
return model_path, config
110112
else:
111-
return model_path, {}
113+
logger.info("CUDA is not available, loading non-quantized model")
114+
mapped_path = non_quantized_models.get(model_path, model_path)
115+
return mapped_path, {}

tests/test_predict_hf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,6 @@ def test_get_config_cuda(mock: mock.Mock) -> None:
3636

3737
@mock.patch("torch.cuda.is_available", return_value=False)
3838
def test_get_config_nocuda(mock: mock.Mock) -> None:
39-
model, config = Predictor.get_config(MODEL_WITH_QUANTIZED)
40-
assert model == MODEL_WITH_QUANTIZED
39+
model, config = Predictor.get_config(MODEL_NO_QUANTIZED)
40+
assert model == MODEL_NO_QUANTIZED
4141
assert len(config) == 0

0 commit comments

Comments
 (0)