Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ var (
injectBanner = flag.String("inject-banner", "", "HTML snippet to inject in served webpages")
bannerHeight = flag.String("banner-height", "40px", "Height of the injected banner. This is ignored if no banner is set.")
shimWebsockets = flag.Bool("shim-websockets", false, "Whether or not to replace websockets with a shim")
websocketShimTimeout = flag.Duration("websocket-shim-timeout", 60*time.Minute, "Timeout for websocket shim connections to expire due to inactivity.")
shimPath = flag.String("shim-path", "", "Path under which to handle websocket shim requests")
healthCheckPath = flag.String("health-check-path", "/", "Path on backend host to issue health checks against. Defaults to the root.")
healthCheckFreq = flag.Int("health-check-interval-seconds", 0, "Wait time in seconds between health checks. Set to zero to disable health checks. Checks disabled by default.")
Expand Down Expand Up @@ -126,7 +127,8 @@ func hostProxy(ctx context.Context, host, shimPath string, injectShimCode, force
// restricted to a path prefix not equal to "/" will fail for websocket open requests. Passing in the
// sessionHandler twice allows the websocket handler to ensure that cookies are applied based on the
// correct, restored path.
h, err = websockets.Proxy(ctx, h, host, shimPath, *rewriteWebsocketHost, *enableWebsocketsInjection, sessionLRU.SessionHandler, metricHandler)
h, err = websockets.Proxy(ctx, h, host, shimPath, *rewriteWebsocketHost, *enableWebsocketsInjection, sessionLRU.SessionHandler,
metricHandler, *websocketShimTimeout)
if injectShimCode {
shimFunc, err := websockets.ShimBody(shimPath)
if err != nil {
Expand Down
45 changes: 32 additions & 13 deletions agent/websockets/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@ limitations under the License.
package websockets

import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"sync"
"time"

"context"

"github.com/gorilla/websocket"
)

Expand Down Expand Up @@ -57,12 +57,14 @@ func (m *message) Serialize(version int) interface{} {
// and encapsulates it in an API that is a little more amenable to how the server side
// of our websocket shim is implemented.
type Connection struct {
done func() <-chan struct{}
cancel context.CancelFunc
clientMessages chan *message
serverMessages chan *message
protocolVersion int
subprotocol string
done func() <-chan struct{}
cancel context.CancelFunc
clientMessages chan *message
serverMessages chan *message
protocolVersion int
subprotocol string
mu sync.Mutex
lastActivityTime time.Time
}

// This map defines the set of headers that should be stripped from the WS request, as they
Expand All @@ -87,6 +89,20 @@ func stripWSHeader(header http.Header) http.Header {
return result
}

// updateActivity updates the last activity timestamp.
func (conn *Connection) updateActivity() {
conn.mu.Lock()
defer conn.mu.Unlock()
conn.lastActivityTime = time.Now()
}

// lastActivity returns the last activity timestamp.
func (conn *Connection) lastActivity() time.Time {
conn.mu.Lock()
defer conn.mu.Unlock()
return conn.lastActivityTime
}

// NewConnection creates and returns a new Connection.
func NewConnection(ctx context.Context, targetURL string, header http.Header, errCallback func(err error)) (*Connection, error) {
ctx, cancel := context.WithCancel(ctx)
Expand Down Expand Up @@ -162,11 +178,12 @@ func NewConnection(ctx context.Context, targetURL string, header http.Header, er
}
}()
return &Connection{
done: ctx.Done,
cancel: cancel,
clientMessages: clientMessages,
serverMessages: serverMessages,
subprotocol: serverConn.Subprotocol(),
done: ctx.Done,
cancel: cancel,
clientMessages: clientMessages,
serverMessages: serverMessages,
subprotocol: serverConn.Subprotocol(),
lastActivityTime: time.Now(),
}, nil
}

Expand All @@ -184,6 +201,7 @@ func (conn *Connection) Close() {
//
// The returned error value is non-nill if the connection has been closed.
func (conn *Connection) SendClientMessage(msg interface{}, injectionEnabled bool, injectedHeaders map[string]string) error {
conn.updateActivity()
var clientMessage *message
if textMsg, ok := msg.(string); ok {
clientMessage = &message{
Expand Down Expand Up @@ -244,6 +262,7 @@ func (conn *Connection) ReadServerMessages() ([]interface{}, error) {
// The server messages channel has been closed.
return nil, fmt.Errorf("attempt to read a server message from a closed websocket connection")
}
conn.updateActivity()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also needs to be called at the beginning of this method (i.e. on line 258).

Otherwise, a connection that the client is continuously polling on, but where the server has not yet responded, will be closed as inactive.

msgs = append(msgs, serverMsg.Serialize(conn.protocolVersion))
for {
select {
Expand Down
39 changes: 32 additions & 7 deletions agent/websockets/shim.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package websockets

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
Expand All @@ -31,8 +32,8 @@ import (
"sync"
"sync/atomic"
"text/template"
"time"

"context"
"github.com/google/inverting-proxy/agent/metrics"
)

Expand Down Expand Up @@ -320,9 +321,33 @@ func (c *connectionErrorHandler) ReportError(err error) {
}
}

func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost bool, openWebsocketWrapper func(http.Handler, *metrics.MetricHandler) http.Handler, enableWebsocketInjection bool, metricHandler *metrics.MetricHandler) http.Handler {
func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost bool, openWebsocketWrapper func(http.Handler, *metrics.MetricHandler) http.Handler, enableWebsocketInjection bool, metricHandler *metrics.MetricHandler, timeout time.Duration) http.Handler {
var connections sync.Map
var sessionCount uint64

// Background goroutine to clean up inactive websocket shim connections.
go func() {
ticker := time.NewTicker(min(timeout, 30*time.Second))
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
connections.Range(func(key, value any) bool {
sessionID := key.(string)
conn := value.(*Connection)
if time.Since(conn.lastActivity()) > timeout {
log.Printf("Closing inactive websocket shim session %q after timeout", sessionID)
conn.Close()
connections.Delete(sessionID)
}
return true // Continue iteration
})
}
}
}()

mux := http.NewServeMux()
errorHandler := &connectionErrorHandler{}
openWebsocketHandler := openWebsocketWrapper(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -351,9 +376,9 @@ func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost b
}
}
resp := &sessionMessage{
ID: sessionID,
Message: targetURL.String(),
Version: conn.protocolVersion,
ID: sessionID,
Message: targetURL.String(),
Version: conn.protocolVersion,
Subprotocol: conn.Subprotocol(),
}
respBytes, err := json.Marshal(resp)
Expand Down Expand Up @@ -548,11 +573,11 @@ func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost b
// openWebsocketWrapper is a http.Handler wrapper function that is invoked on websocket open requests after the original
// targetURL of the request is restored. It must call the wrapped http.Handler with which it is created after it
// is finished processing the request.
func Proxy(ctx context.Context, wrapped http.Handler, host, shimPath string, rewriteHost, enableWebsocketInjection bool, openWebsocketWrapper func(wrapped http.Handler, metricHandler *metrics.MetricHandler) http.Handler, metricHandler *metrics.MetricHandler) (http.Handler, error) {
func Proxy(ctx context.Context, wrapped http.Handler, host, shimPath string, rewriteHost, enableWebsocketInjection bool, openWebsocketWrapper func(wrapped http.Handler, metricHandler *metrics.MetricHandler) http.Handler, metricHandler *metrics.MetricHandler, timeout time.Duration) (http.Handler, error) {
mux := http.NewServeMux()
if shimPath != "" {
shimPath = path.Clean("/"+shimPath) + "/"
shimServer := createShimChannel(ctx, host, shimPath, rewriteHost, openWebsocketWrapper, enableWebsocketInjection, metricHandler)
shimServer := createShimChannel(ctx, host, shimPath, rewriteHost, openWebsocketWrapper, enableWebsocketInjection, metricHandler, timeout)
mux.Handle(shimPath, shimServer)
}
mux.Handle("/", wrapped)
Expand Down
3 changes: 2 additions & 1 deletion agent/websockets/websockets_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"strings"
"sync"
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
Expand Down Expand Up @@ -239,7 +240,7 @@ func TestShimHandlers(t *testing.T) {
openWrapper := func(h http.Handler, metricHandler *metrics.MetricHandler) http.Handler {
return h
}
p, err := Proxy(context.Background(), h, serverURL.Host, testShimPath, false, false, openWrapper, nil)
p, err := Proxy(context.Background(), h, serverURL.Host, testShimPath, false, false, openWrapper, nil, 60*time.Second)
if err != nil {
t.Fatalf("Failure creating the websocket shim proxy: %+v", err)
}
Expand Down