Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions server/dbutil/dbcopy.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@ package dbutil
import (
"database/sql"
"fmt"
"log"
"strconv"
"strings"

"github.com/sirupsen/logrus"
)

const (
Expand Down Expand Up @@ -118,7 +119,7 @@ func genericVals(nColumns int) []interface{} {
// fixSequences fixes the PostgreSQL sequences setting them to the max id found
// This should not be needed for readonly DBs, but his way the DB is ready in
// case new rows are added
func fixSequences(db DB, logger *log.Logger) {
func fixSequences(db DB, logger logrus.FieldLogger) {
if db.Driver != Postgres {
return
}
Expand Down
10 changes: 5 additions & 5 deletions server/dbutil/dbutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ import (
"crypto/md5"
"database/sql"
"fmt"
"log"
"os"
"regexp"
"strings"

"github.com/sirupsen/logrus"

// loads the driver
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
Expand Down Expand Up @@ -177,16 +178,15 @@ func Initialize(db DB) error {
// Options for the ImportFiles and Copy methods.
// Logger is optional, if it is not provided the default stderr will be used.
type Options struct {
Logger *log.Logger
Logger logrus.FieldLogger
}

func (opts *Options) getLogger() *log.Logger {
func (opts *Options) getLogger() logrus.FieldLogger {
if opts.Logger != nil {
return opts.Logger
}

return log.New(os.Stderr, "", log.LstdFlags) // Default log to stderr

return logrus.StandardLogger()
}

// ImportFiles imports pairs of files from the origin to the destination DB.
Expand Down
25 changes: 4 additions & 21 deletions server/handler/experiments_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
package handler_test

import (
"database/sql"
"net/http"
"strings"
"testing"

"github.com/src-d/code-annotation/server/dbutil"
"github.com/src-d/code-annotation/server/handler"
"github.com/src-d/code-annotation/server/model"
"github.com/src-d/code-annotation/server/repository"
Expand All @@ -18,32 +16,17 @@ func TestCreateExperiment(t *testing.T) {
assert := assert.New(t)

db := testDB()
repo := repository.NewExperiments(db)
repo := repository.NewExperiments(db.DB)
handler := handler.CreateExperiment(repo)

json := `{"name": "new", "description": "test"}`
req, _ := http.NewRequest("POST", "/experiments", strings.NewReader(json))
res, err := handler(req)
assert.Nil(err)

assert.Equal(res, serializer.NewExperimentResponse(&model.Experiment{
ID: 1,
assert.Equal(serializer.NewExperimentResponse(&model.Experiment{
ID: 2,
Name: "new",
Description: "test",
}, 0))
}

func testDB() *sql.DB {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
panic(err)
}
err = dbutil.Bootstrap(dbutil.DB{
DB: db,
Driver: dbutil.Sqlite,
})
if err != nil {
panic(err)
}
return db
}, 0), res)
}
49 changes: 49 additions & 0 deletions server/handler/file_pairs.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
package handler

import (
"fmt"
"io"
"io/ioutil"
"net/http"
"os"
"strings"

"github.com/pressly/lg"
"github.com/src-d/code-annotation/server/dbutil"
"github.com/src-d/code-annotation/server/repository"
"github.com/src-d/code-annotation/server/serializer"
)
Expand Down Expand Up @@ -53,3 +59,46 @@ func GetFilePairs(repo *repository.FilePairs) RequestProcessFunc {
return serializer.NewListFilePairsResponse(filePairs), nil
}
}

// UploadFilePairs returns a function that imports file pair from import db file to the experiment
func UploadFilePairs(db *dbutil.DB) RequestProcessFunc {
return func(r *http.Request) (*serializer.Response, error) {
experimentID, err := urlParamInt(r, "experimentId")
if err != nil {
return nil, err
}

file, _, err := r.FormFile("input_db")
if err != nil {
return nil, serializer.NewHTTPError(http.StatusBadRequest, err.Error())
}
defer file.Close()

// need to save on disk to be able to open using sql.Open
tmpfile, err := ioutil.TempFile("", "input_db")
if err != nil {
return nil, fmt.Errorf("can't open tmp file for db %s", err)
}
defer os.Remove(tmpfile.Name())

if _, err := io.Copy(tmpfile, file); err != nil {
return nil, fmt.Errorf("can't copy content to tmp db file %s", err)
}
if err := tmpfile.Close(); err != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

\n missing

return nil, fmt.Errorf("can't close tmp db file %s", err)
}

inputDB, err := dbutil.OpenSQLite(tmpfile.Name(), false)
if err != nil {
return nil, fmt.Errorf("can't open input db %s", err)
}

success, failures, err := dbutil.ImportFiles(
inputDB, *db, dbutil.Options{Logger: lg.RequestLog(r)}, experimentID)
if err != nil {
return nil, err
}

return serializer.NewFilePairsUploadResponse(success, failures), nil
}
}
91 changes: 91 additions & 0 deletions server/handler/file_pairs_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package handler_test

import (
"bytes"
"database/sql"
"io/ioutil"
"mime/multipart"
"net/http"
"os"
"path/filepath"
"testing"

"github.com/src-d/code-annotation/server/handler"
"github.com/src-d/code-annotation/server/serializer"
"github.com/stretchr/testify/assert"
)

func TestUploadFilePairs(t *testing.T) {
assert := assert.New(t)

// create import db
dbPath := filepath.Join(os.TempDir(), "cat_test_upload_file_pairs_import_db.db")
defer os.Remove(dbPath)

importDB, err := sql.Open("sqlite3", dbPath)
defer importDB.Close()
if err != nil {
t.Fatalf("can't create import db for test %s", err)
}
sqlQuery, err := ioutil.ReadFile("testdata/import_db.sql")
if err != nil {
t.Fatalf("can't read sql fixture %s", err)
}
if _, err := importDB.Exec(string(sqlQuery)); err != nil {
t.Fatalf("can't apply sql fixture %s", err)
}
importDB.Close()

// create db & handler
db := testDB()
handler := handler.UploadFilePairs(db)

req, err := newFileUploadRequest("/experiments/1/file-pairs", nil, "input_db", dbPath)
req = chiRequest(req, map[string]string{"experimentId": "1"})
if err != nil {
t.Fatalf("can't create file upload request %s", err)
}
res, err := handler(req)
assert.Nil(err)

assert.Equal(serializer.NewFilePairsUploadResponse(2, 0), res)
}

func newFileUploadRequest(uri string, params map[string]string, paramName, path string) (*http.Request, error) {
file, err := os.Open(path)
if err != nil {
return nil, err
}
fileContents, err := ioutil.ReadAll(file)
if err != nil {
return nil, err
}
fi, err := file.Stat()
if err != nil {
return nil, err
}
file.Close()

body := new(bytes.Buffer)
writer := multipart.NewWriter(body)
part, err := writer.CreateFormFile(paramName, fi.Name())
if err != nil {
return nil, err
}
part.Write(fileContents)

for key, val := range params {
_ = writer.WriteField(key, val)
}
err = writer.Close()
if err != nil {
return nil, err
}

r, err := http.NewRequest("POST", uri, body)
if err != nil {
return nil, err
}
r.Header.Set("Content-Type", writer.FormDataContentType())
return r, nil
}
43 changes: 43 additions & 0 deletions server/handler/helpers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package handler_test

import (
"context"
"database/sql"
"net/http"

"github.com/go-chi/chi"
"github.com/pressly/lg"
"github.com/src-d/code-annotation/server/dbutil"
)

func testDB() *dbutil.DB {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
panic(err)
}
dbWrapper := dbutil.DB{
DB: db,
Driver: dbutil.Sqlite,
}
if err = dbutil.Bootstrap(dbWrapper); err != nil {
panic(err)
}
if dbutil.Initialize(dbWrapper); err != nil {
panic(err)
}
return &dbWrapper
}

func chiRequest(req *http.Request, params map[string]string) *http.Request {
ctx := lg.WithLoggerContext(req.Context(), lg.DefaultLogger)

c := chi.NewRouteContext()
if params != nil {
for name, value := range params {
c.URLParams.Add(name, value)
ctx = context.WithValue(ctx, chi.RouteCtxKey, c)
}
}

return req.WithContext(ctx)
}
Loading