Skip to content

Commit

Permalink
Buffer only proxy header data
Browse files Browse the repository at this point in the history
Reverts and re-implements incorrect #116
which passed all connection data through the buffered reader.
  • Loading branch information
AlexanderYastrebov committed Oct 10, 2024
1 parent 3b89f41 commit 5795baa
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 11 deletions.
1 change: 1 addition & 0 deletions header.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ type Header struct {
SourceAddr net.Addr
DestinationAddr net.Addr
rawTLVs []byte
length int
}

// HeaderProxyFromAddrs creates a new PROXY header from a source and a
Expand Down
31 changes: 20 additions & 11 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package proxyproto

import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -51,7 +52,7 @@ type Conn struct {
once sync.Once
readErr error
conn net.Conn
bufReader *bufio.Reader
reader io.Reader
header *Header
ProxyHeaderPolicy Policy
Validate Validator
Expand Down Expand Up @@ -150,14 +151,8 @@ func (p *Listener) Addr() net.Addr {
// NewConn is used to wrap a net.Conn that may be speaking
// the proxy protocol into a proxyproto.Conn
func NewConn(conn net.Conn, opts ...func(*Conn)) *Conn {
// For v1 the header length is at most 108 bytes.
// For v2 the header length is at most 52 bytes plus the length of the TLVs.
// We use 256 bytes to be safe.
const bufSize = 256

pConn := &Conn{
bufReader: bufio.NewReaderSize(conn, bufSize),
conn: conn,
conn: conn,
}

for _, opt := range opts {
Expand All @@ -178,7 +173,7 @@ func (p *Conn) Read(b []byte) (int, error) {
return 0, p.readErr
}

return p.bufReader.Read(b)
return p.reader.Read(b)
}

// Write wraps original conn.Write
Expand Down Expand Up @@ -294,7 +289,21 @@ func (p *Conn) readHeader() error {
}
}

header, err := Read(p.bufReader)
// For v1 the header length is at most 108 bytes.
// For v2 the header length is at most 52 bytes plus the length of the TLVs.
// We use 256 bytes to be safe.
const bufSize = 256

bb := bytes.NewBuffer(make([]byte, 0, bufSize))
tr := io.TeeReader(p.conn, bb)
br := bufio.NewReaderSize(tr, bufSize)

header, err := Read(br)

if err == nil {
_, err = io.CopyN(io.Discard, bb, int64(header.length))
}
p.reader = io.MultiReader(bb, p.conn)

// If the connection's readHeaderTimeout is more than 0, undo the change to the
// deadline that we made above. Because we retain the readDeadline as part of our
Expand Down Expand Up @@ -360,5 +369,5 @@ func (p *Conn) WriteTo(w io.Writer) (int64, error) {
if p.readErr != nil {
return 0, p.readErr
}
return p.bufReader.WriteTo(w)
return io.Copy(w, p.reader)
}
1 change: 1 addition & 0 deletions v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ func parseVersion1(reader *bufio.Reader) (*Header, error) {
// Command doesn't exist in v1 but set it for other parts of this library
// to rely on it for determining connection details.
header := initVersion1()
header.length = len(buf)

// Transport protocol has been processed already.
header.TransportProtocol = transportProtocol
Expand Down
1 change: 1 addition & 0 deletions v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ func parseVersion2(reader *bufio.Reader) (header *Header, err error) {
if !header.validateLength(length) {
return nil, ErrInvalidLength
}
header.length = 16 + int(length)

// Return early if the length is zero, which means that
// there's no address information and TLVs present for UNSPEC.
Expand Down

0 comments on commit 5795baa

Please sign in to comment.