Skip to content

Commit

Permalink
Merge pull request #93 from sdv-dev/add_ruff
Browse files Browse the repository at this point in the history
Add ruff for linting, remove flake8, remove isort, remove pylint
  • Loading branch information
gsheni authored Mar 25, 2024
2 parents ed3645b + e0b7ef2 commit 3380134
Show file tree
Hide file tree
Showing 13 changed files with 507 additions and 289 deletions.
15 changes: 5 additions & 10 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,6 @@ install-develop: clean-build clean-pyc ## install the package in editable mode a

.PHONY: lint-deepecho
lint-deepecho: ## check style with flake8 and isort
flake8 deepecho
isort -c --recursive deepecho
pylint deepecho --rcfile=setup.cfg

.PHONY: lint-tests
lint-tests: ## check style with flake8 and isort
Expand All @@ -92,17 +89,15 @@ lint-tests: ## check style with flake8 and isort

.PHONY: lint
lint: ## Run all code style checks
invoke lint
ruff check .
ruff format . --check

.PHONY: fix-lint
fix-lint: ## fix lint issues using autoflake, autopep8, and isort
find deepecho tests -name '*.py' | xargs autoflake --in-place --remove-all-unused-imports --remove-unused-variables
autopep8 --in-place --recursive --aggressive deepecho tests
isort --apply --atomic --recursive deepecho tests

fix-lint: ## fix lint issues using ruff
ruff check --fix .
ruff format .

# TEST TARGETS

.PHONY: test-unit
test-unit: ## run unit tests using pytest
invoke unit
Expand Down
4 changes: 3 additions & 1 deletion deepecho/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,6 @@

def load_demo():
"""Load the demo DataFrame."""
return pd.read_csv(os.path.join(_DATA_PATH, 'demo.csv'), parse_dates=['date'])
return pd.read_csv(
os.path.join(_DATA_PATH, 'demo.csv'), parse_dates=['date']
)
46 changes: 36 additions & 10 deletions deepecho/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from deepecho.sequences import assemble_sequences


class DeepEcho():
class DeepEcho:
"""The base class for DeepEcho models."""

_verbose = True
Expand All @@ -28,7 +28,13 @@ def _validate(sequences, context_types, data_types):
data_types:
See `fit`.
"""
dtypes = set(['continuous', 'categorical', 'ordinal', 'count', 'datetime'])
dtypes = set([
'continuous',
'categorical',
'ordinal',
'count',
'datetime',
])
assert all(dtype in dtypes for dtype in context_types)
assert all(dtype in dtypes for dtype in data_types)

Expand Down Expand Up @@ -94,13 +100,22 @@ def _get_data_types(data, data_types, columns):
elif kind == 'M':
dtypes_list.append('datetime')
else:
error = f'Unsupported data_type for column {column}: {dtype}'
error = (
f'Unsupported data_type for column {column}: {dtype}'
)
raise ValueError(error)

return dtypes_list

def fit(self, data, entity_columns=None, context_columns=None,
data_types=None, segment_size=None, sequence_index=None):
def fit(
self,
data,
entity_columns=None,
context_columns=None,
data_types=None,
segment_size=None,
sequence_index=None,
):
"""Fit the model to a dataframe containing time series data.
Args:
Expand Down Expand Up @@ -131,7 +146,9 @@ def fit(self, data, entity_columns=None, context_columns=None,
such as integer values or datetimes.
"""
if not entity_columns and segment_size is None:
raise TypeError('If the data has no `entity_columns`, `segment_size` must be given.')
raise TypeError(
'If the data has no `entity_columns`, `segment_size` must be given.'
)
if segment_size is not None and not isinstance(segment_size, int):
if sequence_index is None:
raise TypeError(
Expand Down Expand Up @@ -159,9 +176,16 @@ def fit(self, data, entity_columns=None, context_columns=None,
self._data_columns.remove(sequence_index)

data_types = self._get_data_types(data, data_types, self._data_columns)
context_types = self._get_data_types(data, data_types, self._context_columns)
context_types = self._get_data_types(
data, data_types, self._context_columns
)
sequences = assemble_sequences(
data, self._entity_columns, self._context_columns, segment_size, sequence_index)
data,
self._entity_columns,
self._context_columns,
segment_size,
sequence_index,
)

# Validate and fit
self._validate(sequences, context_types, data_types)
Expand Down Expand Up @@ -212,7 +236,9 @@ def sample(self, num_entities=None, context=None, sequence_length=None):
"""
if context is None:
if num_entities is None:
raise TypeError('Either context or num_entities must be not None')
raise TypeError(
'Either context or num_entities must be not None'
)

context = self._context_values.sample(num_entities, replace=True)
context = context.reset_index(drop=True)
Expand Down Expand Up @@ -242,7 +268,7 @@ def sample(self, num_entities=None, context=None, sequence_length=None):
# Reformat as a DataFrame
group = pd.DataFrame(
dict(zip(self._data_columns, sequence)),
columns=self._data_columns
columns=self._data_columns,
)
group[self._entity_columns] = entity_values
for column, value in zip(self._context_columns, context_values):
Expand Down
108 changes: 76 additions & 32 deletions deepecho/models/basic_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,15 @@


def _expand_context(data, context):
return torch.cat([
data,
context.unsqueeze(0).expand(data.shape[0], context.shape[0], context.shape[1])
], dim=2)
return torch.cat(
[
data,
context.unsqueeze(0).expand(
data.shape[0], context.shape[0], context.shape[1]
),
],
dim=2,
)


class BasicGenerator(torch.nn.Module):
Expand Down Expand Up @@ -47,7 +52,9 @@ class BasicGenerator(torch.nn.Module):
Device to which this Module is associated to.
"""

def __init__(self, context_size, latent_size, hidden_size, data_size, device):
def __init__(
self, context_size, latent_size, hidden_size, data_size, device
):
super().__init__()
self.latent_size = latent_size
self.rnn = torch.nn.GRU(context_size + latent_size, hidden_size)
Expand All @@ -65,7 +72,7 @@ def forward(self, context=None, sequence_length=None):
"""
latent = torch.randn(
size=(sequence_length, context.size(0), self.latent_size),
device=self.device
device=self.device,
)
latent = _expand_context(latent, context)

Expand Down Expand Up @@ -150,8 +157,16 @@ class BasicGANModel(DeepEcho):
_model_data_size = None
_generator = None

def __init__(self, epochs=1024, latent_size=32, hidden_size=16,
gen_lr=1e-3, dis_lr=1e-3, cuda=True, verbose=True):
def __init__(
self,
epochs=1024,
latent_size=32,
hidden_size=16,
gen_lr=1e-3,
dis_lr=1e-3,
cuda=True,
verbose=True,
):
self._epochs = epochs
self._gen_lr = gen_lr
self._dis_lr = dis_lr
Expand Down Expand Up @@ -211,7 +226,7 @@ def _index_map(columns, types):
'type': column_type,
'min': np.min(values),
'max': np.max(values),
'indices': (dimensions, dimensions + 1)
'indices': (dimensions, dimensions + 1),
}
dimensions += 2

Expand All @@ -221,10 +236,7 @@ def _index_map(columns, types):
indices[value] = dimensions
dimensions += 1

mapping[column] = {
'type': column_type,
'indices': indices
}
mapping[column] = {'type': column_type, 'indices': indices}

else:
raise ValueError(f'Unsupported type: {column_type}')
Expand All @@ -239,21 +251,31 @@ def _analyze_data(self, sequences, context_types, data_types):
- Index map and dimensions for the context.
- Index map and dimensions for the data.
"""
sequence_lengths = np.array([len(sequence['data'][0]) for sequence in sequences])
sequence_lengths = np.array([
len(sequence['data'][0]) for sequence in sequences
])
self._max_sequence_length = np.max(sequence_lengths)
self._fixed_length = (sequence_lengths == self._max_sequence_length).all()
self._fixed_length = (
sequence_lengths == self._max_sequence_length
).all()

# Concatenate all the context sequences together
context = []
for column in range(len(context_types)):
context.append([sequence['context'][column] for sequence in sequences])
context.append([
sequence['context'][column] for sequence in sequences
])

self._context_map, self._context_size = self._index_map(context, context_types)
self._context_map, self._context_size = self._index_map(
context, context_types
)

# Concatenate all the data sequences together
data = []
for column in range(len(data_types)):
data.append(sum([sequence['data'][column] for sequence in sequences], []))
data.append(
sum([sequence['data'][column] for sequence in sequences], [])
)

self._data_map, self._data_size = self._index_map(data, data_types)

Expand Down Expand Up @@ -317,7 +339,7 @@ def _value_to_tensor(self, tensor, value, properties):
self._one_hot_encode(tensor, value, properties)

else:
raise ValueError() # Theoretically unreachable
raise ValueError() # Theoretically unreachable

def _data_to_tensor(self, data):
"""Convert the input data to the corresponding tensor.
Expand Down Expand Up @@ -366,11 +388,13 @@ def _tensor_to_data(self, tensor):
for row in range(sequence_length):
if column_type in ('continuous', 'count'):
round_value = column_type == 'count'
value = self._denormalize(tensor, row, properties, round_value=round_value)
value = self._denormalize(
tensor, row, properties, round_value=round_value
)
elif column_type in ('categorical', 'ordinal'):
value = self._one_hot_decode(tensor, row, properties)
else:
raise ValueError() # Theoretically unreachable
raise ValueError() # Theoretically unreachable

column_data.append(value)

Expand All @@ -394,10 +418,14 @@ def _transform(self, data):
if column_type in ('continuous', 'count'):
value_idx, missing_idx = properties['indices']
data[:, :, value_idx] = torch.tanh(data[:, :, value_idx])
data[:, :, missing_idx] = torch.sigmoid(data[:, :, missing_idx])
data[:, :, missing_idx] = torch.sigmoid(
data[:, :, missing_idx]
)
elif column_type in ('categorical', 'ordinal'):
indices = list(properties['indices'].values())
data[:, :, indices] = torch.nn.functional.softmax(data[:, :, indices])
data[:, :, indices] = torch.nn.functional.softmax(
data[:, :, indices]
)

return data

Expand All @@ -412,7 +440,7 @@ def _truncate(self, generated):
end_flag = sequence[:, self._data_size]
if (end_flag == 1.0).any():
cut_idx = end_flag.detach().cpu().numpy().argmax()
sequence[cut_idx + 1:] = 0.0
sequence[cut_idx + 1 :] = 0.0

def _generate(self, context, sequence_length=None):
generated = self._generator(
Expand All @@ -426,7 +454,9 @@ def _generate(self, context, sequence_length=None):

return generated

def _discriminator_step(self, discriminator, discriminator_opt, data_context, context):
def _discriminator_step(
self, discriminator, discriminator_opt, data_context, context
):
real_scores = discriminator(data_context)

fake = self._generate(context)
Expand Down Expand Up @@ -470,8 +500,12 @@ def _build_fit_artifacts(self):
hidden_size=self._hidden_size,
).to(self._device)

generator_opt = torch.optim.Adam(self._generator.parameters(), lr=self._gen_lr)
discriminator_opt = torch.optim.Adam(discriminator.parameters(), lr=self._dis_lr)
generator_opt = torch.optim.Adam(
self._generator.parameters(), lr=self._gen_lr
)
discriminator_opt = torch.optim.Adam(
discriminator.parameters(), lr=self._dis_lr
)

return discriminator, generator_opt, discriminator_opt

Expand Down Expand Up @@ -513,11 +547,17 @@ def fit_sequences(self, sequences, context_types, data_types):
"""
self._analyze_data(sequences, context_types, data_types)

data = self._build_tensor(self._data_to_tensor, sequences, 'data', dim=1)
context = self._build_tensor(self._context_to_tensor, sequences, 'context', dim=0)
data = self._build_tensor(
self._data_to_tensor, sequences, 'data', dim=1
)
context = self._build_tensor(
self._context_to_tensor, sequences, 'context', dim=0
)
data_context = _expand_context(data, context)

discriminator, generator_opt, discriminator_opt = self._build_fit_artifacts()
discriminator, generator_opt, discriminator_opt = (
self._build_fit_artifacts()
)

iterator = range(self._epochs)
if self._verbose:
Expand All @@ -539,7 +579,9 @@ def fit_sequences(self, sequences, context_types, data_types):
if self._verbose:
d_loss = discriminator_score.item()
g_loss = generator_score.item()
iterator.set_description(f'Epoch {epoch + 1} | D Loss {d_loss} | G Loss {g_loss}')
iterator.set_description(
f'Epoch {epoch + 1} | D Loss {d_loss} | G Loss {g_loss}'
)

def sample_sequence(self, context, sequence_length=None):
"""Sample a single sequence conditioned on context.
Expand All @@ -554,7 +596,9 @@ def sample_sequence(self, context, sequence_length=None):
A list of lists (data) corresponding to the types specified
in data_types when fit was called.
"""
context = self._context_to_tensor(context).unsqueeze(0).to(self._device)
context = (
self._context_to_tensor(context).unsqueeze(0).to(self._device)
)

with torch.no_grad():
generated = self._generate(context, sequence_length)
Expand Down
Loading

0 comments on commit 3380134

Please sign in to comment.