Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion server/dbutil/dbcopy.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func genericVals(nColumns int) []interface{} {
// This should not be needed for readonly DBs, but his way the DB is ready in
// case new rows are added
func fixSequences(db DB, logger *log.Logger) {
if db.driver != postgres {
if db.Driver != Postgres {
return
}

Expand Down
28 changes: 14 additions & 14 deletions server/dbutil/dbutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@ import (
"github.com/pmezard/go-difflib/difflib"
)

type driver int
type Driver int

const (
none driver = iota
sqlite
postgres
None Driver = iota
Sqlite
Postgres
)

// DB groups a sql.DB and the driver used to initialize it
type DB struct {
*sql.DB
driver driver
Driver Driver
}

// SQLDB returns the *sql.DB
Expand Down Expand Up @@ -98,7 +98,7 @@ var (
// sugar a sqlite://path string is also accepted
func OpenSQLite(filepath string, checkExisting bool) (DB, error) {
if psReg.MatchString(filepath) {
return DB{nil, none}, fmt.Errorf(
return DB{nil, None}, fmt.Errorf(
"Invalid PostgreSQL connection string %q, a path to an SQLite file was expected",
filepath)
}
Expand All @@ -116,20 +116,20 @@ func Open(connection string, checkExisting bool) (DB, error) {
if conn := sqliteReg.FindStringSubmatch(connection); conn != nil {
if checkExisting {
if _, err := os.Stat(conn[1]); os.IsNotExist(err) {
return DB{nil, none}, fmt.Errorf("File %q does not exist", conn[1])
return DB{nil, None}, fmt.Errorf("File %q does not exist", conn[1])
}
}

db, err := sql.Open("sqlite3", conn[1])
return DB{db, sqlite}, err
return DB{db, Sqlite}, err
}

if psReg.MatchString(connection) {
db, err := sql.Open("postgres", connection)
return DB{db, postgres}, err
return DB{db, Postgres}, err
}

return DB{nil, none}, fmt.Errorf(`Connection string %q is not valid. It must be on of
return DB{nil, None}, fmt.Errorf(`Connection string %q is not valid. It must be on of
sqlite:///path/to/db.db
postgresql://[user[:password]@][netloc][:port][,...][/dbname]`, connection)
}
Expand All @@ -142,10 +142,10 @@ func Bootstrap(db DB) error {

var colType string

switch db.driver {
case sqlite:
switch db.Driver {
case Sqlite:
colType = sqliteIncrementType
case postgres:
case Postgres:
colType = posgresIncrementType
default:
return fmt.Errorf("Unknown driver type")
Expand All @@ -166,7 +166,7 @@ func Bootstrap(db DB) error {
// DB that is already initialized
func Initialize(db DB) error {
_, err := db.Exec(insertExperiments, defaultExperimentID)
if db.driver == postgres && err == nil {
if db.Driver == Postgres && err == nil {
db.Exec(alterExperimentsSequence)
}

Expand Down
36 changes: 36 additions & 0 deletions server/handler/experiments.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package handler

import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"

"github.com/src-d/code-annotation/server/model"
"github.com/src-d/code-annotation/server/repository"
"github.com/src-d/code-annotation/server/serializer"
"github.com/src-d/code-annotation/server/service"
Expand Down Expand Up @@ -85,3 +88,36 @@ func experimentProgress(repo *repository.Assignments, experimentID int, userID i

return 100.0 * float32(countComplete) / float32(countAll), nil
}

type createExperimentReq struct {
Name string `json:"name"`
Description string `json:"description"`
}

// CreateExperiment returns a function that saves the experiment as passed in the body request
func CreateExperiment(repo *repository.Experiments) RequestProcessFunc {
return func(r *http.Request) (*serializer.Response, error) {
var createExperimentReq createExperimentReq
body, err := ioutil.ReadAll(r.Body)
if err != nil {
return nil, serializer.NewHTTPError(http.StatusBadRequest, err.Error())
}

err = json.Unmarshal(body, &createExperimentReq)
if err != nil {
return nil, serializer.NewHTTPError(http.StatusBadRequest, err.Error())
}

experiment := &model.Experiment{
Name: createExperimentReq.Name,
Description: createExperimentReq.Description,
}

err = repo.Create(experiment)
if err != nil {
return nil, err
}

return serializer.NewExperimentResponse(experiment, 0), nil
}
}
49 changes: 49 additions & 0 deletions server/handler/experiments_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package handler_test

import (
"database/sql"
"net/http"
"strings"
"testing"

"github.com/src-d/code-annotation/server/dbutil"
"github.com/src-d/code-annotation/server/handler"
"github.com/src-d/code-annotation/server/model"
"github.com/src-d/code-annotation/server/repository"
"github.com/src-d/code-annotation/server/serializer"
"github.com/stretchr/testify/assert"
)

func TestCreateExperiment(t *testing.T) {
assert := assert.New(t)

db := testDB()
repo := repository.NewExperiments(db)
handler := handler.CreateExperiment(repo)

json := `{"name": "new", "description": "test"}`
req, _ := http.NewRequest("POST", "/experiments", strings.NewReader(json))
res, err := handler(req)
assert.Nil(err)

assert.Equal(res, serializer.NewExperimentResponse(&model.Experiment{
ID: 1,
Name: "new",
Description: "test",
}, 0))
}

func testDB() *sql.DB {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
panic(err)
}
err = dbutil.Bootstrap(dbutil.DB{
DB: db,
Driver: dbutil.Sqlite,
})
if err != nil {
panic(err)
}
return db
}
17 changes: 17 additions & 0 deletions server/repository/experiments.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ func (repo *Experiments) getWithQuery(queryRow scannable) (*model.Experiment, er

const selectExperimentsWhereIDSQL = `SELECT * FROM experiments WHERE id=$1`
const selectExperimentsSQL = `SELECT * FROM experiments`
const insertExperimentSQL = `INSERT INTO experiments (name, description) VALUES ($1, $2)`

// GetByID returns the Experiment with the given ID. If the Experiment does not
// exist, it returns nil, nil
Expand Down Expand Up @@ -69,3 +70,19 @@ func (repo *Experiments) GetAll() ([]*model.Experiment, error) {

return results, nil
}

// Create experiment model in database. On success the assigned ID is set
func (repo *Experiments) Create(m *model.Experiment) error {
r, err := repo.db.Exec(insertExperimentSQL, m.Name, m.Description)
if err != nil {
return err
}
newID, err := r.LastInsertId()
if err != nil {
return err
}

m.ID = int(newID)

return nil
}
2 changes: 2 additions & 0 deletions server/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ func Router(
r.Get("/me", handler.APIHandlerFunc(handler.Me(userRepo)))

r.Get("/experiments", handler.APIHandlerFunc(handler.GetExperiments(experimentRepo, assignmentRepo)))
r.With(requesterACL.Middleware).
Post("/experiments", handler.APIHandlerFunc(handler.CreateExperiment(experimentRepo)))

r.Route("/experiments/{experimentId}", func(r chi.Router) {

Expand Down