Skip to content

Commit 5f28ee6

Browse files
theGOTOguybbc2
authored andcommitted
Fix out of scope error when "dest" variable is undefined
Fixes theskumar#413 whereby the NamedTemporaryFile "dest" was out of scope in the error handling portion of rewrite. The problem was initially fixed in theskumar#414 but it got reverted because of a linter error. This new commit works around that linter error.
1 parent 718307b commit 5f28ee6

File tree

2 files changed

+16
-11
lines changed

2 files changed

+16
-11
lines changed

src/dotenv/main.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -125,19 +125,17 @@ def rewrite(
125125
path: Union[str, os.PathLike],
126126
encoding: Optional[str],
127127
) -> Iterator[Tuple[IO[str], IO[str]]]:
128-
try:
129-
if not os.path.isfile(path):
130-
with open(path, "w+", encoding=encoding) as source:
131-
source.write("")
132-
with tempfile.NamedTemporaryFile(mode="w+", delete=False, encoding=encoding) as dest:
128+
if not os.path.isfile(path):
129+
with open(path, mode="w", encoding=encoding) as source:
130+
source.write("")
131+
with tempfile.NamedTemporaryFile(mode="w", encoding=encoding, delete=False) as dest:
132+
try:
133133
with open(path, encoding=encoding) as source:
134-
yield (source, dest) # type: ignore
135-
except BaseException:
136-
if os.path.isfile(dest.name):
134+
yield (source, dest)
135+
except BaseException:
137136
os.unlink(dest.name)
138-
raise
139-
else:
140-
shutil.move(dest.name, path)
137+
raise
138+
shutil.move(dest.name, path)
141139

142140

143141
def set_key(

tests/test_main.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,13 @@ def test_unset_encoding(dotenv_file):
178178
assert f.read() == ""
179179

180180

181+
def test_set_key_unauthorized_file(dotenv_file):
182+
os.chmod(dotenv_file, 0o000)
183+
184+
with pytest.raises(PermissionError):
185+
dotenv.set_key(dotenv_file, "a", "x")
186+
187+
181188
def test_unset_non_existent_file(tmp_path):
182189
nx_file = str(tmp_path / "nx")
183190
logger = logging.getLogger("dotenv.main")

0 commit comments

Comments
 (0)