Skip to content

Commit

Permalink
Add ruff for linting, remove flake8, remove isort, remove pylint (#94)
Browse files Browse the repository at this point in the history
  • Loading branch information
gsheni authored Mar 25, 2024
1 parent ed3645b commit 4ad94da
Show file tree
Hide file tree
Showing 12 changed files with 318 additions and 243 deletions.
25 changes: 6 additions & 19 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -79,30 +79,17 @@ install-develop: clean-build clean-pyc ## install the package in editable mode a

# LINT TARGETS

.PHONY: lint-deepecho
lint-deepecho: ## check style with flake8 and isort
flake8 deepecho
isort -c --recursive deepecho
pylint deepecho --rcfile=setup.cfg

.PHONY: lint-tests
lint-tests: ## check style with flake8 and isort
flake8 --ignore=D tests
isort -c --recursive tests

.PHONY: lint
lint: ## Run all code style checks
invoke lint
lint:
ruff check .
ruff format . --check

.PHONY: fix-lint
fix-lint: ## fix lint issues using autoflake, autopep8, and isort
find deepecho tests -name '*.py' | xargs autoflake --in-place --remove-all-unused-imports --remove-unused-variables
autopep8 --in-place --recursive --aggressive deepecho tests
isort --apply --atomic --recursive deepecho tests

fix-lint:
ruff check --fix .
ruff format .

# TEST TARGETS

.PHONY: test-unit
test-unit: ## run unit tests using pytest
invoke unit
Expand Down
33 changes: 25 additions & 8 deletions deepecho/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from deepecho.sequences import assemble_sequences


class DeepEcho():
class DeepEcho:
"""The base class for DeepEcho models."""

_verbose = True
Expand All @@ -28,7 +28,13 @@ def _validate(sequences, context_types, data_types):
data_types:
See `fit`.
"""
dtypes = set(['continuous', 'categorical', 'ordinal', 'count', 'datetime'])
dtypes = set([
'continuous',
'categorical',
'ordinal',
'count',
'datetime',
])
assert all(dtype in dtypes for dtype in context_types)
assert all(dtype in dtypes for dtype in data_types)

Expand Down Expand Up @@ -99,8 +105,15 @@ def _get_data_types(data, data_types, columns):

return dtypes_list

def fit(self, data, entity_columns=None, context_columns=None,
data_types=None, segment_size=None, sequence_index=None):
def fit(
self,
data,
entity_columns=None,
context_columns=None,
data_types=None,
segment_size=None,
sequence_index=None,
):
"""Fit the model to a dataframe containing time series data.
Args:
Expand Down Expand Up @@ -135,8 +148,7 @@ def fit(self, data, entity_columns=None, context_columns=None,
if segment_size is not None and not isinstance(segment_size, int):
if sequence_index is None:
raise TypeError(
'`segment_size` must be of type `int` if '
'no `sequence_index` is given.'
'`segment_size` must be of type `int` if ' 'no `sequence_index` is given.'
)
if data[sequence_index].dtype.kind != 'M':
raise TypeError(
Expand All @@ -161,7 +173,12 @@ def fit(self, data, entity_columns=None, context_columns=None,
data_types = self._get_data_types(data, data_types, self._data_columns)
context_types = self._get_data_types(data, data_types, self._context_columns)
sequences = assemble_sequences(
data, self._entity_columns, self._context_columns, segment_size, sequence_index)
data,
self._entity_columns,
self._context_columns,
segment_size,
sequence_index,
)

# Validate and fit
self._validate(sequences, context_types, data_types)
Expand Down Expand Up @@ -242,7 +259,7 @@ def sample(self, num_entities=None, context=None, sequence_length=None):
# Reformat as a DataFrame
group = pd.DataFrame(
dict(zip(self._data_columns, sequence)),
columns=self._data_columns
columns=self._data_columns,
)
group[self._entity_columns] = entity_values
for column, value in zip(self._context_columns, context_values):
Expand Down
38 changes: 23 additions & 15 deletions deepecho/models/basic_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@


def _expand_context(data, context):
return torch.cat([
data,
context.unsqueeze(0).expand(data.shape[0], context.shape[0], context.shape[1])
], dim=2)
return torch.cat(
[
data,
context.unsqueeze(0).expand(data.shape[0], context.shape[0], context.shape[1]),
],
dim=2,
)


class BasicGenerator(torch.nn.Module):
Expand Down Expand Up @@ -65,7 +68,7 @@ def forward(self, context=None, sequence_length=None):
"""
latent = torch.randn(
size=(sequence_length, context.size(0), self.latent_size),
device=self.device
device=self.device,
)
latent = _expand_context(latent, context)

Expand Down Expand Up @@ -150,8 +153,16 @@ class BasicGANModel(DeepEcho):
_model_data_size = None
_generator = None

def __init__(self, epochs=1024, latent_size=32, hidden_size=16,
gen_lr=1e-3, dis_lr=1e-3, cuda=True, verbose=True):
def __init__(
self,
epochs=1024,
latent_size=32,
hidden_size=16,
gen_lr=1e-3,
dis_lr=1e-3,
cuda=True,
verbose=True,
):
self._epochs = epochs
self._gen_lr = gen_lr
self._dis_lr = dis_lr
Expand Down Expand Up @@ -211,7 +222,7 @@ def _index_map(columns, types):
'type': column_type,
'min': np.min(values),
'max': np.max(values),
'indices': (dimensions, dimensions + 1)
'indices': (dimensions, dimensions + 1),
}
dimensions += 2

Expand All @@ -221,10 +232,7 @@ def _index_map(columns, types):
indices[value] = dimensions
dimensions += 1

mapping[column] = {
'type': column_type,
'indices': indices
}
mapping[column] = {'type': column_type, 'indices': indices}

else:
raise ValueError(f'Unsupported type: {column_type}')
Expand Down Expand Up @@ -317,7 +325,7 @@ def _value_to_tensor(self, tensor, value, properties):
self._one_hot_encode(tensor, value, properties)

else:
raise ValueError() # Theoretically unreachable
raise ValueError() # Theoretically unreachable

def _data_to_tensor(self, data):
"""Convert the input data to the corresponding tensor.
Expand Down Expand Up @@ -370,7 +378,7 @@ def _tensor_to_data(self, tensor):
elif column_type in ('categorical', 'ordinal'):
value = self._one_hot_decode(tensor, row, properties)
else:
raise ValueError() # Theoretically unreachable
raise ValueError() # Theoretically unreachable

column_data.append(value)

Expand Down Expand Up @@ -412,7 +420,7 @@ def _truncate(self, generated):
end_flag = sequence[:, self._data_size]
if (end_flag == 1.0).any():
cut_idx = end_flag.detach().cpu().numpy().argmax()
sequence[cut_idx + 1:] = 0.0
sequence[cut_idx + 1 :] = 0.0

def _generate(self, context, sequence_length=None):
generated = self._generator(
Expand Down
Loading

0 comments on commit 4ad94da

Please sign in to comment.