Skip to content

Commit

Permalink
Merge branch 'main' into peteski22/dont-error-listener-on-invalid-ups…
Browse files Browse the repository at this point in the history
…tream
  • Loading branch information
peteski22 committed Oct 8, 2024
2 parents 6a9dc80 + cd8a402 commit d032c9c
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 14 deletions.
11 changes: 7 additions & 4 deletions .github/workflows/golangci-lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ on:
pull_request:

permissions:
# Required: allow read access to the content for analysis.
contents: read
# Optional: allow read access to pull request. Use with `only-new-issues` option.
# pull-requests: read
pull-requests: read
# Optional: allow write access to checks to allow the action to annotate code in the PR.
checks: write

jobs:
golangci:
Expand All @@ -21,10 +24,10 @@ jobs:
matrix:
go: ['1.19', '1.20']
steps:
- uses: actions/setup-go@v3
- uses: actions/setup-go@v5
with:
go-version: ${{ matrix.go }}
- uses: actions/checkout@v3
- uses: actions/checkout@v4

- name: Format
run: go fmt
Expand All @@ -33,7 +36,7 @@ jobs:
run: go vet

- name: lint
uses: golangci/golangci-lint-action@v3
uses: golangci/golangci-lint-action@v6
#with:
# Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version
#version: v1.29
Expand Down
6 changes: 3 additions & 3 deletions header.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,11 @@ func (header *Header) EqualsTo(otherHeader *Header) bool {
if otherHeader == nil {
return false
}
// TLVs only exist for version 2
if header.Version == 2 && !bytes.Equal(header.rawTLVs, otherHeader.rawTLVs) {
if header.Version != otherHeader.Version || header.Command != otherHeader.Command || header.TransportProtocol != otherHeader.TransportProtocol {
return false
}
if header.Version != otherHeader.Version || header.Command != otherHeader.Command || header.TransportProtocol != otherHeader.TransportProtocol {
// TLVs only exist for version 2
if header.Version == 2 && !bytes.Equal(header.rawTLVs, otherHeader.rawTLVs) {
return false
}
// Return early for header with LOCAL command, which contains no address information
Expand Down
47 changes: 43 additions & 4 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,11 @@ type Conn struct {
once sync.Once
readErr error
conn net.Conn
Validate Validator
bufReader *bufio.Reader
reader io.Reader
header *Header
ProxyHeaderPolicy Policy
Validate Validator
readHeaderTimeout time.Duration
}

Expand All @@ -71,6 +72,15 @@ func ValidateHeader(v Validator) func(*Conn) {
}
}

// SetReadHeaderTimeout sets the readHeaderTimeout for a connection when passed as option to NewConn()
func SetReadHeaderTimeout(t time.Duration) func(*Conn) {
return func(c *Conn) {
if t >= 0 {
c.readHeaderTimeout = t
}
}
}

// Accept waits for and returns the next valid connection to the listener.
func (p *Listener) Accept() (net.Conn, error) {
for {
Expand Down Expand Up @@ -141,8 +151,15 @@ func (p *Listener) Addr() net.Addr {
// NewConn is used to wrap a net.Conn that may be speaking
// the proxy protocol into a proxyproto.Conn
func NewConn(conn net.Conn, opts ...func(*Conn)) *Conn {
// For v1 the header length is at most 108 bytes.
// For v2 the header length is at most 52 bytes plus the length of the TLVs.
// We use 256 bytes to be safe.
const bufSize = 256
br := bufio.NewReaderSize(conn, bufSize)

pConn := &Conn{
bufReader: bufio.NewReader(conn),
bufReader: br,
reader: io.MultiReader(br, conn),
conn: conn,
}

Expand All @@ -164,7 +181,7 @@ func (p *Conn) Read(b []byte) (int, error) {
return 0, p.readErr
}

return p.bufReader.Read(b)
return p.reader.Read(b)
}

// Write wraps original conn.Write
Expand Down Expand Up @@ -346,5 +363,27 @@ func (p *Conn) WriteTo(w io.Writer) (int64, error) {
if p.readErr != nil {
return 0, p.readErr
}
return p.bufReader.WriteTo(w)

b := make([]byte, p.bufReader.Buffered())
if _, err := p.bufReader.Read(b); err != nil {
return 0, err // this should never as we read buffered data
}

var n int64
{
nn, err := w.Write(b)
n += int64(nn)
if err != nil {
return n, err
}
}
{
nn, err := io.Copy(w, p.conn)
n += nn
if err != nil {
return n, err
}
}

return n, nil
}
3 changes: 1 addition & 2 deletions protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"sync/atomic"
Expand Down Expand Up @@ -1542,7 +1541,7 @@ type testConn struct {

func (c *testConn) ReadFrom(r io.Reader) (int64, error) {
c.readFromCalledWith = r
b, err := ioutil.ReadAll(r)
b, err := io.ReadAll(r)
return int64(len(b)), err
}

Expand Down
3 changes: 2 additions & 1 deletion v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package proxyproto
import (
"bufio"
"bytes"
iorand "crypto/rand"
"encoding/binary"
"math/rand"
"reflect"
Expand Down Expand Up @@ -51,7 +52,7 @@ var (
fixtureUnixV2 = append(lengthUnixBytes, fixtureUnixAddress...)
fixtureTLV = func() []byte {
tlv := make([]byte, 2+rand.Intn(1<<12)) // Not enough to overflow, at least size two
rand.Read(tlv)
_, _ = iorand.Read(tlv)
return tlv
}()
fixtureIPv4V2TLV = fixtureWithTLV(lengthV4Bytes, fixtureIPv4Address, fixtureTLV)
Expand Down

0 comments on commit d032c9c

Please sign in to comment.