Skip to content

Commit d440d75

Browse files
author
TOGASHI Tomoki
authored
feat(spanner/spansql): add support for aggregate functions (#8498)
1 parent 9874485 commit d440d75

6 files changed

Lines changed: 173 additions & 13 deletions

File tree

spanner/spansql/keywords.go

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,16 @@ var keywords = map[string]bool{
129129
// https://cloud.google.com/spanner/docs/functions-and-operators
130130
var funcs = make(map[string]bool)
131131
var funcArgParsers = make(map[string]func(*parser) (Expr, *parseError))
132+
var aggregateFuncs = make(map[string]bool)
132133

133134
func init() {
134-
for _, f := range allFuncs {
135+
for _, f := range funcNames {
135136
funcs[f] = true
136137
}
138+
for _, f := range aggregateFuncNames {
139+
funcs[f] = true
140+
aggregateFuncs[f] = true
141+
}
137142
// Special case for CAST, SAFE_CAST and EXTRACT
138143
funcArgParsers["CAST"] = typedArgParser
139144
funcArgParsers["SAFE_CAST"] = typedArgParser
@@ -150,19 +155,9 @@ func init() {
150155
funcArgParsers["GET_INTERNAL_SEQUENCE_STATE"] = sequenceArgParser
151156
}
152157

153-
var allFuncs = []string{
158+
var funcNames = []string{
154159
// TODO: many more
155160

156-
// Aggregate functions.
157-
"ANY_VALUE",
158-
"ARRAY_AGG",
159-
"AVG",
160-
"BIT_XOR",
161-
"COUNT",
162-
"MAX",
163-
"MIN",
164-
"SUM",
165-
166161
// Cast functions.
167162
"CAST",
168163
"SAFE_CAST",
@@ -295,3 +290,28 @@ var allFuncs = []string{
295290
// Utility functions.
296291
"GENERATE_UUID",
297292
}
293+
294+
var aggregateFuncNames = []string{
295+
// Aggregate functions.
296+
"ANY_VALUE",
297+
"ARRAY_AGG",
298+
"ARRAY_CONCAT_AGG",
299+
"AVG",
300+
"BIT_AND",
301+
"BIT_OR",
302+
"BIT_XOR",
303+
"COUNT",
304+
"COUNTIF",
305+
"LOGICAL_AND",
306+
"LOGICAL_OR",
307+
"MAX",
308+
"MIN",
309+
"STRING_AGG",
310+
"SUM",
311+
312+
// Statistical aggregate functions.
313+
"STDDEV",
314+
"STDDEV_SAMP",
315+
"VAR_SAMP",
316+
"VARIANCE",
317+
}

spanner/spansql/parser.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3566,6 +3566,65 @@ var sequenceArgParser = func(p *parser) (Expr, *parseError) {
35663566
return p.parseExpr()
35673567
}
35683568

3569+
func (p *parser) parseAggregateFunc() (Func, *parseError) {
3570+
tok := p.next()
3571+
if tok.err != nil {
3572+
return Func{}, tok.err
3573+
}
3574+
name := strings.ToUpper(tok.value)
3575+
if err := p.expect("("); err != nil {
3576+
return Func{}, err
3577+
}
3578+
var distinct bool
3579+
if p.eat("DISTINCT") {
3580+
distinct = true
3581+
}
3582+
args, err := p.parseExprList()
3583+
if err != nil {
3584+
return Func{}, err
3585+
}
3586+
var nullsHandling NullsHandling
3587+
if p.eat("IGNORE", "NULLS") {
3588+
nullsHandling = IgnoreNulls
3589+
} else if p.eat("RESPECT", "NULLS") {
3590+
nullsHandling = RespectNulls
3591+
}
3592+
var having *AggregateHaving
3593+
if p.eat("HAVING") {
3594+
tok := p.next()
3595+
if tok.err != nil {
3596+
return Func{}, tok.err
3597+
}
3598+
var cond AggregateHavingCondition
3599+
switch tok.value {
3600+
case "MAX":
3601+
cond = HavingMax
3602+
case "MIN":
3603+
cond = HavingMin
3604+
default:
3605+
return Func{}, p.errorf("got %q, want MAX or MIN", tok.value)
3606+
}
3607+
expr, err := p.parseExpr()
3608+
if err != nil {
3609+
return Func{}, err
3610+
}
3611+
having = &AggregateHaving{
3612+
Condition: cond,
3613+
Expr: expr,
3614+
}
3615+
}
3616+
if err := p.expect(")"); err != nil {
3617+
return Func{}, err
3618+
}
3619+
return Func{
3620+
Name: name,
3621+
Args: args,
3622+
Distinct: distinct,
3623+
NullsHandling: nullsHandling,
3624+
Having: having,
3625+
}, nil
3626+
}
3627+
35693628
/*
35703629
Expressions
35713630
@@ -3918,6 +3977,10 @@ func (p *parser) parseLit() (Expr, *parseError) {
39183977
// this is a function invocation.
39193978
// The `funcs` map is keyed by upper case strings.
39203979
if name := strings.ToUpper(tok.value); funcs[name] && p.sniff("(") {
3980+
if aggregateFuncs[name] {
3981+
p.back()
3982+
return p.parseAggregateFunc()
3983+
}
39213984
var list []Expr
39223985
var err *parseError
39233986
if f, ok := funcArgParsers[name]; ok {

spanner/spansql/parser_test.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,13 @@ func TestParseExpr(t *testing.T) {
419419
{`GET_NEXT_SEQUENCE_VALUE(SEQUENCE MySequence)`, Func{Name: "GET_NEXT_SEQUENCE_VALUE", Args: []Expr{SequenceExpr{Name: ID("MySequence")}}}},
420420
{`GET_INTERNAL_SEQUENCE_STATE(SEQUENCE MySequence)`, Func{Name: "GET_INTERNAL_SEQUENCE_STATE", Args: []Expr{SequenceExpr{Name: ID("MySequence")}}}},
421421

422+
// Aggregate Functions
423+
{`COUNT(*)`, Func{Name: "COUNT", Args: []Expr{Star}}},
424+
{`COUNTIF(DISTINCT cname)`, Func{Name: "COUNTIF", Args: []Expr{ID("cname")}, Distinct: true}},
425+
{`ARRAY_AGG(Foo IGNORE NULLS)`, Func{Name: "ARRAY_AGG", Args: []Expr{ID("Foo")}, NullsHandling: IgnoreNulls}},
426+
{`ANY_VALUE(Foo HAVING MAX Bar)`, Func{Name: "ANY_VALUE", Args: []Expr{ID("Foo")}, Having: &AggregateHaving{Condition: HavingMax, Expr: ID("Bar")}}},
427+
{`STRING_AGG(DISTINCT Foo, "," IGNORE NULLS HAVING MAX Bar)`, Func{Name: "STRING_AGG", Args: []Expr{ID("Foo"), StringLiteral(",")}, Distinct: true, NullsHandling: IgnoreNulls, Having: &AggregateHaving{Condition: HavingMax, Expr: ID("Bar")}}},
428+
422429
// Conditional expressions
423430
{
424431
`CASE X WHEN 1 THEN "X" WHEN 2 THEN "Y" ELSE NULL END`,

spanner/spansql/sql.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -913,7 +913,27 @@ func (f Func) SQL() string { return buildSQL(f) }
913913
func (f Func) addSQL(sb *strings.Builder) {
914914
sb.WriteString(f.Name)
915915
sb.WriteString("(")
916+
if f.Distinct {
917+
sb.WriteString("DISTINCT ")
918+
}
916919
addExprList(sb, f.Args, ", ")
920+
switch f.NullsHandling {
921+
case RespectNulls:
922+
sb.WriteString(" RESPECT NULLS")
923+
case IgnoreNulls:
924+
sb.WriteString(" IGNORE NULLS")
925+
}
926+
if ah := f.Having; ah != nil {
927+
sb.WriteString(" HAVING")
928+
switch ah.Condition {
929+
case HavingMax:
930+
sb.WriteString(" MAX")
931+
case HavingMin:
932+
sb.WriteString(" MIN")
933+
}
934+
sb.WriteString(" ")
935+
sb.WriteString(ah.Expr.SQL())
936+
}
917937
sb.WriteString(")")
918938
}
919939

spanner/spansql/sql_test.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -970,6 +970,31 @@ func TestSQL(t *testing.T) {
970970
`SELECT SAFE_CAST(7 AS DATE)`,
971971
reparseQuery,
972972
},
973+
{
974+
Func{Name: "COUNT", Args: []Expr{Star}},
975+
`COUNT(*)`,
976+
reparseExpr,
977+
},
978+
{
979+
Func{Name: "COUNTIF", Args: []Expr{ID("cname")}, Distinct: true},
980+
`COUNTIF(DISTINCT cname)`,
981+
reparseExpr,
982+
},
983+
{
984+
Func{Name: "ARRAY_AGG", Args: []Expr{ID("Foo")}, NullsHandling: IgnoreNulls},
985+
`ARRAY_AGG(Foo IGNORE NULLS)`,
986+
reparseExpr,
987+
},
988+
{
989+
Func{Name: "ANY_VALUE", Args: []Expr{ID("Foo")}, Having: &AggregateHaving{Condition: HavingMax, Expr: ID("Bar")}},
990+
`ANY_VALUE(Foo HAVING MAX Bar)`,
991+
reparseExpr,
992+
},
993+
{
994+
Func{Name: "STRING_AGG", Args: []Expr{ID("Foo"), StringLiteral(",")}, Distinct: true, NullsHandling: IgnoreNulls, Having: &AggregateHaving{Condition: HavingMax, Expr: ID("Bar")}},
995+
`STRING_AGG(DISTINCT Foo, "," IGNORE NULLS HAVING MAX Bar)`,
996+
reparseExpr,
997+
},
973998
{
974999
ComparisonOp{LHS: ID("X"), Op: NotBetween, RHS: ID("Y"), RHS2: ID("Z")},
9751000
`X NOT BETWEEN Y AND Z`,

spanner/spansql/types.go

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,9 @@ type Func struct {
757757
Name string // not ID
758758
Args []Expr
759759

760-
// TODO: various functions permit as-expressions, which might warrant different types in here.
760+
Distinct bool
761+
NullsHandling NullsHandling
762+
Having *AggregateHaving
761763
}
762764

763765
func (Func) isBoolExpr() {} // possibly bool
@@ -804,6 +806,29 @@ type SequenceExpr struct {
804806

805807
func (SequenceExpr) isExpr() {}
806808

809+
// NullsHandling represents the method of dealing with NULL values in aggregate functions.
810+
type NullsHandling int
811+
812+
const (
813+
NullsHandlingUnspecified NullsHandling = iota
814+
RespectNulls
815+
IgnoreNulls
816+
)
817+
818+
// AggregateHaving represents the HAVING clause specific to aggregate functions, restricting rows based on a maximal or minimal value.
819+
type AggregateHaving struct {
820+
Condition AggregateHavingCondition
821+
Expr Expr
822+
}
823+
824+
// AggregateHavingCondition represents the condition (MAX or MIN) for the AggregateHaving clause.
825+
type AggregateHavingCondition int
826+
827+
const (
828+
HavingMax AggregateHavingCondition = iota
829+
HavingMin
830+
)
831+
807832
// Paren represents a parenthesised expression.
808833
type Paren struct {
809834
Expr Expr

0 commit comments

Comments
 (0)