| 
 | 1 | +import time  | 
 | 2 | +from io import BytesIO  | 
 | 3 | +from tarfile import TarFile, TarInfo  | 
 | 4 | +from typing import Optional  | 
 | 5 | + | 
 | 6 | +import bcrypt  | 
 | 7 | +from requests import Response, get  | 
 | 8 | +from requests.auth import HTTPBasicAuth  | 
 | 9 | +from requests.exceptions import ConnectionError, ReadTimeout  | 
 | 10 | +from testcontainers.core.container import DockerContainer  | 
 | 11 | +from testcontainers.core.waiting_utils import wait_container_is_ready  | 
 | 12 | + | 
 | 13 | +class DockerRegistryContainer(DockerContainer):  | 
 | 14 | +    # https://docs.docker.com/registry/  | 
 | 15 | +    credentials_path: str = "/htpasswd/credentials.txt"  | 
 | 16 | + | 
 | 17 | +    def __init__(  | 
 | 18 | +        self,  | 
 | 19 | +        image: str = "registry:2",  | 
 | 20 | +        port: int = 5000,  | 
 | 21 | +        username: str = None,  | 
 | 22 | +        password: str = None,  | 
 | 23 | +        **kwargs,  | 
 | 24 | +    ) -> None:  | 
 | 25 | +        super().__init__(image=image, **kwargs)  | 
 | 26 | +        self.port: int = port  | 
 | 27 | +        self.username: Optional[str] = username  | 
 | 28 | +        self.password: Optional[str] = password  | 
 | 29 | +        self.with_exposed_ports(self.port)  | 
 | 30 | + | 
 | 31 | +    def _copy_credentials(self) -> None:  | 
 | 32 | +        # Create credentials and write them to the container  | 
 | 33 | +        hashed_password: str = bcrypt.hashpw(  | 
 | 34 | +            self.password.encode("utf-8"),  | 
 | 35 | +            bcrypt.gensalt(rounds=12, prefix=b"2a"),  | 
 | 36 | +        ).decode("utf-8")  | 
 | 37 | +        content = f"{self.username}:{hashed_password}".encode("utf-8")  | 
 | 38 | + | 
 | 39 | +        with BytesIO() as tar_archive_object, TarFile(  | 
 | 40 | +            fileobj=tar_archive_object, mode="w"  | 
 | 41 | +        ) as tmp_tarfile:  | 
 | 42 | +            tarinfo: TarInfo = TarInfo(name=self.credentials_path)  | 
 | 43 | +            tarinfo.size = len(content)  | 
 | 44 | +            tarinfo.mtime = time.time()  | 
 | 45 | + | 
 | 46 | +            tmp_tarfile.addfile(tarinfo, BytesIO(content))  | 
 | 47 | +            tar_archive_object.seek(0)  | 
 | 48 | +            self.get_wrapped_container().put_archive("/", tar_archive_object)  | 
 | 49 | + | 
 | 50 | +    @wait_container_is_ready(ConnectionError, ReadTimeout)  | 
 | 51 | +    def _readiness_probe(self) -> None:  | 
 | 52 | +        url: str = f"http://{self.get_registry()}/v2"  | 
 | 53 | +        if self.username and self.password:  | 
 | 54 | +            response: Response = get(url, auth=HTTPBasicAuth(self.username, self.password), timeout=1)  | 
 | 55 | +        else:  | 
 | 56 | +            response: Response = get(url, timeout=1)  | 
 | 57 | +        response.raise_for_status()  | 
 | 58 | + | 
 | 59 | +    def start(self):  | 
 | 60 | +        if self.username and self.password:  | 
 | 61 | +            self.with_env("REGISTRY_AUTH_HTPASSWD_REALM", "local-registry")  | 
 | 62 | +            self.with_env("REGISTRY_AUTH_HTPASSWD_PATH", self.credentials_path)  | 
 | 63 | +            super().start()  | 
 | 64 | +            self._copy_credentials()  | 
 | 65 | +        else:  | 
 | 66 | +            super().start()  | 
 | 67 | + | 
 | 68 | +        self._readiness_probe()              | 
 | 69 | +        return self  | 
 | 70 | + | 
 | 71 | +    def get_registry(self) -> str:  | 
 | 72 | +        host: str = self.get_container_host_ip()  | 
 | 73 | +        port: str = self.get_exposed_port(self.port)  | 
 | 74 | +        return f"{host}:{port}"  | 
0 commit comments