Skip to content

Commit

Permalink
MLBF
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinMind committed Sep 26, 2024
1 parent ecd4748 commit ba87c89
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ def add_arguments(self, parser):
'the database',
default=None,
)
parser.add_argument(
'--block-type',
help='Block type to export',
required=True,
choices=['hard', 'soft'],
)

def load_json(self, json_path):
with open(json_path) as json_file:
Expand All @@ -37,7 +43,7 @@ def load_json(self, json_path):

def handle(self, *args, **options):
log.debug('Exporting blocklist to file')
mlbf = MLBF.generate_from_db(options.get('id'))
mlbf = MLBF.generate_from_db(options.get('id'), options.get('block_type'))

if options.get('block_guids_input'):
mlbf.blocked_items = list(
Expand Down
37 changes: 22 additions & 15 deletions src/olympia/blocklist/mlbf.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,24 @@ def generate_mlbf(stats, blocked, not_blocked):
return cascade


def fetch_blocked_from_db():
def fetch_blocked_from_db(block_type='hard'):
from olympia.blocklist.models import BlockVersion

qs = BlockVersion.objects.filter(version__file__is_signed=True).values_list(
if block_type == 'soft':
qs = BlockVersion.objects.soft_blocked()
else:
qs = BlockVersion.objects.hard_blocked()

blocked_versions = qs.values_list(
'block__guid', 'version__version', 'version_id', named=True
)

all_versions = {
block_version.version_id: (
block_version.block__guid,
block_version.version__version,
)
for block_version in qs
for block_version in blocked_versions
}
return all_versions

Expand All @@ -74,9 +80,10 @@ def fetch_all_versions_from_db(excluding_version_ids=None):
class MLBF:
KEY_FORMAT = '{guid}:{version}'

def __init__(self, id_):
def __init__(self, id_, block_type):
# simplify later code by assuming always a string
self.id = str(id_)
self.block_type = block_type
self.storage = SafeStorage(root_setting='MLBF_STORAGE_PATH')

@classmethod
Expand All @@ -89,7 +96,7 @@ def hash_filter_inputs(cls, input_list):

@property
def _blocked_path(self):
return os.path.join(settings.MLBF_STORAGE_PATH, self.id, 'blocked.json')
return os.path.join(settings.MLBF_STORAGE_PATH, self.id, f'blocked_{self.block_type}.json')

@cached_property
def blocked_items(self):
Expand All @@ -103,7 +110,7 @@ def write_blocked_items(self):

@property
def _not_blocked_path(self):
return os.path.join(settings.MLBF_STORAGE_PATH, self.id, 'notblocked.json')
return os.path.join(settings.MLBF_STORAGE_PATH, self.id, f'notblocked_{self.block_type}.json')

@cached_property
def not_blocked_items(self):
Expand All @@ -117,11 +124,11 @@ def write_not_blocked_items(self):

@property
def filter_path(self):
return os.path.join(settings.MLBF_STORAGE_PATH, self.id, 'filter')
return os.path.join(settings.MLBF_STORAGE_PATH, self.id, f'filter_{self.block_type}')

@property
def _stash_path(self):
return os.path.join(settings.MLBF_STORAGE_PATH, self.id, 'stash.json')
return os.path.join(settings.MLBF_STORAGE_PATH, self.id, f'stash_{self.block_type}.json')

@cached_property
def stash_json(self):
Expand All @@ -135,7 +142,7 @@ def generate_and_write_filter(self):
self.write_not_blocked_items()

bloomfilter = generate_mlbf(
stats=stats, blocked=self.blocked_items, not_blocked=self.not_blocked_items
stats=stats, blocked=self.blocked_items, not_blocked=self.not_blocked_items, block_type=self.block_type
)

# write bloomfilter
Expand All @@ -145,7 +152,7 @@ def generate_and_write_filter(self):
bloomfilter.tofile(filter_file)
stats['mlbf_filesize'] = os.stat(mlbf_path).st_size

log.info(json.dumps(stats))
log.info(f'Generated bloom filter for {self.block_type} blocks: {json.dumps(stats)}')

@classmethod
def generate_diffs(cls, previous, current):
Expand Down Expand Up @@ -196,12 +203,12 @@ def blocks_changed_since_previous(self, previous_bloom_filter):
return len(self.blocked_items)

@classmethod
def load_from_storage(cls, *args, **kwargs):
return StoredMLBF(*args, **kwargs)
def load_from_storage(cls, id_, block_type='hard'):
return StoredMLBF(id_, block_type)

@classmethod
def generate_from_db(cls, *args, **kwargs):
return DatabaseMLBF(*args, **kwargs)
def generate_from_db(cls, id_, block_type='hard'):
return DatabaseMLBF(id_, block_type)


class StoredMLBF(MLBF):
Expand All @@ -219,7 +226,7 @@ def not_blocked_items(self):
class DatabaseMLBF(MLBF):
@cached_property
def blocked_items(self):
blocked_ids_to_versions = fetch_blocked_from_db()
blocked_ids_to_versions = fetch_blocked_from_db(self.block_type)
blocked = blocked_ids_to_versions.values()
# cache version ids so query in not_blocked_items is efficient
self._version_excludes = blocked_ids_to_versions.keys()
Expand Down
23 changes: 23 additions & 0 deletions src/olympia/blocklist/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,34 @@ def get_blocks_from_guids(cls, guids):
return blocks


class BlockVersionQuerySet(models.QuerySet):
def _blocked(self, hard):
return self.filter(hard=hard, version__file__is_signed=True)

def soft_blocked(self):
return self._blocked(hard=False)

def hard_blocked(self):
return self._blocked(hard=True)


class BlockVersionManager(ManagerBase):
_queryset_class = BlockVersionQuerySet

def soft_blocked(self):
return self.get_queryset().soft_blocked()

def hard_blocked(self):
return self.get_queryset().hard_blocked()


class BlockVersion(ModelBase):
version = models.OneToOneField(Version, on_delete=models.CASCADE)
block = models.ForeignKey(Block, on_delete=models.CASCADE)
hard = models.BooleanField(default=True)

objects = BlockVersionManager()

def __str__(self) -> str:
blocktype = 'hard' if self.hard else 'soft'
return f'Block.id={self.block_id} ({blocktype}) -> Version.id={self.version_id}'
Expand Down

0 comments on commit ba87c89

Please sign in to comment.