Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Buffer only proxy header data #119

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 22 additions & 34 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package proxyproto

import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -51,7 +52,6 @@ type Conn struct {
once sync.Once
readErr error
conn net.Conn
bufReader *bufio.Reader
reader io.Reader
header *Header
ProxyHeaderPolicy Policy
Expand Down Expand Up @@ -151,16 +151,8 @@ func (p *Listener) Addr() net.Addr {
// NewConn is used to wrap a net.Conn that may be speaking
// the proxy protocol into a proxyproto.Conn
func NewConn(conn net.Conn, opts ...func(*Conn)) *Conn {
// For v1 the header length is at most 108 bytes.
// For v2 the header length is at most 52 bytes plus the length of the TLVs.
// We use 256 bytes to be safe.
const bufSize = 256
br := bufio.NewReaderSize(conn, bufSize)

pConn := &Conn{
bufReader: br,
reader: io.MultiReader(br, conn),
conn: conn,
conn: conn,
}

for _, opt := range opts {
Expand Down Expand Up @@ -297,7 +289,25 @@ func (p *Conn) readHeader() error {
}
}

header, err := Read(p.bufReader)
// For v1 the header length is at most 108 bytes.
// For v2 the header length is at most 52 bytes plus the length of the TLVs.
// We use 256 bytes to be safe.
const bufSize = 256

bb := bytes.NewBuffer(make([]byte, 0, bufSize))
br := bufio.NewReaderSize(io.TeeReader(p.conn, bb), bufSize)

header, err := Read(br)

if err == nil {
_ = bb.Next(bb.Len() - br.Buffered()) // skip header
}

if bb.Len() == 0 {
p.reader = p.conn
} else {
p.reader = io.MultiReader(bb, p.conn)
}

// If the connection's readHeaderTimeout is more than 0, undo the change to the
// deadline that we made above. Because we retain the readDeadline as part of our
Expand Down Expand Up @@ -363,27 +373,5 @@ func (p *Conn) WriteTo(w io.Writer) (int64, error) {
if p.readErr != nil {
return 0, p.readErr
}

b := make([]byte, p.bufReader.Buffered())
if _, err := p.bufReader.Read(b); err != nil {
return 0, err // this should never as we read buffered data
}

var n int64
{
nn, err := w.Write(b)
n += int64(nn)
if err != nil {
return n, err
}
}
{
nn, err := io.Copy(w, p.conn)
n += nn
if err != nil {
return n, err
}
}

return n, nil
return io.Copy(w, p.reader)
}
Loading