Description
Bug description
As of lightning 2.3.0 save_hyperparameters
no longer seems to respect linked arguments.
Based on my investigation this seems to be due to #18105 which seems to have caused other errors, which were resolved, but as far as I can tell this one persists in the latest 2.4.0 and the master branch 66508ff
What version are you seeing the problem on?
v2.3, v2.4, master
How to reproduce the bug
Save the following script as: lightning_cli_save_hyperaparams_error_on_link_args.py
import torch
import torch.nn
from torch import nn
from torch.utils.data import Dataset
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.cli import LightningCLI
from typing import List, Dict
class MWE_Model(pl.LightningModule):
"""
Example:
>>> dataset = MWE_Dataset()
>>> self = MWE_Model(dataset_stats=dataset.dataset_stats)
>>> batch = [dataset[i] for i in range(2)]
>>> self.forward(batch)
"""
def __init__(self, sorting=False, dataset_stats=None, d_model=16):
super().__init__()
self.save_hyperparameters()
if dataset_stats is None:
raise ValueError('must be given dataset stats')
self.d_model = d_model
self.dataset_stats = dataset_stats
self.known_sensorchan = {
(mode['sensor'], mode['channels'], mode['num_bands'])
for mode in self.dataset_stats['known_modalities']
}
self.known_tasks = self.dataset_stats['known_tasks']
if sorting:
self.known_sensorchan = sorted(self.known_sensorchan)
self.known_tasks = sorted(self.known_tasks, key=lambda t: t['name'])
# Construct stems based on the dataset
self.stems = torch.nn.ModuleDict()
for sensor, channels, num_bands in self.known_sensorchan:
if sensor not in self.stems:
self.stems[sensor] = torch.nn.ModuleDict()
self.stems[sensor][channels] = torch.nn.Conv2d(num_bands, self.d_model, kernel_size=1)
# Backbone is small generic transformer
self.backbone = torch.nn.Transformer(
d_model=self.d_model,
nhead=4,
num_encoder_layers=2,
num_decoder_layers=2,
dim_feedforward=8,
batch_first=True
)
# Construct heads based on the dataset
self.heads = torch.nn.ModuleDict()
for head_info in self.known_tasks:
head_name = head_info['name']
head_classes = head_info['classes']
num_classes = len(head_classes)
self.heads[head_name] = torch.nn.Conv2d(
self.d_model, num_classes, kernel_size=1)
@property
def main_device(self):
""" Helper to get a device for the model. """
for key, item in self.state_dict().items():
return item.device
def tokenize_inputs(self, item: Dict):
"""
Process a single batch item's heterogeneous sequence into a flat list
if tokens for the encoder and decoder.
"""
device = self.device
input_sequence = []
for input_item in item['inputs']:
stem = self.stems[input_item['sensor_code']][input_item['channel_code']]
out = stem(input_item['data'])
tokens = out.view(self.d_model, -1).T
input_sequence.append(tokens)
output_sequence = []
for output_item in item['outputs']:
shape = tuple(output_item['dims']) + (self.d_model,)
tokens = torch.rand(shape, device=device).view(-1, self.d_model)
output_sequence.append(tokens)
if len(input_sequence) == 0 or len(output_sequence) == 0:
return None, None
in_tokens = torch.concat(input_sequence, dim=0)
out_tokens = torch.concat(output_sequence, dim=0)
return in_tokens, out_tokens
def forward(self, batch: List[Dict]) -> List[Dict]:
"""
Runs prediction on multiple batch items. The input is assumed to an
uncollated list of dictionaries, each containing information about some
heterogeneous sequence. The output is a corresponding list of
dictionaries containing the logits for each head.
"""
batch_in_tokens = []
batch_out_tokens = []
given_batch_size = len(batch)
valid_batch_indexes = []
# Prepopulate an output for each input
batch_logits = [{} for _ in range(given_batch_size)]
# Handle heterogeneous style inputs on a per-item level
for batch_idx, item in enumerate(batch):
in_tokens, out_tokens = self.tokenize_inputs(item)
if in_tokens is not None:
valid_batch_indexes.append(batch_idx)
batch_in_tokens.append(in_tokens)
batch_out_tokens.append(out_tokens)
# Some batch items might not be valid
valid_batch_size = len(valid_batch_indexes)
if not valid_batch_size:
# No inputs were valid
return batch_logits
# Pad everything into a batch to be more efficient
padding_value = -9999.0
input_seqs = nn.utils.rnn.pad_sequence(
batch_in_tokens,
batch_first=True,
padding_value=padding_value,
)
output_seqs = nn.utils.rnn.pad_sequence(
batch_out_tokens,
batch_first=True,
padding_value=padding_value,
)
input_masks = input_seqs[..., 0] > padding_value
output_masks = output_seqs[..., 0] > padding_value
input_seqs[~input_masks] = 0.
output_seqs[~output_masks] = 0.
decoded = self.backbone(
src=input_seqs,
tgt=output_seqs,
src_key_padding_mask=~input_masks,
tgt_key_padding_mask=~output_masks,
)
B = valid_batch_size
# Note output h/w is hardcoded here and uses the fact that the mwe only
# has one task; could be generalized.
oh, ow = 3, 3
decoded_features = decoded.view(B, -1, oh, ow, self.d_model)
decoded_masks = output_masks.view(B, -1, oh, ow)
# Reconstruct outputs corresponding to the inputs
for batch_idx, feat, mask in zip(valid_batch_indexes, decoded_features, decoded_masks):
item_feat = feat[mask].view(-1, oh, ow, self.d_model).permute(0, 3, 1, 2)
item_logits = batch_logits[batch_idx]
for head_name, head_layer in self.heads.items():
head_logits = head_layer(item_feat)
item_logits[head_name] = head_logits
return batch_logits
def forward_step(self, batch: List[Dict], with_loss=False, stage='unspecified'):
"""
Generic forward step used for test / train / validation
"""
batch_logits : List[Dict] = self.forward(batch)
outputs = {}
outputs['logits'] = batch_logits
if with_loss:
losses = []
valid_batch_size = 0
for item, item_logits in zip(batch, batch_logits):
if len(item_logits):
valid_batch_size += 1
for head_name, head_logits in item_logits.items():
head_target = torch.stack([label['data'] for label in item['labels'] if label['head'] == head_name], dim=0)
# dummy loss function
head_loss = torch.nn.functional.mse_loss(head_logits, head_target)
losses.append(head_loss)
total_loss = sum(losses) if len(losses) > 0 else None
if total_loss is not None:
self.log(f'{stage}_loss', total_loss, prog_bar=True, batch_size=valid_batch_size, sync_dist=True)
outputs['loss'] = total_loss
return outputs
def training_step(self, batch, batch_idx=None):
outputs = self.forward_step(batch, with_loss=True, stage='train')
if outputs['loss'] is None:
return None
return outputs
def validation_step(self, batch, batch_idx=None):
outputs = self.forward_step(batch, with_loss=True, stage='val')
return outputs
def test_step(self, batch, batch_idx=None):
outputs = self.forward_step(batch, with_loss=True, stage='test')
return outputs
class MWE_Dataset(Dataset):
"""
A dataset that produces heterogeneous outputs
Example:
>>> self = MWE_Dataset()
>>> self[0]
"""
def __init__(self, max_items_per_epoch=100):
super().__init__()
self.max_items_per_epoch = max_items_per_epoch
self.rng = np.random
self.dataset_stats = {
'known_modalities': [
{'sensor': 'sensor1', 'channels': 'rgb', 'num_bands': 3, 'dims': (23, 23)},
],
'known_tasks': [
{'name': 'class', 'classes': ['a', 'b', 'c', 'd', 'e'], 'dims': (3, 3)},
]
}
def __len__(self):
return self.max_items_per_epoch
def __getitem__(self, index) -> Dict:
"""
Returns:
Dict: containing
* inputs - a list of observations
* outputs - a list of what we want to predict
* labels - ground truth if we have it
"""
inputs = []
outputs = []
labels = []
max_timesteps_per_item = 5
num_frames = max_timesteps_per_item
p_drop_input = 0
for frame_index in range(num_frames):
had_input = 0
# In general we may have any number of observations per frame
for modality in self.dataset_stats['known_modalities']:
sensor = modality['sensor']
channels = modality['channels']
c = modality['num_bands']
h, w = modality['dims']
# Randomly include each sensorchan on each frame
if self.rng.rand() >= p_drop_input:
had_input = 1
inputs.append({
'type': 'input',
'channel_code': channels,
'sensor_code': sensor,
'frame_index': frame_index,
'data': torch.rand(c, h, w),
})
if had_input:
for task_info in self.dataset_stats['known_tasks']:
task = task_info['name']
oh, ow = task_info['dims']
oc = len(task_info['classes'])
outputs.append({
'type': 'output',
'head': task,
'frame_index': frame_index,
'dims': (oh, ow),
})
labels.append({
'type': 'label',
'head': task,
'frame_index': frame_index,
'data': torch.rand(oc, oh, ow),
})
item = {
'inputs': inputs,
'outputs': outputs,
'labels': labels,
}
return item
def make_loader(self, batch_size=1, num_workers=0, shuffle=False,
pin_memory=False):
"""
Create a dataloader option with sensible defaults for the problem
"""
loader = torch.utils.data.DataLoader(
self, batch_size=batch_size, num_workers=num_workers,
shuffle=shuffle, pin_memory=pin_memory,
collate_fn=lambda x: x
)
return loader
class MWE_Datamodule(pl.LightningDataModule):
def __init__(self, batch_size=1, num_workers=0, max_items_per_epoch=100):
super().__init__()
self.save_hyperparameters()
self.torch_datasets = {}
self.dataset_stats = None
self.dataset_kwargs = {
'max_items_per_epoch': max_items_per_epoch,
}
self._did_setup = False
def setup(self, stage):
if self._did_setup:
return
self.torch_datasets['train'] = MWE_Dataset(**self.dataset_kwargs)
self.torch_datasets['test'] = MWE_Dataset(**self.dataset_kwargs)
self.torch_datasets['vali'] = MWE_Dataset(**self.dataset_kwargs)
self.dataset_stats = self.torch_datasets['train'].dataset_stats
self._did_setup = True
print('Setup MWE_Datamodule')
print(self.__dict__)
def train_dataloader(self):
return self._make_dataloader('train', shuffle=True)
def val_dataloader(self):
return self._make_dataloader('vali', shuffle=False)
def test_dataloader(self):
return self._make_dataloader('test', shuffle=False)
@property
def train_dataset(self):
return self.torch_datasets.get('train', None)
@property
def test_dataset(self):
return self.torch_datasets.get('test', None)
@property
def vali_dataset(self):
return self.torch_datasets.get('vali', None)
def _make_dataloader(self, stage, shuffle=False):
loader = self.torch_datasets[stage].make_loader(
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
shuffle=shuffle,
pin_memory=True,
)
return loader
class MWE_LightningCLI(LightningCLI):
"""
Customized LightningCLI to ensure the expected model inputs / outputs are
coupled with the what the dataset is able to provide.
"""
def add_arguments_to_parser(self, parser):
def data_value_getter(key):
# Hack to call setup on the datamodule before linking args
def get_value(data):
if not data._did_setup:
data.setup('fit')
return getattr(data, key)
return get_value
# pass dataset stats to model after datamodule initialization
parser.link_arguments(
"data",
"model.dataset_stats",
compute_fn=data_value_getter('dataset_stats'),
apply_on="instantiate")
super().add_arguments_to_parser(parser)
def main():
MWE_LightningCLI(
model_class=MWE_Model,
datamodule_class=MWE_Datamodule,
)
if __name__ == '__main__':
"""
CommandLine:
cd ~/code/geowatch/dev/mwe/
"""
main()
Apologies for the length of the MWE, probably could be a few hundred lines shorter, but I had it on hand and it demonstrates the issue well enough. The link_arguments and model init is the important part:
class MWE_LightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
def data_value_getter(key):
# Hack to call setup on the datamodule before linking args
def get_value(data):
if not data._did_setup:
data.setup('fit')
return getattr(data, key)
return get_value
# pass dataset stats to model after datamodule initialization
parser.link_arguments(
"data",
"model.dataset_stats",
compute_fn=data_value_getter('dataset_stats'),
apply_on="instantiate")
super().add_arguments_to_parser(parser)
class MWE_Model(pl.LightningModule):
def __init__(self, sorting=False, dataset_stats=None, d_model=16):
super().__init__()
self.save_hyperparameters()
...
Given the above script saved as lightning_cli_save_hyperaparams_error_on_link_args.py
, I invoke it as:
DEFAULT_ROOT_DIR=./mwe_train_dir
python lightning_cli_save_hyperaparams_error_on_link_args.py fit --config "
model:
sorting: True
data:
num_workers: 8
batch_size: 2
max_items_per_epoch: 200
optimizer:
class_path: torch.optim.Adam
init_args:
lr: 1e-7
trainer:
default_root_dir : $DEFAULT_ROOT_DIR
accelerator : gpu
devices : 1
max_epochs: 100
"
CKPT_FPATH=$(python -c "import pathlib; print(list(pathlib.Path('$DEFAULT_ROOT_DIR/lightning_logs').glob('*/checkpoints/*.ckpt'))[0])")
HPARAM_FPATH=$(python -c "import pathlib; print(sorted(pathlib.Path('$DEFAULT_ROOT_DIR/lightning_logs').glob('*/hparams.yaml'))[-1])")
cat "$HPARAM_FPATH"
Error messages and logs
When using pytorch_lightning 2.2.5, running:
HPARAM_FPATH=$(python -c "import pathlib; print(sorted(pathlib.Path('$DEFAULT_ROOT_DIR/lightning_logs').glob('*/hparams.yaml'))[-1])")
cat "$HPARAM_FPATH"
Correctly prints hyparams that include the dataset_stats
linked arguments.
sorting: true
dataset_stats:
known_modalities:
- sensor: sensor1
channels: rgb
num_bands: 3
dims:
- 23
- 23
known_tasks:
- name: class
classes:
- a
- b
- c
- d
- e
dims:
- 3
- 3
d_model: 16
batch_size: 2
num_workers: 8
max_items_per_epoch: 200
But on the latest master branch and 2.4.0 it incorrectly prints:
sorting: true
d_model: 16
_instantiator: pytorch_lightning.cli.instantiate_module
batch_size: 2
num_workers: 8
max_items_per_epoch: 200
Environment
Current environment
- CUDA:
- GPU:
- NVIDIA GeForce RTX 3090
- NVIDIA GeForce RTX 3090
- available: True
- version: 12.4
- GPU:
- Lightning:
- lightning: 2.4.0
- lightning-utilities: 0.11.2
- perceiver-pytorch: 0.8.3
- performer-pytorch: 1.0.11
- pytorch-lightning: 2.4.0
- pytorch-msssim: 0.1.5
- pytorch-ranger: 0.1.1
- reformer-pytorch: 1.4.3
- torch: 2.4.0+cu124
- torch-liberator: 0.2.2
- torch-optimizer: 0.1.0
- torchaudio: 2.4.0+cu124
- torchmetrics: 0.11.0
- torchvision: 0.19.0
- Packages:
- absl-py: 1.4.0
- accelerate: 0.30.1
- addict: 2.4.0
- affine: 2.3.0
- aiobotocore: 2.5.4
- aiohttp: 3.9.5
- aiohttp-retry: 2.8.3
- aioitertools: 0.11.0
- aiosignal: 1.3.1
- alabaster: 0.7.16
- albumentations: 1.0.0
- amqp: 5.2.0
- annotated-types: 0.7.0
- antlr4-python3-runtime: 4.9.3
- anyio: 4.6.0
- anytree: 2.12.1
- appdirs: 1.4.4
- argcomplete: 3.5.0
- argo-workflows: 6.5.6
- arrow: 1.3.0
- asciitree: 0.3.3
- astor: 0.8.1
- astroid: 3.2.2
- asttokens: 2.4.1
- astunparse: 1.6.3
- asyncssh: 2.14.2
- atomicwrites: 1.4.0
- atpublic: 4.1.0
- attrs: 23.2.0
- auditwheel: 6.1.0
- autobahn: 24.4.2
- autodocsumm: 0.2.13
- automat: 22.10.0
- autopep8: 2.0.0
- axial-positional-embedding: 0.2.1
- babel: 2.15.0
- backports.tarfile: 1.2.0
- baron: 0.10.1
- bashlex: 0.18
- bcrypt: 4.1.3
- beautifulsoup4: 4.12.3
- bidict: 0.23.1
- billiard: 4.2.0
- black: 24.4.2
- blake3: 0.3.1
- bleach: 6.1.0
- blinker: 1.8.2
- boto: 2.49.0
- boto3: 1.28.17
- botocore: 1.31.17
- bpytop: 1.0.68
- bracex: 2.4
- brotli: 1.1.0
- build: 1.2.2
- cachecontrol: 0.14.0
- cachetools: 5.4.0
- celery: 5.4.0
- certifi: 2024.2.2
- cffi: 1.16.0
- cfgv: 3.4.0
- chardet: 5.2.0
- charset-normalizer: 2.0.12
- chromecontroller: 0.3.26
- cibuildwheel: 2.21.0
- cleo: 2.1.0
- click: 8.1.7
- click-didyoumean: 0.3.1
- click-plugins: 1.1.1
- click-repl: 0.3.0
- cligj: 0.7.2
- cloudpickle: 3.0.0
- cmake: 3.29.3
- cmd-queue: 0.1.21
- codecarbon: 2.2.4
- colorama: 0.4.6
- colormath: 3.0.0
- colt5-attention: 0.10.20
- comm: 0.2.2
- commonmark: 0.9.1
- configargparse: 1.7
- configobj: 5.0.8
- constantly: 23.10.4
- contourpy: 1.2.1
- coverage: 7.4.3
- crashtest: 0.4.1
- cryptography: 42.0.7
- cssutils: 2.10.2
- cycler: 0.12.1
- cython: 0.29.34
- dask: 2023.8.1
- dataframe-image: 0.1.13
- dataproperty: 1.0.1
- dbus-python: 1.3.2
- debugpy: 1.8.2
- decorator: 5.1.1
- defusedxml: 0.7.1
- delayed-image: 0.3.2
- delorean: 1.0.0
- detectron2: 0.6
- diceware: 0.10
- dictdiffer: 0.9.0
- diskcache: 5.6.3
- distinctipy: 1.2.1
- distlib: 0.3.8
- distro: 1.9.0
- docopt: 0.6.2
- docstring-parser: 0.16
- docutils: 0.20.1
- dominate: 2.9.1
- dpath: 2.1.6
- dtool-ibeis: 1.1.2
- dulwich: 0.22.1
- dvc: 3.51.2
- dvc-data: 3.15.1
- dvc-http: 2.32.0
- dvc-objects: 5.1.0
- dvc-render: 1.0.2
- dvc-s3: 3.2.0
- dvc-ssh: 4.1.1
- dvc-studio-client: 0.20.0
- dvc-task: 0.4.0
- einops: 0.6.0
- entrypoints: 0.4
- et-xmlfile: 1.1.0
- executing: 2.0.1
- faiss-cpu: 1.8.0
- fasteners: 0.17.3
- fastjsonschema: 2.19.1
- filelock: 3.15.4
- filterpy: 1.4.5
- fiona: 1.8.22
- fire: 0.4.0
- flake8: 7.0.0
- flask: 3.0.3
- flask-basicauth: 0.2.0
- flask-cors: 3.0.10
- flask-socketio: 5.3.6
- flatten-dict: 0.4.2
- flexcache: 0.3
- flexparser: 0.3.1
- flufl.lock: 7.1.1
- fonttools: 4.51.0
- frozenlist: 1.4.1
- fsspec: 2024.6.0
- funcy: 2.0
- futures-actors: 0.0.5
- fuzzywuzzy: 0.18.0
- fvcore: 0.1.5.post20221221
- gdal: 3.5.2
- geodatasets: 2023.12.0
- geographiclib: 2.0
- geojson: 3.0.1
- geomet: 1.1.0
- geopandas: 0.14.4
- geopy: 2.4.1
- geowatch: 0.18.4
- gevent: 24.2.1
- girder-client: 3.2.4.dev30+gcacd0e706
- git-of-theseus: 0.3.4
- git-python: 1.0.3
- git-well: 0.2.1
- gitdb: 4.0.11
- gitpython: 3.1.43
- google-api-core: 2.19.0
- google-api-python-client: 2.130.0
- google-auth: 2.29.0
- google-auth-httplib2: 0.2.0
- google-auth-oauthlib: 1.0.0
- googleapis-common-protos: 1.63.0
- grandalf: 0.8
- graphid: 0.1.0
- greenlet: 3.0.3
- grpcio: 1.63.0
- gto: 1.7.1
- guitool-ibeis: 2.2.0
- h11: 0.14.0
- h3: 3.7.7
- hardware: 0.31.0
- hkdf: 0.0.3
- html2image: 2.0.4.3
- httpcore: 0.16.3
- httplib2: 0.22.0
- httpx: 0.23.3
- huggingface-hub: 0.23.0
- humanize: 4.8.0
- hydra-core: 1.3.2
- hyperlink: 21.0.0
- ibeis: 2.3.2
- identify: 2.6.0
- idna: 3.7
- ijson: 3.2.1
- imageio: 2.34.1
- imagesize: 1.4.1
- importlib-metadata: 7.2.1
- importlib-resources: 6.4.0
- incremental: 24.7.2
- iniconfig: 2.0.0
- installer: 0.7.0
- instant-rst: 0.9.9.1
- iopath: 0.1.9
- ipykernel: 6.29.5
- ipython: 8.18.1
- ipython-genutils: 0.2.0
- isort: 5.13.2
- iterable-io: 1.0.0
- iterative-telemetry: 0.0.8
- itk: 5.4.0
- itk-core: 5.4.0
- itk-filtering: 5.4.0
- itk-io: 5.4.0
- itk-numerics: 5.4.0
- itk-registration: 5.4.0
- itk-segmentation: 5.4.0
- itsdangerous: 2.2.0
- jaraco.classes: 3.4.0
- jaraco.context: 5.3.0
- jaraco.functools: 4.0.2
- jedi: 0.19.1
- jeepney: 0.8.0
- jellyfin-apiclient-python: 1.9.2
- jellyfin-migrator: 0.0.0
- jinja2: 3.1.4
- jmespath: 1.0.1
- joblib: 1.4.2
- johnnydep: 1.20.4
- jq: 1.7.0
- jsonargparse: 4.32.1
- jsonnet: 0.20.0
- jsonpath: 0.82.2
- jsonschema: 4.19.2
- jsonschema-specifications: 2023.12.1
- jupyter-client: 8.6.1
- jupyter-core: 5.7.2
- jupyterlab-pygments: 0.3.0
- kafka-python: 2.0.2
- keyring: 24.3.1
- kiwisolver: 1.4.5
- kombu: 5.3.7
- kornia: 0.6.8
- kornia-rs: 0.1.3
- kubernetes: 29.0.0
- kwalop: 0.1.0
- kwarray: 0.6.19
- kwcoco: 0.8.5
- kwcoco-explorer: 0.0.1
- kwgis: 0.1.1
- kwimage: 0.10.1
- kwimage-ext: 0.2.1
- kwplot: 0.5.2
- kwutil: 0.3.3
- lark: 1.1.7
- lark-cython: 0.0.15
- lazy-loader: 0.3
- levenshtein: 0.25.1
- liberator: 0.1.0
- lightning: 2.4.0
- lightning-utilities: 0.11.2
- line-profiler: 4.1.3
- linkify-it-py: 2.0.3
- lit: 18.1.4
- livereload: 2.7.0
- llvmlite: 0.42.0
- local-attention: 1.9.1
- locket: 1.0.0
- lockfile: 0.12.2
- logmatic-python: 0.1.7
- lxml: 4.9.2
- magic-wormhole: 0.14.0
- markdown: 3.6
- markdown-it-py: 3.0.0
- markupsafe: 2.1.5
- mathutf: 0.1.0
- matplotlib: 3.8.2
- matplotlib-inline: 0.1.7
- maturin: 1.7.4
- mbstrdecoder: 1.1.3
- mccabe: 0.7.0
- mdit-py-plugins: 0.4.1
- mdurl: 0.1.2
- mgrs: 1.4.6
- mistune: 3.0.2
- mkinit: 1.1.0
- mmcv: 2.0.0
- mmengine: 0.10.4
- monai: 0.8.0
- more-itertools: 8.12.0
- mpmath: 1.3.0
- msgpack: 1.0.8
- multidict: 6.0.5
- munch: 4.0.0
- mutagen: 1.47.0
- mypy: 1.10.0
- mypy-extensions: 1.0.0
- myst-parser: 3.0.1
- nbclient: 0.10.0
- nbconvert: 7.16.4
- nbformat: 5.10.4
- ndsampler: 0.7.9
- nest-asyncio: 1.6.0
- netharn: 0.6.2
- networkx: 3.3
- networkx-algo-common-subtree: 0.2.1
- nh3: 0.2.18
- nodeenv: 1.9.1
- nrtk: 0.11.0
- nrtk-explorer: 0.3.0
- numba: 0.59.1
- numcodecs: 0.13.0
- numexpr: 2.8.4
- numpy: 1.25.2
- nvidia-cublas-cu11: 11.10.3.66
- nvidia-cublas-cu12: 12.4.2.65
- nvidia-cuda-cupti-cu11: 11.7.101
- nvidia-cuda-cupti-cu12: 12.4.99
- nvidia-cuda-nvrtc-cu11: 11.7.99
- nvidia-cuda-nvrtc-cu12: 12.4.99
- nvidia-cuda-runtime-cu11: 11.7.99
- nvidia-cuda-runtime-cu12: 12.4.99
- nvidia-cudnn-cu11: 8.5.0.96
- nvidia-cudnn-cu12: 9.1.0.70
- nvidia-cufft-cu11: 10.9.0.58
- nvidia-cufft-cu12: 11.2.0.44
- nvidia-curand-cu11: 10.2.10.91
- nvidia-curand-cu12: 10.3.5.119
- nvidia-cusolver-cu11: 11.4.0.1
- nvidia-cusolver-cu12: 11.6.0.99
- nvidia-cusparse-cu11: 11.7.4.91
- nvidia-cusparse-cu12: 12.3.0.142
- nvidia-nccl-cu11: 2.14.3
- nvidia-nccl-cu12: 2.20.5
- nvidia-nvjitlink-cu12: 12.4.99
- nvidia-nvtx-cu11: 11.7.91
- nvidia-nvtx-cu12: 12.4.99
- oauthlib: 3.2.2
- omegaconf: 2.3.0
- openapi-python-client: 0.20.0
- openapi-python-generator: 0.5.0
- openapi-schema-pydantic: 1.2.4
- opencv-python-headless: 4.10.0.84
- openpyxl: 3.0.9
- opentimestamps: 0.4.5
- opentimestamps-client: 0.7.1
- ordered-set: 4.1.0
- orjson: 3.10.3
- osmnx: 1.9.4
- oyaml: 1.0
- packaging: 24.1
- pandas: 1.5.3
- pandocfilters: 1.5.1
- parse: 1.19.0
- parso: 0.8.4
- partd: 1.4.2
- pathspec: 0.12.1
- pathvalidate: 3.2.1
- patsy: 0.5.6
- pbr: 6.0.0
- pendulum: 3.0.0
- perceiver-pytorch: 0.8.3
- performer-pytorch: 1.0.11
- pexpect: 4.9.0
- pillow: 10.3.0
- pint: 0.24.3
- pip: 24.2
- pkginfo: 1.10.0
- platformdirs: 3.11.0
- plotly: 5.24.0
- plottool-ibeis: 2.3.0
- pls-dont-shadow-me: 1.0.0
- pluggy: 1.5.0
- pockets: 0.9.1
- poetry: 1.8.3
- poetry-core: 1.9.0
- poetry-plugin-export: 1.8.0
- pooch: 1.8.2
- portalocker: 2.10.1
- portion: 2.4.1
- pre-commit: 3.8.0
- prettytable: 3.11.0
- product-key-memory: 0.2.2
- progiter: 2.0.0
- prometheus-client: 0.20.0
- prompt-toolkit: 3.0.43
- proto-plus: 1.23.0
- protobuf: 4.25.3
- psutil: 5.9.6
- psycopg2-binary: 2.9.5
- ptyprocess: 0.7.0
- pure-eval: 0.2.2
- purepy-root-demo-pkg-lzcsutvo: 1.0.0
- purepy-src-demo-pkg: 1.0.0
- purepy-src-demo-pkg-dbrmcjpb: 1.0.0
- purepy-src-demo-pkg-lzcsutvo: 1.0.0
- py-cpuinfo: 9.0.0
- pyasn1: 0.6.0
- pyasn1-modules: 0.4.0
- pybsm: 0.5.1
- pycocotools: 2.0.7
- pycodestyle: 2.11.1
- pycparser: 2.22
- pycryptodomex: 3.20.0
- pydantic: 2.7.1
- pydantic-core: 2.18.2
- pydot: 2.0.0
- pyelftools: 0.31
- pyfiglet: 1.0.2
- pyflakes: 3.2.0
- pyflann-ibeis: 2.4.0
- pygame: 2.6.0
- pygit2: 1.15.0
- pygments: 2.18.0
- pygraphviz: 1.13
- pygtrie: 2.5.0
- pyhesaff: 2.1.1
- pylatex: 0+untagged.769.gb48e8ec
- pylatexenc: 3.0a29
- pymongo: 3.13.0
- pynacl: 1.5.0
- pynmea2: 1.19.0
- pynndescent: 0.5.12
- pynvim: 0.5.0
- pynvml: 11.5.0
- pyo3-example: 0.1.0
- pyopenssl: 24.1.0
- pyparsing: 3.1.2
- pyperclip: 1.8.2
- pypistats: 1.6.0
- pypng: 0.20220715.0
- pypogo: 0.1.0
- pyproj: 3.4.1
- pyproject-api: 1.7.1
- pyproject-hooks: 1.1.0
- pyqrcode: 1.2.1
- pyqt5: 5.15.10
- pyqt5-qt5: 5.15.2
- pyqt5-sip: 12.13.0
- pyqtree: 1.0.0
- pysocks: 1.7.1
- pystac: 1.10.1
- pystac-client: 0.8.1
- pytablewriter: 1.2.0
- pytest: 8.0.2
- pytest-cov: 5.0.0
- pytest-subtests: 0.13.1
- python-bitcoinlib: 0.12.2
- python-dateutil: 2.9.0.post1.dev3+g9eaa5de
- python-engineio: 4.9.1
- python-gitlab: 4.6.0
- python-json-logger: 2.0.7
- python-levenshtein: 0.25.1
- python-slugify: 8.0.4
- python-socketio: 5.11.3
- pytimeparse: 1.1.8
- pytorch-lightning: 2.4.0
- pytorch-msssim: 0.1.5
- pytorch-ranger: 0.1.1
- pytz: 2024.1
- pywavelets: 1.6.0
- pyyaml: 6.0.1
- pyzmq: 26.0.3
- quantities: 0.15.0
- rapidfuzz: 3.9.1
- rasterio: 1.3.10
- readme-renderer: 44.0
- reconplogger: 4.16.1
- redbaron: 0.9.2
- referencing: 0.35.1
- reformer-pytorch: 1.4.3
- regex: 2024.5.10
- requests: 2.32.2
- requests-oauthlib: 2.0.0
- requests-toolbelt: 1.0.0
- responses: 0.25.3
- rfc3986: 1.5.0
- rgd-client: 0.2.7
- rgd-imagery-client: 0.2.7
- rich: 12.5.1
- rich-argparse: 1.1.0
- rpds-py: 0.18.1
- rply: 0.7.8
- rsa: 4.9
- rtree: 1.0.1
- ruamel.yaml: 0.17.32
- ruamel.yaml.clib: 0.2.8
- ruff: 0.4.5
- ruyaml: 0.91.0
- s3fs: 2024.6.0
- s3transfer: 0.6.2
- s5cmd: 0.2.0
- safer: 4.12.3
- safetensors: 0.4.3
- scikit-build: 0.17.6
- scikit-image: 0.21.0
- scikit-learn: 1.5.1
- scipy: 1.14.0
- scmrepo: 3.3.5
- scriptconfig: 0.7.16
- seaborn: 0.13.2
- secretstorage: 3.3.3
- semver: 3.0.2
- service-identity: 24.1.0
- setuptools: 67.7.2
- shapely: 2.0.1
- shellingham: 1.5.4
- shitspotter: 0.0.1
- shortuuid: 1.0.13
- shtab: 1.7.1
- simple-dvc: 0.2.2
- simple-websocket: 1.0.0
- simpleitk: 2.3.1
- simplejson: 3.19.2
- simplekml: 1.3.3
- six: 1.16.0
- smartflow: 3.1.3
- smmap: 5.0.1
- smqtk-classifier: 0.19.0
- smqtk-core: 0.19.0
- smqtk-dataprovider: 0.18.0
- smqtk-descriptors: 0.19.0
- smqtk-detection: 0.20.1
- smqtk-image-io: 0.17.1
- smqtk-indexing: 0.18.0
- smqtk-iqr: 0.15.1
- smqtk-relevancy: 0.17.0
- sniffio: 1.3.1
- snowballstemmer: 2.2.0
- snuggs: 1.4.7
- sortedcontainers: 2.4.0
- soupsieve: 2.5
- spake2: 0.8
- sphinx: 7.3.7
- sphinx-autoapi: 3.1.1
- sphinx-autobuild: 2024.4.16
- sphinx-autodoc-typehints: 2.3.0
- sphinx-reredirects: 0.1.3
- sphinx-rtd-theme: 2.0.0
- sphinxcontrib-applehelp: 1.0.8
- sphinxcontrib-devhelp: 1.0.6
- sphinxcontrib-htmlhelp: 2.0.5
- sphinxcontrib-jquery: 4.1
- sphinxcontrib-jsmath: 1.0.1
- sphinxcontrib-napoleon: 0.7
- sphinxcontrib-qthelp: 1.0.7
- sphinxcontrib-serializinghtml: 1.1.10
- sqlalchemy: 1.4.50
- sqlalchemy-utils: 0.41.2
- sqltrie: 0.11.0
- sshfs: 2024.4.1
- stack-data: 0.6.3
- starlette: 0.37.2
- statsmodels: 0.14.2
- structlog: 24.2.0
- sympy: 1.12
- tabledata: 1.3.3
- tabulate: 0.9.0
- tcolorpy: 0.1.6
- tempenv: 0.2.0
- tenacity: 9.0.0
- tensorboard: 2.14.0
- tensorboard-data-server: 0.7.2
- tensorrt-bindings: 8.6.1
- tensorrt-cu12: 10.0.1
- tensorrt-cu12-bindings: 10.0.1
- tensorrt-cu12-libs: 10.0.1
- tensorrt-libs: 8.6.1
- termcolor: 2.4.0
- text-unidecode: 1.3
- textual: 0.1.18
- threadpoolctl: 3.5.0
- tifffile: 2024.5.22
- timerit: 1.1.0
- timezonefinder: 6.5.2
- timm: 0.6.13
- tinycss2: 1.3.0
- tokenizers: 0.15.2
- toml: 0.10.2
- tomli: 2.0.1
- tomlkit: 0.12.5
- toolz: 0.12.1
- torch: 2.4.0+cu124
- torch-liberator: 0.2.2
- torch-optimizer: 0.1.0
- torchaudio: 2.4.0+cu124
- torchmetrics: 0.11.0
- torchvision: 0.19.0
- tornado: 6.4
- tox: 4.17.1
- tqdm: 4.64.1
- traitlets: 5.14.3
- trame: 3.6.1
- trame-client: 3.0.3
- trame-plotly: 3.0.2
- trame-quasar: 0.2.1
- trame-server: 3.0.1
- trame-vuetify: 2.7.0
- transformers: 4.37.2
- triton: 3.0.0
- trove-classifiers: 2024.9.12
- twine: 5.1.1
- twisted: 24.3.0
- txaio: 23.1.1
- txtorcon: 23.11.0
- typepy: 1.3.2
- typer: 0.12.3
- types-python-dateutil: 2.9.0.20240316
- types-pyyaml: 6.0.12.20240808
- types-requests: 2.32.0.20240907
- types-setuptools: 70.0.0.20240524
- typeshed-client: 2.5.1
- typing-extensions: 4.11.0
- tzdata: 2024.1
- tzlocal: 5.2
- ubelt: 1.3.6
- uc-micro-py: 1.0.3
- ujson: 5.6.0
- umap-learn: 0.5.6
- uncertainties: 3.2.2
- uritemplate: 4.1.1
- uritools: 4.0.2
- urllib3: 1.26.20
- utm: 0.7.0
- utool: 2.2.0
- uv: 0.3.4
- uvicorn: 0.29.0
- validators: 0.28.1
- vimtk: 0.5.0
- vine: 5.1.0
- virtualenv: 20.26.3
- voluptuous: 0.14.2
- vtool-ibeis: 2.3.0
- vtool-ibeis-ext: 0.1.1
- watchfiles: 0.21.0
- wcmatch: 8.5.2
- wcwidth: 0.2.13
- webencodings: 0.5.1
- websocket-client: 1.8.0
- websockets: 12.0
- werkzeug: 3.0.4
- wheel: 0.40.0
- wimpy: 0.6
- wrapt: 1.14.1
- wslink: 2.0.4
- wsproto: 1.2.0
- xarray: 0.17.0
- xcookie: 0.2.2
- xdev: 1.5.2
- xdoctest: 1.1.5
- xinspect: 0.2.0
- xmltodict: 0.12.0
- xxhash: 3.4.1
- yacs: 0.1.8
- yapf: 0.40.2
- yarl: 1.9.4
- yt-dlp: 2024.8.6
- zarr: 2.18.2
- zc.lockfile: 3.0.post1
- zipp: 3.18.1
- zipstream-ng: 1.7.1
- zope.event: 5.0
- zope.interface: 6.4.post2
- System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.11.9
- release: 6.8.0-45-generic
- version: How to set hyperparameters search range and run the search? #45-Ubuntu SMP PREEMPT_DYNAMIC Fri Aug 30 12:02:04 UTC 2024
More info
No response