Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support to get result response byte size #295

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ func (cn *connection) open(hostAddress HostAddress, timeout time.Duration, sslCo
transport thrift.Transport
pf thrift.ProtocolFactory
)
pf = cn.getProtocolFactory()

if useHTTP2 {
if sslConfig != nil {
transport, err = thrift.NewHTTPPostClientWithOptions("https://"+newAdd, thrift.HTTPClientOptions{
Expand Down Expand Up @@ -86,7 +88,6 @@ func (cn *connection) open(hostAddress HostAddress, timeout time.Duration, sslCo
if err != nil {
return fmt.Errorf("failed to create a net.Conn-backed Transport,: %s", err.Error())
}
pf = thrift.NewBinaryProtocolFactoryDefault()
if httpHeader != nil {
client, ok := transport.(*thrift.HTTPClient)
if !ok {
Expand Down Expand Up @@ -118,7 +119,6 @@ func (cn *connection) open(hostAddress HostAddress, timeout time.Duration, sslCo
// Set transport
bufferedTranFactory := thrift.NewBufferedTransportFactory(bufferSize)
transport = thrift.NewHeaderTransport(bufferedTranFactory.GetTransport(sock))
pf = thrift.NewHeaderProtocolFactory()
}

cn.graph = graph.NewGraphServiceClientFactory(transport, pf)
Expand Down Expand Up @@ -224,3 +224,11 @@ func (cn *connection) release() {
func (cn *connection) close() {
cn.graph.Close()
}

func (cn *connection) getProtocolFactory() thrift.ProtocolFactory {
if cn.useHTTP2 {
return thrift.NewBinaryProtocolFactoryDefault()
} else {
return thrift.NewHeaderProtocolFactory()
}
}
35 changes: 33 additions & 2 deletions result_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"strings"
"time"

"github.com/vesoft-inc/fbthrift/thrift/lib/go/thrift"
"github.com/vesoft-inc/nebula-go/v3/nebula"
"github.com/vesoft-inc/nebula-go/v3/nebula/graph"
)
Expand All @@ -26,6 +27,7 @@ type ResultSet struct {
columnNames []string
colNameIndexMap map[string]int
timezoneInfo timezoneInfo
factory thrift.ProtocolFactory
}

type Record struct {
Expand Down Expand Up @@ -96,18 +98,22 @@ const (

func GenResultSet(resp *graph.ExecutionResponse) (*ResultSet, error) {
var defaultTimezone timezoneInfo = timezoneInfo{0, []byte("UTC")}
return genResultSet(resp, defaultTimezone)
return genResultSet(resp, defaultTimezone, nil)
}

func genResultSet(resp *graph.ExecutionResponse, timezoneInfo timezoneInfo) (*ResultSet, error) {
func genResultSet(resp *graph.ExecutionResponse, timezoneInfo timezoneInfo, factory thrift.ProtocolFactory) (*ResultSet, error) {
var colNames []string
var colNameIndexMap = make(map[string]int)
if factory == nil {
factory = thrift.NewHeaderProtocolFactory()
}

if resp.Data == nil { // if resp.Data != nil then resp.Data.row and resp.Data.colNames wont be nil
return &ResultSet{
resp: resp,
columnNames: colNames,
colNameIndexMap: colNameIndexMap,
factory: factory,
}, nil
}
for i, name := range resp.Data.ColumnNames {
Expand All @@ -120,6 +126,7 @@ func genResultSet(resp *graph.ExecutionResponse, timezoneInfo timezoneInfo) (*Re
columnNames: colNames,
colNameIndexMap: colNameIndexMap,
timezoneInfo: timezoneInfo,
factory: factory,
}, nil
}

Expand Down Expand Up @@ -1381,3 +1388,27 @@ func (res ResultSet) MakePlanByTck() [][]interface{} {
}
return rows
}

func (res *ResultSet) GetByteSize() int {
if res.resp == nil {
return 0
}
if res.factory == nil {
return 0
}
var pf thrift.ProtocolFactory
buf := thrift.NewMemoryBuffer()
switch res.factory.(type) {
case *thrift.BinaryProtocolFactory:
pf = res.factory
case *thrift.HeaderProtocolFactory:
pf = thrift.NewCompactProtocolFactory()
}
protocal := pf.GetProtocol(buf)
if err := res.resp.Write(protocal); err != nil {
return 0
}
bs := make([]byte, buf.Len())
copy(bs, buf.Bytes())
return len(bs)
}
28 changes: 25 additions & 3 deletions result_set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/vesoft-inc/fbthrift/thrift/lib/go/thrift"
"github.com/vesoft-inc/nebula-go/v3/nebula"
"github.com/vesoft-inc/nebula-go/v3/nebula/graph"
)
Expand Down Expand Up @@ -550,7 +551,7 @@ func TestResultSet(t *testing.T) {
nil,
nil,
nil}
resultSetWithNil, err := genResultSet(respWithNil, testTimezone)
resultSetWithNil, err := genResultSet(respWithNil, testTimezone, nil)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -595,7 +596,7 @@ func TestResultSet(t *testing.T) {
&planDesc,
[]byte("test_comment")}

resultSet, err := genResultSet(resp, testTimezone)
resultSet, err := genResultSet(resp, testTimezone, nil)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -673,7 +674,7 @@ func TestAsStringTable(t *testing.T) {
[]byte("test"),
graph.NewPlanDescription(),
[]byte("test_comment")}
resultSet, err := genResultSet(resp, testTimezone)
resultSet, err := genResultSet(resp, testTimezone, nil)
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -878,3 +879,24 @@ func setIVal(ival int) *nebula.Value {
value.IVal = newNum
return value
}

func TestGetByteSize(t *testing.T) {
resp := &graph.ExecutionResponse{
nebula.ErrorCode_SUCCEEDED,
1000,
getDateset(),
[]byte("test_space"),
[]byte("test"),
graph.NewPlanDescription(),
[]byte("test_comment")}
resultSet, err := genResultSet(resp, testTimezone, thrift.NewBinaryProtocolFactoryDefault())
if err != nil {
t.Error(err)
}
assert.Equal(t, 2899, resultSet.GetByteSize())
resultSet, err = genResultSet(resp, testTimezone, thrift.NewHeaderProtocolFactory())
if err != nil {
t.Error(err)
}
assert.Equal(t, 1297, resultSet.GetByteSize())
}
3 changes: 2 additions & 1 deletion session.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ func (session *Session) ExecuteWithParameter(stmt string, params map[string]inte
if err != nil {
return nil, err
}
resSet, err := genResultSet(resp, session.timezoneInfo)
pf := session.connection.getProtocolFactory()
resSet, err := genResultSet(resp, session.timezoneInfo, pf)
if err != nil {
return nil, err
}
Expand Down
3 changes: 2 additions & 1 deletion session_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,8 @@ func (session *pureSession) executeWithParameter(stmt string, params map[string]
if err != nil {
return nil, err
}
rs, err := genResultSet(resp, session.timezoneInfo)
pf := session.connection.getProtocolFactory()
rs, err := genResultSet(resp, session.timezoneInfo, pf)
if err != nil {
return nil, err
}
Expand Down