Skip to content

Commit c8e7692

Browse files
authored
feat(bigquery): expose Apache Arrow data through ArrowIterator (#8506)
As we have some planned work to support Arrow data fetching on other query APIs, so we need to think of an interface that will support all of those query paths and also work as a base for other Arrow projects like ADBC. So this PR detaches the Storage API from the Arrow Decoder and creates a new ArrowIterator interface. This new interface is implemented by the Storage iterator and later can be implemented for other query interfaces that supports Arrow. Resolves #8100
1 parent f8ba0b9 commit c8e7692

File tree

6 files changed

+241
-71
lines changed

6 files changed

+241
-71
lines changed

β€Žbigquery/arrow.go

Lines changed: 85 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,49 +19,114 @@ import (
1919
"encoding/base64"
2020
"errors"
2121
"fmt"
22+
"io"
2223
"math/big"
2324

2425
"cloud.google.com/go/civil"
2526
"github.com/apache/arrow/go/v12/arrow"
2627
"github.com/apache/arrow/go/v12/arrow/array"
2728
"github.com/apache/arrow/go/v12/arrow/ipc"
29+
"github.com/apache/arrow/go/v12/arrow/memory"
30+
"google.golang.org/api/iterator"
2831
)
2932

30-
type arrowDecoder struct {
31-
tableSchema Schema
32-
rawArrowSchema []byte
33-
arrowSchema *arrow.Schema
33+
// ArrowRecordBatch represents an Arrow RecordBatch with the source PartitionID
34+
type ArrowRecordBatch struct {
35+
reader io.Reader
36+
// Serialized Arrow Record Batch.
37+
Data []byte
38+
// Serialized Arrow Schema.
39+
Schema []byte
40+
// Source partition ID. In the Storage API world, it represents the ReadStream.
41+
PartitionID string
42+
}
43+
44+
// Read makes ArrowRecordBatch implements io.Reader
45+
func (r *ArrowRecordBatch) Read(p []byte) (int, error) {
46+
if r.reader == nil {
47+
buf := bytes.NewBuffer(r.Schema)
48+
buf.Write(r.Data)
49+
r.reader = buf
50+
}
51+
return r.reader.Read(p)
52+
}
53+
54+
// ArrowIterator represents a way to iterate through a stream of arrow records.
55+
// Experimental: this interface is experimental and may be modified or removed in future versions,
56+
// regardless of any other documented package stability guarantees.
57+
type ArrowIterator interface {
58+
Next() (*ArrowRecordBatch, error)
59+
Schema() Schema
60+
SerializedArrowSchema() []byte
3461
}
3562

36-
func newArrowDecoderFromSession(session *readSession, schema Schema) (*arrowDecoder, error) {
37-
bqSession := session.bqSession
38-
if bqSession == nil {
39-
return nil, errors.New("read session not initialized")
63+
// NewArrowIteratorReader allows to consume an ArrowIterator as an io.Reader.
64+
// Experimental: this interface is experimental and may be modified or removed in future versions,
65+
// regardless of any other documented package stability guarantees.
66+
func NewArrowIteratorReader(it ArrowIterator) io.Reader {
67+
return &arrowIteratorReader{
68+
it: it,
4069
}
41-
arrowSerializedSchema := bqSession.GetArrowSchema().GetSerializedSchema()
70+
}
71+
72+
type arrowIteratorReader struct {
73+
buf *bytes.Buffer
74+
it ArrowIterator
75+
}
76+
77+
// Read makes ArrowIteratorReader implement io.Reader
78+
func (r *arrowIteratorReader) Read(p []byte) (int, error) {
79+
if r.it == nil {
80+
return -1, errors.New("bigquery: nil ArrowIterator")
81+
}
82+
if r.buf == nil { // init with schema
83+
buf := bytes.NewBuffer(r.it.SerializedArrowSchema())
84+
r.buf = buf
85+
}
86+
n, err := r.buf.Read(p)
87+
if err == io.EOF {
88+
batch, err := r.it.Next()
89+
if err == iterator.Done {
90+
return 0, io.EOF
91+
}
92+
r.buf.Write(batch.Data)
93+
return r.Read(p)
94+
}
95+
return n, err
96+
}
97+
98+
type arrowDecoder struct {
99+
allocator memory.Allocator
100+
tableSchema Schema
101+
arrowSchema *arrow.Schema
102+
}
103+
104+
func newArrowDecoder(arrowSerializedSchema []byte, schema Schema) (*arrowDecoder, error) {
42105
buf := bytes.NewBuffer(arrowSerializedSchema)
43106
r, err := ipc.NewReader(buf)
44107
if err != nil {
45108
return nil, err
46109
}
47110
defer r.Release()
48111
p := &arrowDecoder{
49-
tableSchema: schema,
50-
rawArrowSchema: arrowSerializedSchema,
51-
arrowSchema: r.Schema(),
112+
tableSchema: schema,
113+
arrowSchema: r.Schema(),
114+
allocator: memory.DefaultAllocator,
52115
}
53116
return p, nil
54117
}
55118

56-
func (ap *arrowDecoder) createIPCReaderForBatch(serializedArrowRecordBatch []byte) (*ipc.Reader, error) {
57-
buf := bytes.NewBuffer(ap.rawArrowSchema)
58-
buf.Write(serializedArrowRecordBatch)
59-
return ipc.NewReader(buf, ipc.WithSchema(ap.arrowSchema))
119+
func (ap *arrowDecoder) createIPCReaderForBatch(arrowRecordBatch *ArrowRecordBatch) (*ipc.Reader, error) {
120+
return ipc.NewReader(
121+
arrowRecordBatch,
122+
ipc.WithSchema(ap.arrowSchema),
123+
ipc.WithAllocator(ap.allocator),
124+
)
60125
}
61126

62127
// decodeArrowRecords decodes BQ ArrowRecordBatch into rows of []Value.
63-
func (ap *arrowDecoder) decodeArrowRecords(serializedArrowRecordBatch []byte) ([][]Value, error) {
64-
r, err := ap.createIPCReaderForBatch(serializedArrowRecordBatch)
128+
func (ap *arrowDecoder) decodeArrowRecords(arrowRecordBatch *ArrowRecordBatch) ([][]Value, error) {
129+
r, err := ap.createIPCReaderForBatch(arrowRecordBatch)
65130
if err != nil {
66131
return nil, err
67132
}
@@ -79,8 +144,8 @@ func (ap *arrowDecoder) decodeArrowRecords(serializedArrowRecordBatch []byte) ([
79144
}
80145

81146
// decodeRetainedArrowRecords decodes BQ ArrowRecordBatch into a list of retained arrow.Record.
82-
func (ap *arrowDecoder) decodeRetainedArrowRecords(serializedArrowRecordBatch []byte) ([]arrow.Record, error) {
83-
r, err := ap.createIPCReaderForBatch(serializedArrowRecordBatch)
147+
func (ap *arrowDecoder) decodeRetainedArrowRecords(arrowRecordBatch *ArrowRecordBatch) ([]arrow.Record, error) {
148+
r, err := ap.createIPCReaderForBatch(arrowRecordBatch)
84149
if err != nil {
85150
return nil, err
86151
}

β€Žbigquery/iterator.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ type RowIterator struct {
4444
ctx context.Context
4545
src *rowSource
4646

47-
arrowIterator *arrowIterator
47+
arrowIterator ArrowIterator
48+
arrowDecoder *arrowDecoder
4849

4950
pageInfo *iterator.PageInfo
5051
nextFunc func() error

β€Žbigquery/storage_bench_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ func BenchmarkIntegration_StorageReadQuery(b *testing.B) {
7474
}
7575
}
7676
b.ReportMetric(float64(it.TotalRows), "rows")
77-
bqSession := it.arrowIterator.session.bqSession
77+
bqSession := it.arrowIterator.(*storageArrowIterator).session.bqSession
7878
b.ReportMetric(float64(len(bqSession.Streams)), "parallel_streams")
7979
b.ReportMetric(float64(maxStreamCount), "max_streams")
8080
}

β€Žbigquery/storage_integration_test.go

Lines changed: 88 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ import (
2222
"time"
2323

2424
"cloud.google.com/go/internal/testutil"
25+
"github.com/apache/arrow/go/v12/arrow"
26+
"github.com/apache/arrow/go/v12/arrow/array"
27+
"github.com/apache/arrow/go/v12/arrow/ipc"
28+
"github.com/apache/arrow/go/v12/arrow/math"
29+
"github.com/apache/arrow/go/v12/arrow/memory"
2530
"github.com/google/go-cmp/cmp"
2631
"google.golang.org/api/iterator"
2732
)
@@ -250,11 +255,12 @@ func TestIntegration_StorageReadQueryOrdering(t *testing.T) {
250255
}
251256
total++ // as we read the first value separately
252257

253-
bqSession := it.arrowIterator.session.bqSession
258+
session := it.arrowIterator.(*storageArrowIterator).session
259+
bqSession := session.bqSession
254260
if len(bqSession.Streams) == 0 {
255261
t.Fatalf("%s: expected to use at least one stream but found %d", tc.name, len(bqSession.Streams))
256262
}
257-
streamSettings := it.arrowIterator.session.settings.maxStreamCount
263+
streamSettings := session.settings.maxStreamCount
258264
if tc.maxExpectedStreams > 0 {
259265
if streamSettings > tc.maxExpectedStreams {
260266
t.Fatalf("%s: expected stream settings to be at most %d streams but found %d", tc.name, tc.maxExpectedStreams, streamSettings)
@@ -317,7 +323,7 @@ func TestIntegration_StorageReadQueryStruct(t *testing.T) {
317323
total++
318324
}
319325

320-
bqSession := it.arrowIterator.session.bqSession
326+
bqSession := it.arrowIterator.(*storageArrowIterator).session.bqSession
321327
if len(bqSession.Streams) == 0 {
322328
t.Fatalf("should use more than one stream but found %d", len(bqSession.Streams))
323329
}
@@ -366,7 +372,7 @@ func TestIntegration_StorageReadQueryMorePages(t *testing.T) {
366372
}
367373
total++ // as we read the first value separately
368374

369-
bqSession := it.arrowIterator.session.bqSession
375+
bqSession := it.arrowIterator.(*storageArrowIterator).session.bqSession
370376
if len(bqSession.Streams) == 0 {
371377
t.Fatalf("should use more than one stream but found %d", len(bqSession.Streams))
372378
}
@@ -418,11 +424,88 @@ func TestIntegration_StorageReadCancel(t *testing.T) {
418424
}
419425
// resources are cleaned asynchronously
420426
time.Sleep(time.Second)
421-
if !it.arrowIterator.isDone() {
427+
arrowIt := it.arrowIterator.(*storageArrowIterator)
428+
if !arrowIt.isDone() {
422429
t.Fatal("expected stream to be done")
423430
}
424431
}
425432

433+
func TestIntegration_StorageReadArrow(t *testing.T) {
434+
if client == nil {
435+
t.Skip("Integration tests skipped")
436+
}
437+
ctx := context.Background()
438+
table := "`bigquery-public-data.usa_names.usa_1910_current`"
439+
sql := fmt.Sprintf(`SELECT name, number, state FROM %s where state = "CA"`, table)
440+
441+
q := storageOptimizedClient.Query(sql)
442+
job, err := q.Run(ctx) // force usage of Storage API by skipping fast paths
443+
if err != nil {
444+
t.Fatal(err)
445+
}
446+
it, err := job.Read(ctx)
447+
if err != nil {
448+
t.Fatal(err)
449+
}
450+
451+
checkedAllocator := memory.NewCheckedAllocator(memory.DefaultAllocator)
452+
it.arrowDecoder.allocator = checkedAllocator
453+
defer checkedAllocator.AssertSize(t, 0)
454+
455+
arrowIt, err := it.ArrowIterator()
456+
if err != nil {
457+
t.Fatalf("expected iterator to be accelerated: %v", err)
458+
}
459+
arrowItReader := NewArrowIteratorReader(arrowIt)
460+
461+
records := []arrow.Record{}
462+
r, err := ipc.NewReader(arrowItReader, ipc.WithAllocator(checkedAllocator))
463+
numrec := 0
464+
for r.Next() {
465+
rec := r.Record()
466+
rec.Retain()
467+
defer rec.Release()
468+
records = append(records, rec)
469+
numrec += int(rec.NumRows())
470+
}
471+
r.Release()
472+
473+
arrowSchema := r.Schema()
474+
arrowTable := array.NewTableFromRecords(arrowSchema, records)
475+
defer arrowTable.Release()
476+
if arrowTable.NumRows() != int64(it.TotalRows) {
477+
t.Fatalf("should have a table with %d rows, but found %d", it.TotalRows, arrowTable.NumRows())
478+
}
479+
if arrowTable.NumCols() != 3 {
480+
t.Fatalf("should have a table with 3 columns, but found %d", arrowTable.NumCols())
481+
}
482+
483+
sumSQL := fmt.Sprintf(`SELECT sum(number) as total FROM %s where state = "CA"`, table)
484+
sumQuery := client.Query(sumSQL)
485+
sumIt, err := sumQuery.Read(ctx)
486+
if err != nil {
487+
t.Fatal(err)
488+
}
489+
sumValues := []Value{}
490+
err = sumIt.Next(&sumValues)
491+
if err != nil {
492+
t.Fatal(err)
493+
}
494+
totalFromSQL := sumValues[0].(int64)
495+
496+
tr := array.NewTableReader(arrowTable, arrowTable.NumRows())
497+
defer tr.Release()
498+
var totalFromArrow int64
499+
for tr.Next() {
500+
rec := tr.Record()
501+
vec := rec.Column(1).(*array.Int64)
502+
totalFromArrow += math.Int64.Sum(vec)
503+
}
504+
if totalFromArrow != totalFromSQL {
505+
t.Fatalf("expected total to be %d, but with arrow we got %d", totalFromSQL, totalFromArrow)
506+
}
507+
}
508+
426509
func countIteratorRows(it *RowIterator) (total uint64, err error) {
427510
for {
428511
var dst []Value

0 commit comments

Comments
 (0)