Skip to content

Commit

Permalink
Add Conn.CopyFrom("COPY table1 FROM STDIN", io.Reader); Issue #14
Browse files Browse the repository at this point in the history
  • Loading branch information
temoto committed Jan 22, 2013
1 parent dace7ab commit c14ef12
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 1 deletion.
63 changes: 63 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ const (
StatusDisconnected ConnStatus = iota
StatusReady
StatusProcessingQuery
StatusCopy
)

func (s ConnStatus) String() string {
Expand All @@ -73,6 +74,9 @@ func (s ConnStatus) String() string {

case StatusProcessingQuery:
return "Processing Query"

case StatusCopy:
return "Bulk Copy"
}

return "Unknown"
Expand Down Expand Up @@ -348,6 +352,65 @@ func (conn *Conn) Close() (err error) {
})
}

func (conn *Conn) copyFrom(command string, r io.Reader) int64 {
if conn.LogLevel >= LogDebug {
defer conn.logExit(conn.logEnter("*Conn.execute"))
}

conn.writeQuery(command)
conn.readBackendMessages(nil)
if stateCode := conn.state.code(); stateCode != StatusCopy {
panic("wrong state, expected: StatusCopy, have: " + stateCode.String())
return 0
}

// FIXME: magic number; wild guess without any reason.
const CopyBufferSize = 32 << 10
buf := make([]byte, CopyBufferSize)
var nr int
var err error
for {
nr, err = r.Read(buf)
if err != nil && err != io.EOF {
message := err.Error()
conn.writeFrontendMessageCode(_CopyFail)
conn.writeInt32(int32(5 + len(message)))
conn.writeString0(message)
panic(err)
}
if nr > 0 {
conn.writeFrontendMessageCode(_CopyData_FE)
conn.writeInt32(int32(4 + nr))
conn.write(buf[:nr])
conn.flush()
}
// TODO: peek backend message. Maybe there was error in data
// and we can stop sending early.
if err == io.EOF {
break
}
}
conn.writeFrontendMessageCode(_CopyDone_FE)
conn.writeInt32(4)
conn.flush()

rs := newResultSet(conn)
conn.readBackendMessages(rs)
rs.close()

return rs.rowsAffected
}

// CopyIn sends a `COPY table FROM STDIN` SQL command to the server and
// returns the number of rows affected.
func (conn *Conn) CopyFrom(command string, r io.Reader) (rowsAffected int64, err error) {
err = conn.withRecover("*Conn.CopyIn", func() {
rowsAffected = conn.copyFrom(command, r)
})

return
}

func getpgpassfilename() string {
var env string
env = os.Getenv("PGPASSFILE")
Expand Down
26 changes: 26 additions & 0 deletions conn_read.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,28 @@ func (conn *Conn) readCommandComplete(rs *ResultSet) {
}
}

// As of PostgreSQL 9.2 (protocol 3.0), CopyOutResponse and CopyBothResponse
// are exactly the same.
func (conn *Conn) readCopyInResponse() {
if conn.LogLevel >= LogDebug {
defer conn.logExit(conn.logEnter("*Conn.readCopyInResponse"))
}

// Just eat message length.
conn.readInt32()

// Just eat overall COPY format. 0 - textual, 1 - binary.
conn.readByte()

numColumns := conn.readInt16()
for i := int16(0); i < numColumns; i++ {
// Just eat column formats.
conn.readInt16()
}

conn.state = copyState{}
}

func (conn *Conn) readDataRow(rs *ResultSet) {
// Just eat message length.
conn.readInt32()
Expand Down Expand Up @@ -405,6 +427,10 @@ func (conn *Conn) readBackendMessages(rs *ResultSet) {
conn.readCommandComplete(rs)
return

case _CopyInResponse:
conn.readCopyInResponse()
return

case _DataRow:
rs.readRow()
return
Expand Down
30 changes: 29 additions & 1 deletion pgsql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package pgsql

import (
"bytes"
"errors"
"fmt"
"math"
Expand Down Expand Up @@ -862,7 +863,7 @@ func Test_Query_Exception(t *testing.T) {
IF num != 1 THEN
RAISE EXCEPTION 'FAIL!';
END IF;
RETURN 1;
END;
$$ LANGUAGE plpgsql;
Expand Down Expand Up @@ -1031,3 +1032,30 @@ func Test_Issue2_Uint64_OutOfRange(t *testing.T) {
}
})
}

func Test_Issue14_CopyFrom(t *testing.T) {
const data = "1\ts1\t\\N\ttrue\t2\n"
dataBuf := bytes.NewBufferString(data)
withConnLog(t, LogNothing, func(conn *Conn) {
if _, err := conn.Execute("TRUNCATE table1;"); err != nil {
t.Error("failed to truncate table1:", err)
return
}

if n, err := conn.CopyFrom("COPY table1 FROM STDIN;", dataBuf); err != nil && n != 1 {
t.Error("COPY failed. err:", err, "n:", n)
}

var b1, b2, b3, b4, b5 bool
if _, err := conn.Scan("SELECT id = 1, strreq = 's1', stropt IS NULL, blnreq, i32req = 2 FROM table1;",
&b1, &b2, &b3, &b4, &b5); err != nil {
t.Error("failed to SELECT table1:", err)
return
} else {
if !(b1 && b2 && b3 && b4 && b5) {
t.Error("some columns have incorrect data:", b1, b2, b3, b4, b5)
return
}
}
})
}
10 changes: 10 additions & 0 deletions state.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ func (abstractState) query(conn *Conn, rs *ResultSet, sql string) {
panic(invalidOpForStateMsg)
}

// copyState is the state that is active when the connection is used
// to exchange CopyData messages for bulk import/export.
type copyState struct {
abstractState
}

func (copyState) code() ConnStatus {
return StatusCopy
}

// disconnectedState is the initial state before a connection is established.
type disconnectedState struct {
abstractState
Expand Down

0 comments on commit c14ef12

Please sign in to comment.