Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@ module github.com/srahkmli/go-panic-proof

go 1.23.4

require google.golang.org/grpc v1.69.0
require (
go.uber.org/zap v1.27.0
google.golang.org/grpc v1.69.0
)

require (
go.uber.org/multierr v1.10.0 // indirect
golang.org/x/net v0.30.0 // indirect
golang.org/x/sys v0.26.0 // indirect
golang.org/x/text v0.19.0 // indirect
Expand Down
14 changes: 14 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
Expand All @@ -8,6 +10,10 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
go.opentelemetry.io/otel v1.31.0 h1:NsJcKPIW0D0H3NgzPDHmo0WW6SptzPdqg/L1zsIm2hY=
go.opentelemetry.io/otel v1.31.0/go.mod h1:O0C14Yl9FgkjqcCZAsE053C13OaddMYr/hz6clDkEJE=
go.opentelemetry.io/otel/metric v1.31.0 h1:FSErL0ATQAmYHUIzSezZibnyVlft1ybhy4ozRPcF2fE=
Expand All @@ -18,6 +24,12 @@ go.opentelemetry.io/otel/sdk/metric v1.31.0 h1:i9hxxLJF/9kkvfHppyLL55aW7iIJz4Jjx
go.opentelemetry.io/otel/sdk/metric v1.31.0/go.mod h1:CRInTMVvNhUKgSAMbKyTMxqOBC0zgyxzW55lZzX43Y8=
go.opentelemetry.io/otel/trace v1.31.0 h1:ffjsj1aRouKewfr85U2aGagJ46+MvodynlQ1HYdmJys=
go.opentelemetry.io/otel/trace v1.31.0/go.mod h1:TXZkRk7SM2ZQLtR6eoAWQFIHPvzQ06FJAsO1tJg480A=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
Expand All @@ -30,3 +42,5 @@ google.golang.org/grpc v1.69.0 h1:quSiOM1GJPmPH5XtU+BCoVXcDVJJAzNcoyfC2cCjGkI=
google.golang.org/grpc v1.69.0/go.mod h1:vyjdE6jLBI76dgpDojsFGNaHlxdjXN9ghpnd2o7JGZ4=
google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA=
google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
37 changes: 27 additions & 10 deletions grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,30 @@ package panicrecovery

import (
"context"
"log"
"runtime/debug"
"time"

"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

// Context key
const PanicDetailsKey = "panicDetails"

// RecoverInterceptor is a gRPC interceptor that recovers from panics in gRPC methods.
func RecoverInterceptor() grpc.UnaryServerInterceptor {
return func(
ctx context.Context, req interface{}, info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (interface{}, error) {
) (resp any, err error) {
defer func() {
if err := recover(); err != nil {
// Log the panic error with stack trace.
log.Printf("Recovered from panic in gRPC method %s: %v\nStack Trace: %s", info.FullMethod, err, string(debug.Stack()))
if r := recover(); r != nil {
ctx, err = handlegRPCPanic(ctx, r, info.FullMethod)
}
}()
// Call the handler to execute the RPC method
// Call the handler to execute the RPC method and pass ctx to downstream
return handler(ctx, req)
}
}
Expand All @@ -30,14 +35,26 @@ func RecoverStreamInterceptor() grpc.StreamServerInterceptor {
return func(
srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo,
handler grpc.StreamHandler,
) error {
) (err error) {
defer func() {
if err := recover(); err != nil {
// Log the panic error with stack trace for stream-based methods.
log.Printf("Recovered from panic in gRPC streaming method %s: %v\nStack Trace: %s", info.FullMethod, err, string(debug.Stack()))
if r := recover(); r != nil {
_, err = handlegRPCPanic(stream.Context(), r, info.FullMethod)
}
}()
// Call the handler to execute the streaming RPC method
return handler(srv, stream)
}
}

func handlegRPCPanic(ctx context.Context, r any, method string) (context.Context, error) {
Logger.Error("Recovered from panic in gRPC",
zap.Any("error", r),
zap.String("method", method),
zap.String("stack_trace", string(debug.Stack())),
zap.Time("timestamp", time.Now()))

// Atach recovery info to ctx
ctx = context.WithValue(ctx, PanicDetailsKey, r)

return ctx, status.Errorf(codes.Internal, "panic occurred: %v", r)
}
106 changes: 106 additions & 0 deletions grpc_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package panicrecovery

import (
"context"
"testing"

"google.golang.org/grpc"
)

type mockHandler struct {
shouldPanic bool
}

type mockServerStream struct {
ctx context.Context
grpc.ServerStream
}

func (m *mockServerStream) Context() context.Context {
return m.ctx
}

func (m *mockHandler) Handle(ctx context.Context, req any) (any, error) {
if m.shouldPanic {
panic("test panic")
}

return "Hale Haji", nil
}

func TestRecoverInterceptor(t *testing.T) {
tests := []struct {
name string
shouldPanic bool
}{
{
name: "no panic",
shouldPanic: false,
},
{
name: "with panic",
shouldPanic: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
interceptor := RecoverInterceptor()
handler := &mockHandler{shouldPanic: tt.shouldPanic}

ctx := context.Background()
info := &grpc.UnaryServerInfo{FullMethod: "TestMethod"}

_, err := interceptor(ctx, "test request", info, handler.Handle)

if tt.shouldPanic && err == nil {
t.Error("Expected error from panic recovery, got nil")
}

if !tt.shouldPanic && err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
}
}

func TestRecoverStreamInterceptor(t *testing.T) {
tests := []struct {
name string
shouldPanic bool
}{
{
name: "no panic",
shouldPanic: false,
},
{
name: "with panic",
shouldPanic: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
interceptor := RecoverStreamInterceptor()
stream := &mockServerStream{ctx: context.Background()}
info := &grpc.StreamServerInfo{FullMethod: "TestStreamMethod"}

handler := func(srv any, stream grpc.ServerStream) error {
if tt.shouldPanic {
panic("test panic")
}
return nil
}

err := interceptor(nil, stream, info, handler)

if tt.shouldPanic && err == nil {
t.Error("Expected error from panic recovery, got nil")
}

if !tt.shouldPanic && err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
}
}
50 changes: 27 additions & 23 deletions http.go
Original file line number Diff line number Diff line change
@@ -1,43 +1,47 @@
package panicrecovery

import (
"log"
"net/http"
"runtime/debug"
"time"

"go.uber.org/zap"
)

type HTTPErrorHandler func(w http.ResponseWriter, r *http.Request, err interface{})

// HTTPRecover is an HTTP middleware that recovers from panics and logs the error.
func HTTPRecover(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
// Log the panic error with stack trace.
log.Printf("Recovered from panic: %v\nStack Trace: %s", err, string(debug.Stack()))

// Respond with 500 Internal Server Error.
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
}()
defer handleHTTPPanic(w, r, nil)
// Call the next handler in the chain
next.ServeHTTP(w, r)
})
}

type HTTPErrorHandler func(w http.ResponseWriter, r *http.Request, err interface{})

// HTTPRecoverWithHandler adds customizable error responses
func HTTPRecoverWithHandler(next http.Handler, errorHandler HTTPErrorHandler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
if errorHandler != nil {
errorHandler(w, r, err)
} else {
// Default error response
log.Printf("Recovered from panic: %v\nStack Trace: %s", err, string(debug.Stack()))
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
}
}()
defer handleHTTPPanic(w, r, errorHandler)
// Call the next handler in the chain
next.ServeHTTP(w, r)
})
}

func handleHTTPPanic(w http.ResponseWriter, r *http.Request, errorHandler HTTPErrorHandler) {
if err := recover(); err != nil {
Logger.Error("Recovered from panic",
zap.Any("error", err),
zap.String("stack_trace", string(debug.Stack())),
zap.Time("timestamp", time.Now()),
zap.String("path", r.URL.Path),
zap.String("method", r.Method))

if errorHandler != nil {
errorHandler(w, r, err)
return
}
// Default error response
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
}
}
89 changes: 89 additions & 0 deletions http_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package panicrecovery

import (
"net/http"
"net/http/httptest"
"testing"
)

func TestHTTPRecover(t *testing.T) {
tests := []struct {
name string
handler http.HandlerFunc
expectedStatus int
}{
{
name: "no panic",
handler: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
},
expectedStatus: http.StatusOK,
},
{
name: "with panic",
handler: func(w http.ResponseWriter, r *http.Request) {
panic("test panic")
},
expectedStatus: http.StatusInternalServerError,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
handler := HTTPRecover(tt.handler)
server := httptest.NewServer(handler)
defer server.Close()

resp, err := http.Get(server.URL)
if err != nil {
t.Fatalf("Fialed to make request: %v", err)
}
defer resp.Body.Close()

if resp.StatusCode != tt.expectedStatus {
t.Errorf("Expected status %d, but got %d", tt.expectedStatus, resp.StatusCode)
}
})
}
}

func TestHTTPRecoverWithHandler(t *testing.T) {
customStatus := http.StatusServiceUnavailable
customHandler := func(w http.ResponseWriter, r *http.Request, err any) {
w.WriteHeader(customStatus)
}

tests := []struct {
name string
handler http.HandlerFunc
errorHandler HTTPErrorHandler
expectedStatus int
}{
{
name: "custom error handler",
handler: func(w http.ResponseWriter, r *http.Request) {
panic("test panic")
},
errorHandler: customHandler,
expectedStatus: customStatus,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
handler := HTTPRecoverWithHandler(tt.handler, tt.errorHandler)
server := httptest.NewServer(handler)
defer server.Close()

resp, err := http.Get(server.URL)
if err != nil {
t.Fatalf("Fialed to make request: %v", err)
}
defer resp.Body.Close()

if resp.StatusCode != tt.expectedStatus {
t.Errorf("Expected status %d, but got %d", tt.expectedStatus, resp.StatusCode)
}
})
}
}
Loading