Skip to content
Merged
26 changes: 25 additions & 1 deletion medcat-v2/medcat/utils/download_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
import importlib.metadata
import tempfile
import zipfile
import sys
from pathlib import Path
import requests
import logging
import argparse
import re


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -67,7 +69,7 @@ def _determine_url(overwrite_url: str | None,
else:
tag = _find_latest_scripts_tag(version)

logger.info("Fetching scripts for MedCAT %s → tag %s}",
logger.info("Fetching scripts for MedCAT %s → tag %s",
version, tag)

# Download the GitHub auto-generated zipball
Expand Down Expand Up @@ -110,6 +112,23 @@ def _extract_zip(dest: Path, zip_path: Path):
logger.info("Scripts extracted to: %s", dest)


def _fix_requirements(dest: Path, current_version: str):
requirements_file = dest / "requirements.txt"
original = requirements_file.read_text(encoding="utf-8")

updated, count = re.subn(
pattern=r"(medcat\[.*?\])[><=!~]+[\d.]+",
repl=rf"\1~={current_version}",
string=original,
)

if count == 0:
return

requirements_file.write_text(updated, encoding="utf-8")



def fetch_scripts(destination: str | Path = ".",
overwrite_url: str | None = None,
overwrite_tag: str | None = None) -> Path:
Expand All @@ -130,6 +149,11 @@ def fetch_scripts(destination: str | Path = ".",
with tempfile.NamedTemporaryFile() as tmp:
_download_zip(zip_url, tmp)
_extract_zip(dest, Path(tmp.name))
_fix_requirements(dest, _get_medcat_version())
logger.info(
"You also need to install the requiements by doing:\n"
"%s -m pip install -r %s/requirements.txt",
sys.executable, str(destination))
return dest


Expand Down
34 changes: 34 additions & 0 deletions medcat-v2/tests/utils/test_download_scripts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from medcat.utils import download_scripts

import os
import unittest
import unittest.mock
import tempfile


class ScriptsDownloadTest(unittest.TestCase):
use_version = "2.5"

@classmethod
def setUpClass(cls):
cls._temp_dir = tempfile.TemporaryDirectory()
with unittest.mock.patch(
"medcat.utils.download_scripts._get_medcat_version"
) as mock_get_version:
mock_get_version.return_value = cls.use_version
cls.scripts_path = download_scripts.fetch_scripts(cls._temp_dir.name)

def test_can_download(self):
self.assertTrue(os.path.exists(self.scripts_path))
self.assertTrue(os.path.isdir(self.scripts_path))
self.assertTrue(os.listdir(self.scripts_path))

def test_has_requirements(self):
self.assertIn('requirements.txt', os.listdir(self.scripts_path))

def test_requirements_define_correct_version(self):
req_path = os.path.join(self.scripts_path, 'requirements.txt')
with open(req_path) as f:
medcat_line = [line.strip() for line in f if "medcat" in line][0]
self.assertIn(self.use_version, medcat_line)
self.assertTrue(medcat_line.endswith(self.use_version))
Loading