Skip to content

Commit

Permalink
Use Dask temporary file utility (#5361)
Browse files Browse the repository at this point in the history
  • Loading branch information
jrbourbeau authored Sep 29, 2021
1 parent 7e2fe5c commit 43d3866
Show file tree
Hide file tree
Showing 12 changed files with 29 additions and 33 deletions.
4 changes: 3 additions & 1 deletion distributed/cli/tests/test_dask_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
import requests
from click.testing import CliRunner

from dask.utils import tmpfile

import distributed
import distributed.cli.dask_scheduler
from distributed import Client, Scheduler
from distributed.compatibility import LINUX
from distributed.metrics import time
from distributed.utils import get_ip, get_ip_interface, tmpfile
from distributed.utils import get_ip, get_ip_interface
from distributed.utils_test import (
assert_can_connect_from_everywhere_4_6,
assert_can_connect_locally_4,
Expand Down
4 changes: 3 additions & 1 deletion distributed/cli/tests/test_dask_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@

import requests

from dask.utils import tmpfile

import distributed.cli.dask_worker
from distributed import Client, Scheduler
from distributed.compatibility import LINUX
from distributed.deploy.utils import nprocesses_nthreads
from distributed.metrics import time
from distributed.utils import parse_ports, sync, tmpfile
from distributed.utils import parse_ports, sync
from distributed.utils_test import (
gen_cluster,
popen,
Expand Down
3 changes: 2 additions & 1 deletion distributed/protocol/tests/test_h5py.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

h5py = pytest.importorskip("h5py")

from dask.utils import tmpfile

from distributed.protocol import deserialize, serialize
from distributed.utils import tmpfile


def silence_h5py_issue775(func):
Expand Down
3 changes: 2 additions & 1 deletion distributed/protocol/tests/test_netcdf4.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
netCDF4 = pytest.importorskip("netCDF4")
np = pytest.importorskip("numpy")

from dask.utils import tmpfile

from distributed.protocol import deserialize, serialize
from distributed.utils import tmpfile


def create_test_dataset(fn):
Expand Down
4 changes: 3 additions & 1 deletion distributed/protocol/tests/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

np = pytest.importorskip("numpy")

from dask.utils import tmpfile

from distributed.protocol import (
decompress,
deserialize,
Expand All @@ -18,7 +20,7 @@
from distributed.protocol.pickle import HIGHEST_PROTOCOL
from distributed.protocol.utils import BIG_BYTES_SHARD_SIZE
from distributed.system import MEMORY_LIMIT
from distributed.utils import ensure_bytes, nbytes, tmpfile
from distributed.utils import ensure_bytes, nbytes
from distributed.utils_test import gen_cluster


Expand Down
3 changes: 1 addition & 2 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

import dask
from dask.highlevelgraph import HighLevelGraph
from dask.utils import format_bytes, format_time, parse_bytes, parse_timedelta
from dask.utils import format_bytes, format_time, parse_bytes, parse_timedelta, tmpfile
from dask.widgets import get_template

from . import preloading, profile
Expand Down Expand Up @@ -76,7 +76,6 @@
key_split_group,
log_errors,
no_default,
tmpfile,
validate_key,
)
from .utils_comm import gather_from_workers, retry_operation, scatter_to_workers
Expand Down
4 changes: 2 additions & 2 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import dask.bag as db
from dask import delayed
from dask.optimization import SubgraphCallable
from dask.utils import stringify
from dask.utils import stringify, tmpfile

from distributed import (
CancelledError,
Expand Down Expand Up @@ -67,7 +67,7 @@
Scheduler,
)
from distributed.sizeof import sizeof
from distributed.utils import is_valid_xml, mp_context, sync, tmp_text, tmpfile
from distributed.utils import is_valid_xml, mp_context, sync, tmp_text
from distributed.utils_test import (
TaskStateMetadataPlugin,
_UnhashableCallable,
Expand Down
3 changes: 2 additions & 1 deletion distributed/tests/test_nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@
from tornado.ioloop import IOLoop

import dask
from dask.utils import tmpfile

from distributed import Nanny, Scheduler, Worker, rpc, wait, worker
from distributed.compatibility import LINUX, WINDOWS
from distributed.core import CommClosedError, Status
from distributed.diagnostics import SchedulerPlugin
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.utils import TimeoutError, parse_ports, tmpfile
from distributed.utils import TimeoutError, parse_ports
from distributed.utils_test import captured_logger, gen_cluster, gen_test, inc

pytestmark = pytest.mark.ci1
Expand Down
4 changes: 2 additions & 2 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import dask
from dask import delayed
from dask.utils import apply, parse_timedelta, stringify, typename
from dask.utils import apply, parse_timedelta, stringify, tmpfile, typename

from distributed import Client, Nanny, Worker, fire_and_forget, wait
from distributed.comm import Comm
Expand All @@ -27,7 +27,7 @@
from distributed.metrics import time
from distributed.protocol.pickle import dumps
from distributed.scheduler import MemoryState, Scheduler
from distributed.utils import TimeoutError, tmpfile
from distributed.utils import TimeoutError
from distributed.utils_test import (
captured_logger,
cluster,
Expand Down
6 changes: 6 additions & 0 deletions distributed/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,12 @@ def test_typename_deprecated():
assert typename is dask.utils.typename


def test_tmpfile_deprecated():
with pytest.warns(FutureWarning, match="tmpfile is deprecated"):
from distributed.utils import tmpfile
assert tmpfile is dask.utils.tmpfile


def test_iscoroutinefunction_unhashable_input():
# Ensure iscoroutinefunction can handle unhashable callables
assert not iscoroutinefunction(_UnhashableCallable())
3 changes: 2 additions & 1 deletion distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import dask
from dask import delayed
from dask.system import CPU_COUNT
from dask.utils import tmpfile

import distributed
from distributed import (
Expand All @@ -38,7 +39,7 @@
from distributed.diagnostics.plugin import PipInstall
from distributed.metrics import time
from distributed.scheduler import Scheduler
from distributed.utils import TimeoutError, tmpfile
from distributed.utils import TimeoutError
from distributed.utils_test import (
TaskStateMetadataPlugin,
_LockedCommPool,
Expand Down
21 changes: 1 addition & 20 deletions distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import os
import pkgutil
import re
import shutil
import socket
import sys
import tempfile
Expand Down Expand Up @@ -837,25 +836,6 @@ def read_block(f, offset, length, delimiter=None):
return bytes


@contextmanager
def tmpfile(extension=""):
extension = "." + extension.lstrip(".")
handle, filename = tempfile.mkstemp(extension)
os.close(handle)
os.remove(filename)

yield filename

if os.path.exists(filename):
try:
if os.path.isdir(filename):
shutil.rmtree(filename)
else:
os.remove(filename)
except OSError: # sometimes we can't remove a generated temp file
pass


def ensure_bytes(s):
"""Attempt to turn `s` into bytes.
Expand Down Expand Up @@ -1435,6 +1415,7 @@ def clean_dashboard_address(addrs: AnyType, default_listen_ip: str = "") -> List
"parse_bytes": "dask.utils.parse_bytes",
"parse_timedelta": "dask.utils.parse_timedelta",
"typename": "dask.utils.typename",
"tmpfile": "dask.utils.tmpfile",
}


Expand Down

0 comments on commit 43d3866

Please sign in to comment.