@@ -22,6 +22,11 @@ import (
22
22
"time"
23
23
24
24
"cloud.google.com/go/internal/testutil"
25
+ "github.com/apache/arrow/go/v12/arrow"
26
+ "github.com/apache/arrow/go/v12/arrow/array"
27
+ "github.com/apache/arrow/go/v12/arrow/ipc"
28
+ "github.com/apache/arrow/go/v12/arrow/math"
29
+ "github.com/apache/arrow/go/v12/arrow/memory"
25
30
"github.com/google/go-cmp/cmp"
26
31
"google.golang.org/api/iterator"
27
32
)
@@ -250,11 +255,12 @@ func TestIntegration_StorageReadQueryOrdering(t *testing.T) {
250
255
}
251
256
total ++ // as we read the first value separately
252
257
253
- bqSession := it .arrowIterator .session .bqSession
258
+ session := it .arrowIterator .(* storageArrowIterator ).session
259
+ bqSession := session .bqSession
254
260
if len (bqSession .Streams ) == 0 {
255
261
t .Fatalf ("%s: expected to use at least one stream but found %d" , tc .name , len (bqSession .Streams ))
256
262
}
257
- streamSettings := it . arrowIterator . session .settings .maxStreamCount
263
+ streamSettings := session .settings .maxStreamCount
258
264
if tc .maxExpectedStreams > 0 {
259
265
if streamSettings > tc .maxExpectedStreams {
260
266
t .Fatalf ("%s: expected stream settings to be at most %d streams but found %d" , tc .name , tc .maxExpectedStreams , streamSettings )
@@ -317,7 +323,7 @@ func TestIntegration_StorageReadQueryStruct(t *testing.T) {
317
323
total ++
318
324
}
319
325
320
- bqSession := it .arrowIterator .session .bqSession
326
+ bqSession := it .arrowIterator .( * storageArrowIterator ). session .bqSession
321
327
if len (bqSession .Streams ) == 0 {
322
328
t .Fatalf ("should use more than one stream but found %d" , len (bqSession .Streams ))
323
329
}
@@ -366,7 +372,7 @@ func TestIntegration_StorageReadQueryMorePages(t *testing.T) {
366
372
}
367
373
total ++ // as we read the first value separately
368
374
369
- bqSession := it .arrowIterator .session .bqSession
375
+ bqSession := it .arrowIterator .( * storageArrowIterator ). session .bqSession
370
376
if len (bqSession .Streams ) == 0 {
371
377
t .Fatalf ("should use more than one stream but found %d" , len (bqSession .Streams ))
372
378
}
@@ -418,11 +424,88 @@ func TestIntegration_StorageReadCancel(t *testing.T) {
418
424
}
419
425
// resources are cleaned asynchronously
420
426
time .Sleep (time .Second )
421
- if ! it .arrowIterator .isDone () {
427
+ arrowIt := it .arrowIterator .(* storageArrowIterator )
428
+ if ! arrowIt .isDone () {
422
429
t .Fatal ("expected stream to be done" )
423
430
}
424
431
}
425
432
433
+ func TestIntegration_StorageReadArrow (t * testing.T ) {
434
+ if client == nil {
435
+ t .Skip ("Integration tests skipped" )
436
+ }
437
+ ctx := context .Background ()
438
+ table := "`bigquery-public-data.usa_names.usa_1910_current`"
439
+ sql := fmt .Sprintf (`SELECT name, number, state FROM %s where state = "CA"` , table )
440
+
441
+ q := storageOptimizedClient .Query (sql )
442
+ job , err := q .Run (ctx ) // force usage of Storage API by skipping fast paths
443
+ if err != nil {
444
+ t .Fatal (err )
445
+ }
446
+ it , err := job .Read (ctx )
447
+ if err != nil {
448
+ t .Fatal (err )
449
+ }
450
+
451
+ checkedAllocator := memory .NewCheckedAllocator (memory .DefaultAllocator )
452
+ it .arrowDecoder .allocator = checkedAllocator
453
+ defer checkedAllocator .AssertSize (t , 0 )
454
+
455
+ arrowIt , err := it .ArrowIterator ()
456
+ if err != nil {
457
+ t .Fatalf ("expected iterator to be accelerated: %v" , err )
458
+ }
459
+ arrowItReader := NewArrowIteratorReader (arrowIt )
460
+
461
+ records := []arrow.Record {}
462
+ r , err := ipc .NewReader (arrowItReader , ipc .WithAllocator (checkedAllocator ))
463
+ numrec := 0
464
+ for r .Next () {
465
+ rec := r .Record ()
466
+ rec .Retain ()
467
+ defer rec .Release ()
468
+ records = append (records , rec )
469
+ numrec += int (rec .NumRows ())
470
+ }
471
+ r .Release ()
472
+
473
+ arrowSchema := r .Schema ()
474
+ arrowTable := array .NewTableFromRecords (arrowSchema , records )
475
+ defer arrowTable .Release ()
476
+ if arrowTable .NumRows () != int64 (it .TotalRows ) {
477
+ t .Fatalf ("should have a table with %d rows, but found %d" , it .TotalRows , arrowTable .NumRows ())
478
+ }
479
+ if arrowTable .NumCols () != 3 {
480
+ t .Fatalf ("should have a table with 3 columns, but found %d" , arrowTable .NumCols ())
481
+ }
482
+
483
+ sumSQL := fmt .Sprintf (`SELECT sum(number) as total FROM %s where state = "CA"` , table )
484
+ sumQuery := client .Query (sumSQL )
485
+ sumIt , err := sumQuery .Read (ctx )
486
+ if err != nil {
487
+ t .Fatal (err )
488
+ }
489
+ sumValues := []Value {}
490
+ err = sumIt .Next (& sumValues )
491
+ if err != nil {
492
+ t .Fatal (err )
493
+ }
494
+ totalFromSQL := sumValues [0 ].(int64 )
495
+
496
+ tr := array .NewTableReader (arrowTable , arrowTable .NumRows ())
497
+ defer tr .Release ()
498
+ var totalFromArrow int64
499
+ for tr .Next () {
500
+ rec := tr .Record ()
501
+ vec := rec .Column (1 ).(* array.Int64 )
502
+ totalFromArrow += math .Int64 .Sum (vec )
503
+ }
504
+ if totalFromArrow != totalFromSQL {
505
+ t .Fatalf ("expected total to be %d, but with arrow we got %d" , totalFromSQL , totalFromArrow )
506
+ }
507
+ }
508
+
426
509
func countIteratorRows (it * RowIterator ) (total uint64 , err error ) {
427
510
for {
428
511
var dst []Value
0 commit comments