diff --git a/protocol_test.go b/protocol_test.go index 8c0834c..4221149 100644 --- a/protocol_test.go +++ b/protocol_test.go @@ -13,6 +13,7 @@ import ( "io" "net" "net/http" + "sync/atomic" "testing" "time" ) @@ -1280,15 +1281,16 @@ func Test_ConnectionHandlesInvalidUpstreamError(t *testing.T) { t.Fatalf("error creating listener: %v", err) } - times := 0 + var connectionCounter atomic.Int32 newLn := &Listener{ Listener: l, ConnPolicy: func(_ ConnPolicyOptions) (Policy, error) { // Return the invalid upstream error on the first call, the listener // should remain open and accepting. + times := connectionCounter.Load() if times == 0 { - times++ + connectionCounter.Store(times + 1) return REJECT, ErrInvalidUpstream } @@ -1296,23 +1298,41 @@ func Test_ConnectionHandlesInvalidUpstreamError(t *testing.T) { }, } - // Kick off the listener and capture any error. - var listenerErr error + // Kick off the listener and return any error via the chanel. + errCh := make(chan error) + defer close(errCh) go func(t *testing.T) { - _, listenerErr = newLn.Accept() + _, err := newLn.Accept() + errCh <- err }(t) // Make two calls to trigger the listener's accept, the first should experience // the ErrInvalidUpstream and keep the listener open, the second should experience // a different error which will cause the listener to close. _, _ = http.Get("http://localhost:8080") - if listenerErr != nil { - t.Fatalf("invalid upstream shouldn't return an error: %v", listenerErr) + // Wait a few seconds to ensure we didn't get anything back on our channel. + select { + case err := <-errCh: + if err != nil { + t.Fatalf("invalid upstream shouldn't return an error: %v", err) + } + case <-time.After(2 * time.Second): + // No error returned (as expected, we're still listening though) } _, _ = http.Get("http://localhost:8080") - if listenerErr == nil { - t.Fatalf("errors other than invalid upstream should error") + // Wait a few seconds before we fail the test as we should have received an + // error that was not invalid upstream. + select { + case err := <-errCh: + if err == nil { + t.Fatalf("errors other than invalid upstream should error") + } + if !errors.Is(ErrNoProxyProtocol, err) { + t.Fatalf("unexpected error type: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatalf("timed out waiting for listener") } }