Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/lint_and_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ jobs:
- name: Install Dependences
run: pip install -r requirements.txt -r requirements-dev.txt

- name: Run Setup Script
run: python setup.py install

- name: Linting
run: flake8

Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,6 @@ venv.bak/

# mypy
.mypy_cache/

# macOS
.DS_Store
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
repos:
- repo: https://github.com/pre-commit/mirrors-autopep8
rev: '4b4928307f1e6e8c9e02570ef705364f47ddb6dc' # Use the sha / tag you want to point at
hooks:
Expand Down
96 changes: 93 additions & 3 deletions mffpy/tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
ANY KIND, either express or implied.
"""
from datetime import datetime
from os import makedirs
from os.path import join
from os import makedirs, listdir
from os.path import join, splitext

import pytest
import json
Expand All @@ -35,7 +35,7 @@ def test_writer_receives_bad_init_data():


def test_writer_doesnt_overwrite(tmpdir):
"""test that `mffpy.Writer` doesn't overwrite existing files"""
"""test that `mffpy.Writer` doesn't overwrite existing files by default"""
dirname = join(str(tmpdir), 'testdir.mff')
makedirs(dirname, exist_ok=True)
with pytest.raises(AssertionError, match='File.*exists already'):
Expand Down Expand Up @@ -75,6 +75,96 @@ def test_writer_writes(tmpdir):
assert layout.name == device


def test_writer_can_overwrite(tmpdir):
"""test that the Writer does overwrite existing files"""
dirname = join(str(tmpdir), 'testdir2.mff')
# create some data and add it to a binary writer
device = 'HydroCel GSN 256 1.0'
num_samples = 10
num_channels = 256
sampling_rate = 128
b = BinWriter(sampling_rate=sampling_rate, data_type='EEG')
data = np.random.randn(num_channels, num_samples).astype(np.float32)
b.add_block(data)
# create an mffpy.Writer and add a file info, and the binary file
W = Writer(dirname)
startdatetime = datetime.strptime(
'1984-02-18T14:00:10.000000+0100', XML._time_format)
W.addxml('fileInfo', recordTime=startdatetime)
W.add_coordinates_and_sensor_layout(device)
W.addbin(b)
W.write()

# list files
files = listdir(dirname)
assert 'info.xml' in files
assert 'coordinates.xml' in files
assert 'sensorLayout.xml' in files

# add a directory inside
makedirs(join(dirname, 'test'))

# create new writer to overwrite
b = BinWriter(sampling_rate=sampling_rate, data_type='EEG')
data2 = np.random.randn(num_channels, num_samples).astype(np.float32)
b.add_block(data2)
W = Writer(dirname, overwrite=True)
W.addbin(b)
W.write()

# compare files with
files = listdir(dirname)
assert 'info.xml' not in files
assert 'coordinates.xml' not in files
assert 'sensorLayout.xml' not in files
assert 'test' not in files

# read
R = Reader(dirname)
with pytest.raises(FileNotFoundError):
R.startdatetime
with pytest.raises(FileNotFoundError):
R.directory.filepointer('sensorLayout')
read_data = R.get_physical_samples_from_epoch(R.epochs[0])
assert 'EEG' in read_data
read_data, t0 = read_data['EEG']
assert t0 == 0.0
assert np.allclose(read_data, data2)
assert not np.allclose(read_data, data)

# test writer can 'overwrite' if there is nothing to overwrite
dirname = join(str(tmpdir), 'testdir3.mff')
W = Writer(dirname, overwrite=True)
W.addbin(b)
W.write()


def test_overwrite_mfz(tmpdir):
"""Test mffdir and mfz file are overwritten when overwrite is on"""
mfzpath = join(tmpdir, 'test.mfz')
mffpath = splitext(mfzpath)[0] + '.mff'
time1 = datetime.strptime('1984-02-18T14:00:10.000000+0100',
"%Y-%m-%dT%H:%M:%S.%f%z")
time2 = datetime.strptime('1973-10-23T14:00:10.000000+0100',
"%Y-%m-%dT%H:%M:%S.%f%z")

W1 = Writer(mfzpath)
W1.addxml('fileInfo', recordTime=time1)
W1.write()

for p in [mffpath, mfzpath]:
R1 = Reader(p)
assert R1.startdatetime == time1

W2 = Writer(mfzpath, overwrite=True)
W2.addxml('fileInfo', recordTime=time2)
W2.write()

for p in [mffpath, mfzpath]:
R2 = Reader(p)
assert R2.startdatetime == time2


def test_writer_writes_multple_bins(tmpdir):
"""test that `mffpy.Writer` can write multiple binary files"""
dirname = join(str(tmpdir), 'multiple_bins.mff')
Expand Down
18 changes: 13 additions & 5 deletions mffpy/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
ANY KIND, either express or implied.
"""
from os import makedirs
from os import makedirs, remove
from os.path import splitext, exists, join
from shutil import rmtree
from subprocess import check_output
import xml.etree.ElementTree as ET

Expand All @@ -30,7 +31,8 @@

class Writer:

def __init__(self, filename: str):
def __init__(self, filename: str, overwrite: bool = False):
self.overwrite = bool(overwrite)
self.filename = filename
self.files: Dict[str, Any] = {}
self.num_bin_files = 0
Expand All @@ -41,6 +43,8 @@ def __init__(self, filename: str):
def create_directory(self):
"""Creates the directory for the recording."""
if not self.file_created:
if self.overwrite and exists(self.mffdir):
rmtree(self.mffdir)
makedirs(self.mffdir, exist_ok=False)
self.file_created = True

Expand All @@ -59,6 +63,9 @@ def write(self):

# convert from .mff to .mfz
if self.ext == '.mfz':
mfzpath = splitext(self.mffdir)[0] + '.mfz'
if self.overwrite and exists(mfzpath):
remove(mfzpath)
check_output(['mff2mfz.py', self.mffdir])

def export_to_json(self, data):
Expand Down Expand Up @@ -124,7 +131,8 @@ def filename(self, fn: str):
"""check filename with .mff/.mfz extension does not exist"""
base, ext = splitext(fn)
assert ext in ('.mff', '.mfz', '.json')
assert not exists(fn), f"File '{fn}' exists already"
if ext == '.mfz':
assert not exists(base + '.mff')
if not self.overwrite:
assert not exists(fn), f"File '{fn}' exists already"
if ext == '.mfz':
assert not exists(base + '.mff')
self._filename = fn