Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(pgdialect): postgres syntax errors for pointers and slices #877 #1111

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dialect/pgdialect/array.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ func (d *Dialect) arrayElemAppender(typ reflect.Type) schema.AppenderFunc {
if typ.Elem().Kind() == reflect.Uint8 {
return appendBytesElemValue
}
case reflect.Ptr:
return schema.PtrAppender(d.arrayElemAppender(typ.Elem()))
}
return schema.Appender(d, typ)
}
Expand Down
50 changes: 38 additions & 12 deletions dialect/pgdialect/array_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,23 @@ type arrayParser struct {

elem []byte
err error

isJson bool
}

func newArrayParser(b []byte) *arrayParser {
p := new(arrayParser)

if len(b) < 2 || b[0] != '{' || b[len(b)-1] != '}' {
if b[0] == 'n' {
p.p.Reset(nil)
return p
}

if len(b) < 2 || (b[0] != '{' && b[0] != '[') || (b[len(b)-1] != '}' && b[len(b)-1] != ']') {
p.err = fmt.Errorf("pgdialect: can't parse array: %q", b)
return p
}
p.isJson = b[0] == '['

p.p.Reset(b[1 : len(b)-1])
return p
Expand Down Expand Up @@ -51,7 +59,7 @@ func (p *arrayParser) readNext() error {
}

switch ch {
case '}':
case '}', ']':
return io.EOF
case '"':
b, err := p.p.ReadSubstring(ch)
Expand All @@ -78,16 +86,34 @@ func (p *arrayParser) readNext() error {
p.elem = rng
return nil
default:
lit := p.p.ReadLiteral(ch)
if bytes.Equal(lit, []byte("NULL")) {
lit = nil
}

if p.p.Peek() == ',' {
p.p.Advance()
if ch == '{' && p.isJson {
json, err := p.p.ReadJSON()
if err != nil {
return err
}

for {
if p.p.Peek() == ',' || p.p.Peek() == ' ' {
p.p.Advance()
} else {
break
}
}

p.elem = json
return nil
} else {
lit := p.p.ReadLiteral(ch)
if bytes.Equal(lit, []byte("NULL")) {
lit = nil
}

if p.p.Peek() == ',' {
p.p.Advance()
}

p.elem = lit
return nil
}

p.elem = lit
return nil
}
}
4 changes: 4 additions & 0 deletions dialect/pgdialect/array_parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ func TestArrayParser(t *testing.T) {
{`{"1","2"}`, []string{"1", "2"}},
{`{"{1}","{2}"}`, []string{"{1}", "{2}"}},
{`{[1,2),[3,4)}`, []string{"[1,2)", "[3,4)"}},

{`[]`, []string{}},
{`[{"'\"[]"}]`, []string{`{"'\"[]"}`}},
{`[{"id": 1}, {"id":2, "name":"bob"}]`, []string{"{\"id\": 1}", "{\"id\":2, \"name\":\"bob\"}"}},
}

for i, test := range tests {
Expand Down
54 changes: 54 additions & 0 deletions dialect/pgdialect/array_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package pgdialect

import (
"testing"

"github.com/uptrace/bun/schema"
)

func ptr[T any](v T) *T {
return &v
}

func TestArrayAppend(t *testing.T) {
tcases := []struct {
input interface{}
out string
}{
{
input: []byte{1, 2},
out: `'{1,2}'`,
},
{
input: []*byte{ptr(byte(1)), ptr(byte(2))},
out: `'{1,2}'`,
},
{
input: []int{1, 2},
out: `'{1,2}'`,
},
{
input: []*int{ptr(1), ptr(2)},
out: `'{1,2}'`,
},
{
input: []string{"foo", "bar"},
out: `'{"foo","bar"}'`,
},
{
input: []*string{ptr("foo"), ptr("bar")},
out: `'{"foo","bar"}'`,
},
}

for _, tcase := range tcases {
out, err := Array(tcase.input).AppendQuery(schema.NewFormatter(New()), []byte{})
if err != nil {
t.Fatal(err)
}

if string(out) != tcase.out {
t.Errorf("expected output to be %s, was %s", tcase.out, string(out))
}
}
}
36 changes: 36 additions & 0 deletions dialect/pgdialect/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,39 @@ func (p *pgparser) ReadRange(ch byte) ([]byte, error) {

return p.buf, nil
}

func (p *pgparser) ReadJSON() ([]byte, error) {
p.Unread()

c, err := p.ReadByte()
if err != nil {
return nil, err
}

p.buf = p.buf[:0]

depth := 0
for {
switch c {
case '{':
depth++
case '}':
depth--
}

p.buf = append(p.buf, c)

if depth == 0 {
break
}

next, err := p.ReadByte()
if err != nil {
return nil, err
}

c = next
}

return p.buf, nil
}
4 changes: 4 additions & 0 deletions dialect/pgdialect/sqltype.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ func fieldSQLType(field *schema.Field) string {
}

func sqlType(typ reflect.Type) string {
if typ.Kind() == reflect.Ptr {
typ = typ.Elem()
}

switch typ {
case nullStringType: // typ.Kind() == reflect.Struct, test for exact match
return sqltype.VarChar
Expand Down
93 changes: 93 additions & 0 deletions internal/dbtest/pg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/pgdialect"
"github.com/uptrace/bun/driver/pgdriver"
"github.com/uptrace/bun/schema"
)

func TestPostgresArray(t *testing.T) {
Expand All @@ -25,16 +26,20 @@ func TestPostgresArray(t *testing.T) {
Array1 []string `bun:",array"`
Array2 *[]string `bun:",array"`
Array3 *[]string `bun:",array"`
Array4 []*string `bun:",array"`
}

db := pg(t)
t.Cleanup(func() { db.Close() })
mustResetModel(t, ctx, db, (*Model)(nil))

str1 := "hello"
str2 := "world"
model1 := &Model{
ID: 123,
Array1: []string{"one", "two", "three"},
Array2: &[]string{"hello", "world"},
Array4: []*string{&str1, &str2},
}
_, err := db.NewInsert().Model(model1).Exec(ctx)
require.NoError(t, err)
Expand All @@ -56,6 +61,12 @@ func TestPostgresArray(t *testing.T) {
Scan(ctx, pgdialect.Array(&strs))
require.NoError(t, err)
require.Nil(t, strs)

err = db.NewSelect().Model((*Model)(nil)).
Column("array4").
Scan(ctx, pgdialect.Array(&strs))
require.NoError(t, err)
require.Equal(t, []string{"hello", "world"}, strs)
}

func TestPostgresArrayQuote(t *testing.T) {
Expand Down Expand Up @@ -877,3 +888,85 @@ func TestPostgresMultiRange(t *testing.T) {
err = db.NewSelect().Model(out).Scan(ctx)
require.NoError(t, err)
}

type UserID struct {
ID string
}

func (u UserID) AppendQuery(fmter schema.Formatter, b []byte) ([]byte, error) {
v := []byte(`"` + u.ID + `"`)
return append(b, v...), nil
}

var _ schema.QueryAppender = (*UserID)(nil)

func (r *UserID) Scan(anySrc any) (err error) {
src, ok := anySrc.([]byte)
if !ok {
return fmt.Errorf("pgdialect: Range can't scan %T", anySrc)
}

r.ID = string(src)
return nil
}

var _ sql.Scanner = (*UserID)(nil)

func TestPostgresJSONB(t *testing.T) {
type Item struct {
Name string `json:"name"`
}
type Model struct {
ID int64 `bun:",pk,autoincrement"`
Item Item `bun:",type:jsonb"`
ItemPtr *Item `bun:",type:jsonb"`
Items []Item `bun:",type:jsonb"`
ItemsP []*Item `bun:",type:jsonb"`
ItemsNull []*Item `bun:",type:jsonb"`
TextItemA []UserID `bun:"type:text[]"`
}

db := pg(t)
t.Cleanup(func() { db.Close() })
mustResetModel(t, ctx, db, (*Model)(nil))

item1 := Item{Name: "one"}
item2 := Item{Name: "two"}
uid1 := UserID{ID: "1"}
uid2 := UserID{ID: "2"}
model1 := &Model{
ID: 123,
Item: item1,
ItemPtr: &item2,
Items: []Item{item1, item2},
ItemsP: []*Item{&item1, &item2},
ItemsNull: nil,
TextItemA: []UserID{uid1, uid2},
}
_, err := db.NewInsert().Model(model1).Exec(ctx)
require.NoError(t, err)

model2 := new(Model)
err = db.NewSelect().Model(model2).Scan(ctx)
require.NoError(t, err)
require.Equal(t, model1, model2)

var items []Item
err = db.NewSelect().Model((*Model)(nil)).
Column("items").
Scan(ctx, pgdialect.Array(&items))
require.NoError(t, err)
require.Equal(t, []Item{item1, item2}, items)

err = db.NewSelect().Model((*Model)(nil)).
Column("itemsp").
Scan(ctx, pgdialect.Array(&items))
require.NoError(t, err)
require.Equal(t, []Item{item1, item2}, items)

err = db.NewSelect().Model((*Model)(nil)).
Column("items_null").
Scan(ctx, pgdialect.Array(&items))
require.NoError(t, err)
require.Equal(t, []Item{}, items)
}
Loading