Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions decompression_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,35 @@ func Test_decompressTarXz_ErrorWhenFileToCopyToNotExists(t *testing.T) {

assert.Regexp(t, "^unable to extract postgres archive:.+$", err)
}

func Test_decompressTarXz_ErrorWhenArchiveCorrupted(t *testing.T) {
tempDir, err := ioutil.TempDir("", "temp_tar_test")
if err != nil {
panic(err)
}

archive, cleanup := createTempXzArchive()

defer cleanup()

file, err := os.OpenFile(archive, os.O_WRONLY, 0664)
if err != nil {
panic(err)
}

if _, err := file.Seek(35, 0); err != nil {
panic(err)
}

if _, err := file.WriteString("someJunk"); err != nil {
panic(err)
}

if err := file.Close(); err != nil {
panic(err)
}

err = decompressTarXz(defaultTarReader, archive, tempDir)

assert.EqualError(t, err, "unable to extract postgres archive: xz: data is corrupt")
}
50 changes: 36 additions & 14 deletions remote_fetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package embeddedpostgres
import (
"archive/zip"
"bytes"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io/ioutil"
"log"
Expand All @@ -19,7 +22,8 @@ type RemoteFetchStrategy func() error
func defaultRemoteFetchStrategy(remoteFetchHost string, versionStrategy VersionStrategy, cacheLocator CacheLocator) RemoteFetchStrategy {
return func() error {
operatingSystem, architecture, version := versionStrategy()
downloadURL := fmt.Sprintf("%s/io/zonky/test/postgres/embedded-postgres-binaries-%s-%s/%s/embedded-postgres-binaries-%s-%s-%s.jar",

jarDownloadURL := fmt.Sprintf("%s/io/zonky/test/postgres/embedded-postgres-binaries-%s-%s/%s/embedded-postgres-binaries-%s-%s-%s.jar",
remoteFetchHost,
operatingSystem,
architecture,
Expand All @@ -28,32 +32,50 @@ func defaultRemoteFetchStrategy(remoteFetchHost string, versionStrategy VersionS
architecture,
version)

resp, err := http.Get(downloadURL)
jarDownloadResponse, err := http.Get(jarDownloadURL)
if err != nil {
return fmt.Errorf("unable to connect to %s", remoteFetchHost)
}

defer func() {
if err := resp.Body.Close(); err != nil {
log.Fatal(err)
}
}()
defer closeBody(jarDownloadResponse)()

if resp.StatusCode != http.StatusOK {
if jarDownloadResponse.StatusCode != http.StatusOK {
return fmt.Errorf("no version found matching %s", version)
}

return decompressResponse(resp, cacheLocator, downloadURL)
jarBodyBytes, err := ioutil.ReadAll(jarDownloadResponse.Body)
if err != nil {
return errorFetchingPostgres(err)
}

shaDownloadURL := fmt.Sprintf("%s.sha256", jarDownloadURL)
shaDownloadResponse, err := http.Get(shaDownloadURL)

defer closeBody(shaDownloadResponse)()

if err == nil && shaDownloadResponse.StatusCode == http.StatusOK {
if shaBodyBytes, err := ioutil.ReadAll(shaDownloadResponse.Body); err == nil {
jarChecksum := sha256.Sum256(jarBodyBytes)
if !bytes.Equal(shaBodyBytes, []byte(hex.EncodeToString(jarChecksum[:]))) {
return errors.New("downloaded checksums do not match")
}
}
}

return decompressResponse(jarBodyBytes, jarDownloadResponse.ContentLength, cacheLocator, jarDownloadURL)
}
}

func decompressResponse(resp *http.Response, cacheLocator CacheLocator, downloadURL string) error {
bodyBytes, err := ioutil.ReadAll(resp.Body)
if err != nil {
return errorFetchingPostgres(err)
func closeBody(resp *http.Response) func() {
return func() {
if err := resp.Body.Close(); err != nil {
log.Fatal(err)
}
}
}

zipReader, err := zip.NewReader(bytes.NewReader(bodyBytes), resp.ContentLength)
func decompressResponse(bodyBytes []byte, contentLength int64, cacheLocator CacheLocator, downloadURL string) error {
zipReader, err := zip.NewReader(bytes.NewReader(bodyBytes), contentLength)
if err != nil {
return errorFetchingPostgres(err)
}
Expand Down
82 changes: 81 additions & 1 deletion remote_fetch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@ package embeddedpostgres

import (
"archive/zip"
"crypto/sha256"
"encoding/hex"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -54,7 +57,10 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenBodyReadIssue(t *testing.T) {

func Test_defaultRemoteFetchStrategy_ErrorWhenCannotUnzipSubFile(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

if strings.HasSuffix(r.RequestURI, ".sha256") {
w.WriteHeader(http.StatusNotFound)
return
}
}))
defer server.Close()

Expand All @@ -69,6 +75,11 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotUnzipSubFile(t *testing.T) {

func Test_defaultRemoteFetchStrategy_ErrorWhenCannotUnzip(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.RequestURI, ".sha256") {
w.WriteHeader(404)
return
}

if _, err := w.Write([]byte("lolz")); err != nil {
panic(err)
}
Expand All @@ -86,6 +97,11 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotUnzip(t *testing.T) {

func Test_defaultRemoteFetchStrategy_ErrorWhenNoSubTarArchive(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.RequestURI, ".sha256") {
w.WriteHeader(http.StatusNotFound)
return
}

MyZipWriter := zip.NewWriter(w)

if err := MyZipWriter.Close(); err != nil {
Expand Down Expand Up @@ -114,6 +130,11 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotExtractSubArchive(t *testing
}

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.RequestURI, ".sha256") {
w.WriteHeader(http.StatusNotFound)
return
}

bytes, err := ioutil.ReadFile(jarFile)
if err != nil {
panic(err)
Expand Down Expand Up @@ -148,6 +169,11 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotCreateCacheDirectory(t *test
cacheLocation := filepath.Join(fileBlockingExtractDirectory, "cache_file.jar")

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.RequestURI, ".sha256") {
w.WriteHeader(http.StatusNotFound)
return
}

bytes, err := ioutil.ReadFile(jarFile)
if err != nil {
panic(err)
Expand Down Expand Up @@ -181,6 +207,11 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotCreateSubArchiveFile(t *test
}

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.RequestURI, ".sha256") {
w.WriteHeader(http.StatusNotFound)
return
}

bytes, err := ioutil.ReadFile(jarFile)
if err != nil {
panic(err)
Expand All @@ -202,6 +233,44 @@ func Test_defaultRemoteFetchStrategy_ErrorWhenCannotCreateSubArchiveFile(t *test
assert.Regexp(t, "^unable to extract postgres archive:.+$", err)
}

func Test_defaultRemoteFetchStrategy_ErrorWhenSHA256NotMatch(t *testing.T) {
jarFile, cleanUp := createTempZipArchive()
defer cleanUp()

cacheLocation := filepath.Join(filepath.Dir(jarFile), "extract_location", "cache.jar")

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
bytes, err := ioutil.ReadFile(jarFile)
if err != nil {
panic(err)
}

if strings.HasSuffix(r.RequestURI, ".sha256") {
w.WriteHeader(200)
if _, err := w.Write([]byte("literallyN3verGonnaWork")); err != nil {
panic(err)
}

return
}

if _, err := w.Write(bytes); err != nil {
panic(err)
}
}))
defer server.Close()

remoteFetchStrategy := defaultRemoteFetchStrategy(server.URL+"/maven2",
testVersionStrategy(),
func() (s string, b bool) {
return cacheLocation, false
})

err := remoteFetchStrategy()

assert.EqualError(t, err, "downloaded checksums do not match")
}

func Test_defaultRemoteFetchStrategy(t *testing.T) {
jarFile, cleanUp := createTempZipArchive()
defer cleanUp()
Expand All @@ -213,6 +282,17 @@ func Test_defaultRemoteFetchStrategy(t *testing.T) {
if err != nil {
panic(err)
}

if strings.HasSuffix(r.RequestURI, ".sha256") {
w.WriteHeader(200)
contentHash := sha256.Sum256(bytes)
if _, err := w.Write([]byte(hex.EncodeToString(contentHash[:]))); err != nil {
panic(err)
}

return
}

if _, err := w.Write(bytes); err != nil {
panic(err)
}
Expand Down