Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 50 additions & 51 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"log/slog"
"slices"
"sync"
"time"

"github.com/cooldogedev/spectrum/protocol"
spectrumpacket "github.com/cooldogedev/spectrum/server/packet"
Expand Down Expand Up @@ -64,6 +63,8 @@ type Conn struct {
deferredPackets []any
expectedIds []uint32

onConnect func(err error)

connected chan struct{}
spawned chan struct{}
once sync.Once
Expand Down Expand Up @@ -96,7 +97,7 @@ func NewConn(conn io.ReadWriteCloser, client *minecraft.Conn, logger *slog.Logge
connected: make(chan struct{}),
spawned: make(chan struct{}),
}
c.ctx, c.cancelFunc = context.WithCancelCause(client.Context())
c.ctx, c.cancelFunc = context.WithCancelCause(context.Background())
c.expect(spectrumpacket.IDConnectionResponse)
return c
}
Expand Down Expand Up @@ -162,26 +163,45 @@ func (c *Conn) Write(p []byte) error {
return c.writer.Write(snappy.Encode(nil, p))
}

// Connect initiates the connection sequence with a default timeout of 1 minute.
func (c *Conn) Connect() error {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
return c.ConnectContext(ctx)
}
// DoConnect sends a ConnectionRequest packet to initiate the connection sequence.
func (c *Conn) DoConnect() error {
select {
case <-c.ctx.Done():
return context.Cause(c.ctx)
default:
}

// ConnectTimeout initiates the connection sequence with the specified timeout duration.
func (c *Conn) ConnectTimeout(duration time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), duration)
defer cancel()
return c.ConnectContext(ctx)
}
clientData, err := json.Marshal(c.client.ClientData())
if err != nil {
return err
}

// ConnectContext initiates the connection sequence using the provided context for cancellation.
func (c *Conn) ConnectContext(ctx context.Context) error {
if err := c.sendConnectionRequest(); err != nil {
identityData, err := json.Marshal(c.client.IdentityData())
if err != nil {
return err
}

err = c.WritePacket(&spectrumpacket.ConnectionRequest{
Addr: c.client.RemoteAddr().String(),
ProtocolID: c.protocol.ID(),
ClientData: clientData,
IdentityData: identityData,
Cache: c.cache,
})
if err != nil {
return err
}
c.logger.Debug("sent connection_request, expecting connection_response")
return nil
}

// OnConnect invokes the provided function once the connection sequence is complete or has failed.
func (c *Conn) OnConnect(fn func(error)) {
c.onConnect = fn
}

// WaitConnect blocks until the connection sequence has completed or the provided context is canceled.
func (c *Conn) WaitConnect(ctx context.Context) error {
select {
case <-c.ctx.Done():
return context.Cause(c.ctx)
Expand All @@ -204,14 +224,6 @@ func (c *Conn) DoSpawn() error {
return c.WritePacket(&packet.SetLocalPlayerAsInitialised{EntityRuntimeID: c.runtimeID})
}

// Conn returns the underlying connection.
// Direct access to the underlying connection through this method is
// strongly discouraged due to the potential for unpredictable behavior.
// Use this method only when absolutely necessary.
func (c *Conn) Conn() io.ReadWriteCloser {
return c.conn
}

// GameData returns the game data set for the connection by the StartGame packet.
func (c *Conn) GameData() minecraft.GameData {
return c.gameData
Expand All @@ -237,6 +249,16 @@ func (c *Conn) Close() error {
// CloseWithError closes the underlying connection.
func (c *Conn) CloseWithError(err error) {
c.once.Do(func() {
var connected bool
select {
case <-c.connected:
connected = true
default:
}

if !connected && c.onConnect != nil {
c.onConnect(err)
}
c.cancelFunc(err)
_ = c.conn.Close()
})
Expand Down Expand Up @@ -296,32 +318,6 @@ func (c *Conn) expect(ids ...uint32) {
c.expectedIds = ids
}

// sendConnectionRequest initiates the connection sequence by sending a ConnectionRequest packet to the underlying connection.
func (c *Conn) sendConnectionRequest() error {
clientData, err := json.Marshal(c.client.ClientData())
if err != nil {
return err
}

identityData, err := json.Marshal(c.client.IdentityData())
if err != nil {
return err
}

err = c.WritePacket(&spectrumpacket.ConnectionRequest{
Addr: c.client.RemoteAddr().String(),
ProtocolID: c.protocol.ID(),
ClientData: clientData,
IdentityData: identityData,
Cache: c.cache,
})
if err != nil {
return err
}
c.logger.Debug("sent connection_request, expecting connection_response")
return nil
}

// handlePacket handles an expected packet that was received before the connection sequence finalization.
func (c *Conn) handlePacket(p packet.Packet) (err error) {
var pks []packet.Packet
Expand Down Expand Up @@ -441,6 +437,9 @@ func (c *Conn) handleChunkRadiusUpdated(pk *packet.ChunkRadiusUpdated) error {
func (c *Conn) handlePlayStatus(pk *packet.PlayStatus) error {
c.deferPacket(pk)
close(c.connected)
if c.onConnect != nil {
c.onConnect(nil)
}
c.logger.Debug("received play_status, finalizing connection sequence")
return nil
}
13 changes: 4 additions & 9 deletions session/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,13 @@ loop:
}

server := s.Server()
select {
case <-server.Context().Done():
if !s.fallbackInProcess.Load() {
go s.fallback()
}
continue loop
default:
}

pk, err := server.ReadPacket()
if err != nil {
server.CloseWithError(fmt.Errorf("failed to read packet from server: %w", err))
if err := s.fallback(); err != nil {
s.CloseWithError(fmt.Errorf("fallback failed: %w", err))
break loop
}
continue loop
}

Expand Down
99 changes: 47 additions & 52 deletions session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,10 @@ type Session struct {
processor Processor
processorMu sync.RWMutex

cache atomic.Value
latency atomic.Int64
transferring atomic.Bool
fallbackInProcess atomic.Bool
once sync.Once
cache atomic.Value
latency atomic.Int64
inFallback atomic.Bool
once sync.Once
}

// NewSession creates a new Session instance using the provided minecraft.Conn.
Expand Down Expand Up @@ -103,14 +102,16 @@ func (s *Session) LoginContext(ctx context.Context) (err error) {
return err
}

s.serverMu.Lock()
s.serverAddr = serverAddr
s.serverConn = conn
s.serverMu.Unlock()
go handleServer(s)
go handleClient(s)
go handleLatency(s, s.opts.LatencyInterval)
if err := conn.ConnectContext(ctx); err != nil {
if err := conn.DoConnect(); err != nil {
s.logger.Debug("connection sequence failed", "err", err)
return err
}

if err := conn.WaitConnect(s.ctx); err != nil {
conn.CloseWithError(fmt.Errorf("connection sequence failed: %w", err))
s.logger.Debug("connection sequence failed", "err", err)
return err
}
Expand Down Expand Up @@ -151,49 +152,45 @@ func (s *Session) TransferTimeout(addr string, duration time.Duration) (err erro
// occurs at a time, returning an error if another transfer is already in progress.
// The process is performed using the provided context for cancellation.
func (s *Session) TransferContext(ctx context.Context, addr string) (err error) {
if !s.transferring.CompareAndSwap(false, true) {
return errors.New("already transferring")
}

defer s.transferring.Store(false)
s.serverMu.RLock()
origin := s.serverAddr
s.serverMu.RUnlock()
processorCtx := NewContext()
s.Processor().ProcessPreTransfer(processorCtx, &s.serverAddr, &addr)
s.Processor().ProcessPreTransfer(processorCtx, &origin, &addr)
if processorCtx.Cancelled() {
return errors.New("processor failed")
}

s.serverMu.RLock()
origin := s.serverAddr
s.serverMu.RUnlock()
s.sendMetadata(true)
conn, err := s.dial(ctx, addr)
defer func() {
if err != nil {
s.sendMetadata(false)
s.Processor().ProcessTransferFailure(NewContext(), &s.serverAddr, &addr)
}
}()
if err != nil {
s.logger.Debug("dialer failed", "err", err)
return err
s.Processor().ProcessTransferFailure(NewContext(), &origin, &addr)
return fmt.Errorf("dialer failed: %w", err)
}

if err := conn.ConnectContext(ctx); err != nil {
conn.CloseWithError(fmt.Errorf("connection sequence failed: %w", err))
s.logger.Debug("connection sequence failed", "err", err)
return err
if err := conn.DoConnect(); err != nil {
s.Processor().ProcessTransferFailure(NewContext(), &origin, &addr)
return fmt.Errorf("connection sequence failed failed: %w", err)
}

gameData := conn.GameData()
s.animation.Play(s.client, gameData)
s.sendGameData(conn.GameData())
if err := conn.DoSpawn(); err != nil {
conn.CloseWithError(fmt.Errorf("spawn sequence failed: %w", err))
return err
}
s.animation.Clear(s.client, gameData)
s.Processor().ProcessPostTransfer(NewContext(), &origin, &addr)
s.logger.Debug("transferred session", "origin", origin, "target", addr)
conn.OnConnect(func(err error) {
if err != nil {
s.Processor().ProcessTransferFailure(NewContext(), &origin, &addr)
return
}

gameData := conn.GameData()
s.animation.Play(s.client, gameData)
s.sendGameData(conn.GameData())
if err := conn.DoSpawn(); err != nil {
s.Processor().ProcessTransferFailure(NewContext(), &origin, &addr)
return
}
s.inFallback.Store(false)
s.animation.Clear(s.client, gameData)
s.Processor().ProcessPostTransfer(NewContext(), &origin, &addr)
s.logger.Debug("transferred session", "origin", origin, "target", addr)
})
return nil
}

Expand Down Expand Up @@ -283,7 +280,7 @@ func (s *Session) CloseWithError(err error) {
}
s.serverMu.RUnlock()
s.registry.RemoveSession(s.client.IdentityData().XUID)
s.logger.Info("closed session")
s.logger.Info("closed session", "err", err)
})
}

Expand Down Expand Up @@ -313,29 +310,27 @@ func (s *Session) dial(ctx context.Context, addr string) (*server.Conn, error) {
}

// fallback attempts to transfer the session to a fallback server provided by the discovery.
func (s *Session) fallback() {
func (s *Session) fallback() error {
select {
case <-s.ctx.Done():
return
return context.Cause(s.ctx)
default:
}

if !s.fallbackInProcess.CompareAndSwap(false, true) {
return
if !s.inFallback.CompareAndSwap(false, true) {
return errors.New("already in fallback")
}

defer s.fallbackInProcess.Store(false)
addr, err := s.discovery.DiscoverFallback(s.client)
if err != nil {
s.CloseWithError(err)
return
return fmt.Errorf("discovery failed: %w", err)
}

s.logger.Debug("transferring session to a fallback server", "addr", addr)
if err := s.Transfer(addr); err != nil {
s.CloseWithError(fmt.Errorf("failed to transfer to fallback server: %w", err))
return
return fmt.Errorf("transfer failed: %w", err)
}
s.logger.Info("transferred session to a fallback server", "addr", addr)
return nil
}

func (s *Session) sendMetadata(noAI bool) {
Expand Down