Skip to content

Commit f48e253

Browse files
committed
refactor to add documentation, clarify variable names, add test cases, and better encapsulate behaviors (among other things to simplify testing)
1 parent 4cfb11f commit f48e253

File tree

5 files changed

+303
-92
lines changed

5 files changed

+303
-92
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
build
44
dist
55
*.egg-info
6+
.idea

ssm-diff

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,53 @@
11
#!/usr/bin/env python
22
from __future__ import print_function
3-
from states import *
4-
import states.helpers as helpers
3+
54
import argparse
65
import os
76

7+
from states import states
8+
from states.helpers import DiffResolver
9+
10+
11+
def configure_endpoints(args):
12+
# pre-configure resolver, but still accept remote and local at runtime
13+
diff_resolver = DiffResolver.configure(force=args.force)
14+
return states.ParameterStore(args.profile, diff_resolver, paths=args.path), states.YAMLFile(args.filename, paths=args.path)
15+
816

917
def init(args):
10-
r, l = RemoteState(args.profile), LocalState(args.filename)
11-
l.save(r.get(flat=False, paths=args.path))
18+
"""Create a local YAML file from the SSM Parameter Store (per configs in args)"""
19+
remote, local = configure_endpoints(args)
20+
local.save(remote.clone())
1221

1322

1423
def pull(args):
15-
dictfilter = lambda x, y: dict([ (i,x[i]) for i in x if i in set(y) ])
16-
r, l = RemoteState(args.profile), LocalState(args.filename)
17-
diff = helpers.FlatDictDiffer(r.get(paths=args.path), l.get(paths=args.path))
18-
if args.force:
19-
ref_set = diff.changed().union(diff.removed()).union(diff.unchanged())
20-
target_set = diff.added()
21-
else:
22-
ref_set = diff.unchanged().union(diff.removed())
23-
target_set = diff.added().union(diff.changed())
24-
state = dictfilter(diff.ref, ref_set)
25-
state.update(dictfilter(diff.target, target_set))
26-
l.save(helpers.unflatten(state))
24+
"""Update local YAML file with changes in the SSM Parameter Store (per configs in args)"""
25+
remote, local = configure_endpoints(args)
26+
local.save(remote.pull(local.get()))
2727

2828

2929
def apply(args):
30-
r, _, diff = plan(args)
31-
30+
"""Apply local changes to the SSM Parameter Store"""
31+
remote, local = configure_endpoints(args)
3232
print("\nApplying changes...")
3333
try:
34-
r.apply(diff)
34+
remote.push(local.get())
3535
except Exception as e:
3636
print("Failed to apply changes to remote:", e)
3737
print("Done.")
3838

3939

4040
def plan(args):
41-
r, l = RemoteState(args.profile), LocalState(args.filename)
42-
diff = helpers.FlatDictDiffer(r.get(paths=args.path), l.get(paths=args.path))
41+
"""Print a representation of the changes that would be applied to SSM Parameter Store if applied (per config in args)"""
42+
remote, local = configure_endpoints(args)
43+
diff = remote.dry_run(local.get())
4344

4445
if diff.differ:
45-
diff.print_state()
46+
print(diff.describe_diff())
4647
else:
4748
print("Remote state is up to date.")
4849

49-
return r, l, diff
50+
return remote, local, diff
5051

5152

5253
if __name__ == "__main__":

states/helpers.py

Lines changed: 76 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,101 @@
1-
from termcolor import colored
2-
from copy import deepcopy
31
import collections
2+
from copy import deepcopy
3+
from functools import partial
44

5+
from termcolor import colored
56

6-
class FlatDictDiffer(object):
7-
def __init__(self, ref, target):
8-
self.ref, self.target = ref, target
9-
self.ref_set, self.target_set = set(ref.keys()), set(target.keys())
10-
self.isect = self.ref_set.intersection(self.target_set)
7+
8+
class DiffResolver(object):
9+
"""Determines diffs between two dicts, where the remote copy is considered the baseline"""
10+
def __init__(self, remote, local, force=False):
11+
self.remote_flat, self.local_flat = self._flatten(remote), self._flatten(local)
12+
self.remote_set, self.local_set = set(self.remote_flat.keys()), set(self.local_flat.keys())
13+
self.intersection = self.remote_set.intersection(self.local_set)
14+
self.force = force
1115

1216
if self.added() or self.removed() or self.changed():
1317
self.differ = True
1418
else:
1519
self.differ = False
1620

21+
@classmethod
22+
def configure(cls, *args, **kwargs):
23+
return partial(cls, *args, **kwargs)
24+
1725
def added(self):
18-
return self.target_set - self.isect
26+
"""Returns a (flattened) dict of added leaves i.e. {"full/path": value, ...}"""
27+
return self.local_set - self.intersection
1928

2029
def removed(self):
21-
return self.ref_set - self.isect
30+
"""Returns a (flattened) dict of removed leaves i.e. {"full/path": value, ...}"""
31+
return self.remote_set - self.intersection
2232

2333
def changed(self):
24-
return set(k for k in self.isect if self.ref[k] != self.target[k])
34+
"""Returns a (flattened) dict of changed leaves i.e. {"full/path": value, ...}"""
35+
return set(k for k in self.intersection if self.remote_flat[k] != self.local_flat[k])
2536

2637
def unchanged(self):
27-
return set(k for k in self.isect if self.ref[k] == self.target[k])
38+
"""Returns a (flattened) dict of unchanged leaves i.e. {"full/path": value, ...}"""
39+
return set(k for k in self.intersection if self.remote_flat[k] == self.local_flat[k])
2840

29-
def print_state(self):
41+
def describe_diff(self):
42+
"""Return a (multi-line) string describing all differences"""
43+
description = ""
3044
for k in self.added():
31-
print(colored("+", 'green'), "{} = {}".format(k, self.target[k]))
45+
description += colored("+", 'green'), "{} = {}".format(k, self.local_flat[k]) + '\n'
3246

3347
for k in self.removed():
34-
print(colored("-", 'red'), k)
48+
description += colored("-", 'red'), k + '\n'
3549

3650
for k in self.changed():
37-
print(colored("~", 'yellow'), "{}:\n\t< {}\n\t> {}".format(k, self.ref[k], self.target[k]))
38-
39-
40-
def flatten(d, pkey='', sep='/'):
41-
items = []
42-
for k in d:
43-
new = pkey + sep + k if pkey else k
44-
if isinstance(d[k], collections.MutableMapping):
45-
items.extend(flatten(d[k], new, sep=sep).items())
51+
description += colored("~", 'yellow'), "{}:\n\t< {}\n\t> {}".format(k, self.remote_flat[k], self.local_flat[k]) + '\n'
52+
53+
return description
54+
55+
def _flatten(self, d, current_path='', sep='/'):
56+
"""Convert a nested dict structure into a "flattened" dict i.e. {"full/path": "value", ...}"""
57+
items = []
58+
for k in d:
59+
new = current_path + sep + k if current_path else k
60+
if isinstance(d[k], collections.MutableMapping):
61+
items.extend(self._flatten(d[k], new, sep=sep).items())
62+
else:
63+
items.append((sep + new, d[k]))
64+
return dict(items)
65+
66+
def _unflatten(self, d, sep='/'):
67+
"""Converts a "flattened" dict i.e. {"full/path": "value", ...} into a nested dict structure"""
68+
output = {}
69+
for k in d:
70+
add(
71+
obj=output,
72+
path=k,
73+
value=d[k],
74+
sep=sep,
75+
)
76+
return output
77+
78+
def merge(self):
79+
"""Generate a merge of the local and remote dicts, following configurations set during __init__"""
80+
dictfilter = lambda original, keep_keys: dict([(i, original[i]) for i in original if i in set(keep_keys)])
81+
if self.force:
82+
# Overwrite local changes (i.e. only preserve added keys)
83+
# NOTE: Currently the system cannot tell the difference between a remote delete and a local add
84+
prior_set = self.changed().union(self.removed()).union(self.unchanged())
85+
current_set = self.added()
4686
else:
47-
items.append((sep + new, d[k]))
48-
return dict(items)
49-
50-
51-
def add(obj, path, value):
52-
parts = path.strip("/").split("/")
87+
# Preserve added keys and changed keys
88+
# NOTE: Currently the system cannot tell the difference between a remote delete and a local add
89+
prior_set = self.unchanged().union(self.removed())
90+
current_set = self.added().union(self.changed())
91+
state = dictfilter(original=self.remote_flat, keep_keys=prior_set)
92+
state.update(dictfilter(original=self.local_flat, keep_keys=current_set))
93+
return self._unflatten(state)
94+
95+
96+
def add(obj, path, value, sep='/'):
97+
"""Add value to the `obj` dict at the specified path"""
98+
parts = path.strip(sep).split(sep)
5399
last = len(parts) - 1
54100
for index, part in enumerate(parts):
55101
if index == last:
@@ -61,7 +107,7 @@ def add(obj, path, value):
61107
def search(state, path):
62108
result = state
63109
for p in path.strip("/").split("/"):
64-
if result.get(p):
110+
if result.clone(p):
65111
result = result[p]
66112
else:
67113
result = {}
@@ -71,16 +117,6 @@ def search(state, path):
71117
return output
72118

73119

74-
def unflatten(d):
75-
output = {}
76-
for k in d:
77-
add(
78-
obj=output,
79-
path=k,
80-
value=d[k])
81-
return output
82-
83-
84120
def merge(a, b):
85121
if not isinstance(b, dict):
86122
return b

0 commit comments

Comments
 (0)