Skip to content
45 changes: 45 additions & 0 deletions pkg/config/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,21 @@ var (
// defaultSessionTimeout default: 24 hour
defaultSessionTimeout = 24 * time.Hour

// defaultNetReadTimeout default: 0 (no timeout for normal operations)
defaultNetReadTimeout = time.Duration(0)

// defaultNetWriteTimeout default: 0 (no timeout for normal operations)
defaultNetWriteTimeout = time.Duration(0)

// defaultLoadLocalReadTimeout default: 60 seconds
// Timeout for reading data from client during LOAD DATA LOCAL operations
// Used to detect F5/LoadBalancer idle timeout disconnections
defaultLoadLocalReadTimeout = 60 * time.Second

// defaultLoadLocalWriteTimeout default: 60 seconds
// Timeout for writing data to client during LOAD DATA LOCAL operations
defaultLoadLocalWriteTimeout = 60 * time.Second

// defaultOBShowStatsInterval default: 1min
defaultOBShowStatsInterval = time.Minute

Expand Down Expand Up @@ -266,6 +281,20 @@ type FrontendParameters struct {
//timeout of the session. the default is 10minutes
SessionTimeout toml.Duration `toml:"sessionTimeout"`

// NetReadTimeout is the timeout for reading from the network connection. Default is 0 (no timeout).
NetReadTimeout toml.Duration `toml:"netReadTimeout"`

// NetWriteTimeout is the timeout for writing to the network connection. Default is 60 seconds.
NetWriteTimeout toml.Duration `toml:"netWriteTimeout"`

// LoadLocalReadTimeout is the timeout for reading data from client during LOAD DATA LOCAL operations.
// Used to detect F5/LoadBalancer idle timeout disconnections. Default is 60 seconds.
LoadLocalReadTimeout toml.Duration `toml:"loadLocalReadTimeout"`

// LoadLocalWriteTimeout is the timeout for writing data to client during LOAD DATA LOCAL operations.
// Default is 60 seconds.
LoadLocalWriteTimeout toml.Duration `toml:"loadLocalWriteTimeout"`

// MaxMessageSize max size for read messages from dn. Default is 10M
MaxMessageSize uint64 `toml:"max-message-size"`

Expand Down Expand Up @@ -406,6 +435,22 @@ func (fp *FrontendParameters) SetDefaultValues() {
fp.SessionTimeout.Duration = defaultSessionTimeout
}

if fp.NetReadTimeout.Duration == 0 {
fp.NetReadTimeout.Duration = defaultNetReadTimeout
}

if fp.NetWriteTimeout.Duration == 0 {
fp.NetWriteTimeout.Duration = defaultNetWriteTimeout
}

if fp.LoadLocalReadTimeout.Duration == 0 {
fp.LoadLocalReadTimeout.Duration = defaultLoadLocalReadTimeout
}

if fp.LoadLocalWriteTimeout.Duration == 0 {
fp.LoadLocalWriteTimeout.Duration = defaultLoadLocalWriteTimeout
}

if fp.SaveQueryResult == "" {
fp.SaveQueryResult = "off"
}
Expand Down
5 changes: 5 additions & 0 deletions pkg/frontend/internal_executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,11 @@ func (ip *internalProtocol) SetUserName(username string) {

func (ip *internalProtocol) Close() {}

// Disconnect for internal protocol does nothing since there's no real network connection
func (ip *internalProtocol) Disconnect() error {
return nil
}

// sendRows
// case 1: used in WriteResponse and WriteResultSetRow, which are 'copy' op
// case 2: used in Write, which is 'append' op. (deprecated)
Expand Down
84 changes: 68 additions & 16 deletions pkg/frontend/mysql_buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,25 +185,35 @@ type Conn struct {
packetInBuf int
allowedPacketSize int
timeout time.Duration
allocator *BufferAllocator
ses atomic.Pointer[holder[*Session]]
closeFunc sync.Once
service string
readTimeout time.Duration
writeTimeout time.Duration
// loadLocalReadTimeout is the timeout for reading data from client during LOAD DATA LOCAL operations
loadLocalReadTimeout time.Duration
// loadLocalWriteTimeout is the timeout for writing data to client during LOAD DATA LOCAL operations
loadLocalWriteTimeout time.Duration
allocator *BufferAllocator
ses atomic.Pointer[holder[*Session]]
closeFunc sync.Once
service string
}

// NewIOSession create a new io session
func NewIOSession(conn net.Conn, pu *config.ParameterUnit, service string) (_ *Conn, err error) {
c := &Conn{
conn: conn,
localAddr: conn.LocalAddr().String(),
remoteAddr: conn.RemoteAddr().String(),
fixBuf: MemBlock{},
dynamicWrBuf: list.New(),
allocator: &BufferAllocator{allocator: getSessionAlloc(service)},
timeout: pu.SV.SessionTimeout.Duration,
maxBytesToFlush: int(pu.SV.MaxBytesInOutbufToFlush * 1024),
allowedPacketSize: int(MaxPayloadSize),
service: service,
conn: conn,
localAddr: conn.LocalAddr().String(),
remoteAddr: conn.RemoteAddr().String(),
fixBuf: MemBlock{},
dynamicWrBuf: list.New(),
allocator: &BufferAllocator{allocator: getSessionAlloc(service)},
timeout: pu.SV.SessionTimeout.Duration,
readTimeout: pu.SV.NetReadTimeout.Duration,
writeTimeout: pu.SV.NetWriteTimeout.Duration,
loadLocalReadTimeout: pu.SV.LoadLocalReadTimeout.Duration,
loadLocalWriteTimeout: pu.SV.LoadLocalWriteTimeout.Duration,
maxBytesToFlush: int(pu.SV.MaxBytesInOutbufToFlush * 1024),
allowedPacketSize: int(MaxPayloadSize),
service: service,
}

defer func() {
Expand Down Expand Up @@ -312,7 +322,7 @@ func (c *Conn) ReadLoadLocalPacket() (_ []byte, err error) {
c.FreeLoadLocal()
}
}()
err = c.ReadNBytesIntoBuf(c.header[:], HeaderLengthOfTheProtocol)
err = c.ReadNBytesIntoBufWithTimeout(c.header[:], HeaderLengthOfTheProtocol, c.loadLocalReadTimeout)
if err != nil {
return
}
Expand All @@ -333,13 +343,29 @@ func (c *Conn) ReadLoadLocalPacket() (_ []byte, err error) {
}
}

err = c.ReadNBytesIntoBuf(c.loadLocalBuf.data, packetLength)
err = c.ReadNBytesIntoBufWithTimeout(c.loadLocalBuf.data, packetLength, c.loadLocalReadTimeout)
if err != nil {
return
}
return c.loadLocalBuf.data[:packetLength], err
}

// ReadNBytesIntoBufWithTimeout reads specified bytes from the network with a timeout.
// This is used for LOAD DATA LOCAL operations.
func (c *Conn) ReadNBytesIntoBufWithTimeout(buf []byte, n int, timeout time.Duration) error {
var err error
var read int
var readLength int
for readLength < n {
read, err = c.ReadFromConnWithTimeout(buf[readLength:n], timeout)
if err != nil {
return err
}
readLength += read
}
return err
}

func (c *Conn) FreeLoadLocal() {
c.loadLocalBuf.freeBuffUnsafe(c.allocator)
}
Expand Down Expand Up @@ -569,6 +595,23 @@ func (c *Conn) ReadFromConn(buf []byte) (int, error) {
return c.conn.Read(buf)
}

// ReadFromConnWithTimeout reads from the network with a specific timeout.
// This is used for LOAD DATA LOCAL operations where we need timeout detection.
func (c *Conn) ReadFromConnWithTimeout(buf []byte, timeout time.Duration) (int, error) {
var err error
if timeout > 0 {
err = c.conn.SetReadDeadline(time.Now().Add(timeout))
if err != nil {
return 0, err
}
// Clear the deadline after reading
defer func() {
_ = c.conn.SetReadDeadline(time.Time{})
}()
}
return c.conn.Read(buf)
}

// Append Add bytes to buffer
func (c *Conn) Append(elems ...byte) (err error) {
defer func() {
Expand Down Expand Up @@ -805,6 +848,15 @@ func (c *Conn) closeConn() error {
return err
}

// Disconnect closes the underlying network connection without full cleanup.
// This is used to forcefully disconnect the client (e.g., on timeout during LOAD DATA).
func (c *Conn) Disconnect() error {
if c.conn != nil {
return c.conn.Close()
}
return nil
}

// Reset does not release fix buffer but release dynamical buffer
// and load data local buffer
func (c *Conn) Reset() {
Expand Down
99 changes: 98 additions & 1 deletion pkg/frontend/mysql_buffer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1023,7 +1023,7 @@ func TestConn_ReadLoadLocalPacketErr(t *testing.T) {
{
resetFunc()
tConn.mod = testConnModSetReadDeadlineReturnErr
conn.timeout = time.Second * 2
conn.loadLocalReadTimeout = time.Second * 2
_, _ = tConn.Write(makePacket(payload1, 1))
read, err = conn.ReadLoadLocalPacket()
assert.NotNil(t, err)
Expand Down Expand Up @@ -1804,3 +1804,100 @@ func Test_BeginPacket(t *testing.T) {
_ = conn.Close()
assert.True(t, leakAlloc.CheckBalance())
}

// timeoutConn is a test connection that simulates read timeout
type timeoutConn struct {
testConn
readTimeout time.Duration
closed bool
}

func (tc *timeoutConn) Read(b []byte) (n int, err error) {
if tc.readTimeout > 0 {
// Simulate timeout by returning a timeout error
return 0, &timeoutError{msg: "read timeout"}
}
return tc.testConn.Read(b)
}

func (tc *timeoutConn) Close() error {
tc.closed = true
return nil
}

func (tc *timeoutConn) SetReadDeadline(t time.Time) error {
return nil
}

// timeoutError implements net.Error interface for timeout
type timeoutError struct {
msg string
}

func (e *timeoutError) Error() string { return e.msg }
func (e *timeoutError) Timeout() bool { return true }
func (e *timeoutError) Temporary() bool { return true }

func TestConn_ReadFromConn_Timeout(t *testing.T) {
leakAlloc := NewLeakCheckAllocator()
setSessionAlloc("", leakAlloc)

// Create a timeout connection
tConn := &timeoutConn{
readTimeout: 1 * time.Second,
}

sv, err := getSystemVariables("test/system_vars_config.toml")
assert.Nil(t, err)
sv.NetReadTimeout.Duration = 1 * time.Second
pu := config.NewParameterUnit(sv, nil, nil, nil)

conn, err := NewIOSession(tConn, pu, "")
assert.Nil(t, err)
assert.NotNil(t, conn)

// Test ReadFromConn with timeout
buf := make([]byte, 100)
n, err := conn.ReadFromConn(buf)

// Should return timeout error
assert.NotNil(t, err)
assert.Equal(t, 0, n)

// Check if it's a timeout error
netErr, ok := err.(net.Error)
assert.True(t, ok)
assert.True(t, netErr.Timeout())

_ = conn.Close()
}

func TestConn_ReadFromConn_NoTimeout(t *testing.T) {
leakAlloc := NewLeakCheckAllocator()
setSessionAlloc("", leakAlloc)

// Create a normal connection with data
tConn := &testConn{}
testData := []byte("hello world")
tConn.data = testData

sv, err := getSystemVariables("test/system_vars_config.toml")
assert.Nil(t, err)
sv.NetReadTimeout.Duration = 0 // No timeout
pu := config.NewParameterUnit(sv, nil, nil, nil)

conn, err := NewIOSession(tConn, pu, "")
assert.Nil(t, err)
assert.NotNil(t, conn)

// Test ReadFromConn without timeout
buf := make([]byte, 100)
n, err := conn.ReadFromConn(buf)

// Should succeed
assert.Nil(t, err)
assert.Equal(t, len(testData), n)
assert.Equal(t, testData, buf[:n])

_ = conn.Close()
}
Loading
Loading