Skip to content

Commit

Permalink
add RobustL(1|2)Loss aliases for backwards compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
janosh committed Feb 6, 2024
1 parent a7344c1 commit a8da6c4
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 18 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ ci:

repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.9
rev: v0.2.1
hooks:
- id: ruff
args: [--fix]
Expand Down Expand Up @@ -42,7 +42,7 @@ repos:
- id: format-ipy-cells

- repo: https://github.com/kynan/nbstripout
rev: 0.6.1
rev: 0.7.1
hooks:
- id: nbstripout
args: [--drop-empty-cells, --keep-output]
5 changes: 5 additions & 0 deletions aviary/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,8 @@ def robust_l2_loss(pred_mean: Tensor, pred_log_std: Tensor, target: Tensor) -> T
"""
loss = 0.5 * (pred_mean - target) ** 2 * torch.exp(-2 * pred_log_std) + pred_log_std
return torch.mean(loss)


# aliases for backwards compatibility
RobustL1Loss = robust_l1_loss
RobustL2Loss = robust_l2_loss
2 changes: 1 addition & 1 deletion aviary/wrenformer/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def df_to_in_mem_dataloader(
for idx, tensor in enumerate(initial_embeddings):
inputs[idx] = tensor.to(device)

ids = (df[id_col] if id_col in df else df.index).to_numpy()
ids = df.get(id_col, df.index).to_numpy()
return InMemoryDataLoader(
[inputs, targets, ids], collate_fn=collate_batch, **kwargs
)
10 changes: 5 additions & 5 deletions examples/notebooks/Roost.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": null,
"id": "0",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -20,7 +20,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": null,
"id": "1",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -42,7 +42,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": null,
"id": "2",
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -74,7 +74,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": null,
"id": "3",
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -135,7 +135,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": null,
"id": "4",
"metadata": {},
"outputs": [],
"source": [
Expand Down
10 changes: 5 additions & 5 deletions examples/notebooks/Wren.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": null,
"id": "0",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -20,7 +20,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": null,
"id": "1",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -42,7 +42,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": null,
"id": "2",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -69,7 +69,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": null,
"id": "3",
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -130,7 +130,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": null,
"id": "4",
"metadata": {},
"outputs": [],
"source": [
Expand Down
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ no_implicit_optional = false
[tool.ruff]
target-version = "py38"
include = ["**/pyproject.toml", "*.ipynb", "*.py", "*.pyi"]
select = [
lint.select = [
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"D", # pydocstyle
Expand Down Expand Up @@ -103,7 +103,7 @@ select = [
"W", # pycodestyle warning
"YTT", # flake8-2020
]
ignore = [
lint.ignore = [
"C408", # Unnecessary dict call - rewrite as a literal
"D100", # Missing docstring in public module
"D104", # Missing docstring in public package
Expand All @@ -114,9 +114,9 @@ ignore = [
"PLR", # pylint refactor
"PT006", # pytest-parametrize-names-wrong-type
]
pydocstyle.convention = "google"
isort.known-third-party = ["wandb"]
lint.pydocstyle.convention = "google"
lint.isort.known-third-party = ["wandb"]

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
"tests/*" = ["D"]
"examples/notebooks/*.py" = ["E402"]

0 comments on commit a8da6c4

Please sign in to comment.