Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Setup SQL server based on pgx frontend/backend protocol #4881

Closed
wants to merge 36 commits into from
Closed
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
7d186df
postgres wire protocol on runtime - initial commit
k-anshul May 8, 2024
18a70c1
postgres wire protocol on runtime
k-anshul May 9, 2024
3245765
connectivity with postgres server
k-anshul May 10, 2024
cf72180
handle complex types
k-anshul May 10, 2024
3b84741
Merge remote-tracking branch 'origin/main' into postgres_proxy
k-anshul May 13, 2024
ec85fd9
fix go.sum
k-anshul May 13, 2024
c5a0a78
interim commit
k-anshul May 21, 2024
2a9177d
superset working
k-anshul May 22, 2024
6cfb77f
Merge remote-tracking branch 'origin/main' into postgres_proxy
k-anshul May 22, 2024
965fa37
refactor
k-anshul May 23, 2024
652fded
lint and fmt fix
k-anshul May 27, 2024
c557035
revert version upgrade
k-anshul May 27, 2024
a516a6b
quick review
k-anshul May 27, 2024
c2d98cc
Merge remote-tracking branch 'origin/main' into postgres_proxy
k-anshul May 27, 2024
d42ced6
Update cli/cmd/start/start.go
k-anshul May 28, 2024
08594bd
Apply suggestions from code review
k-anshul May 29, 2024
b5b3d3b
interim commit
k-anshul May 29, 2024
3cd52c2
interim changes
k-anshul Jun 4, 2024
f9ef986
add some godoc
k-anshul Jun 4, 2024
29a665e
Merge remote-tracking branch 'origin/main' into postgres_proxy
k-anshul Jun 4, 2024
d1e11c3
fix new metric resolver
k-anshul Jun 4, 2024
86255bc
json deserialize fix
k-anshul Jun 5, 2024
5c36d4b
fix unit test
k-anshul Jun 5, 2024
cb97546
Merge remote-tracking branch 'origin/main' into postgres_proxy
k-anshul Jun 11, 2024
8668da1
fix unit test
k-anshul Jun 11, 2024
d2804bf
also handle last error while iterating rows
k-anshul Jun 11, 2024
d4b991a
interim review
k-anshul Jun 11, 2024
4953555
Merge remote-tracking branch 'origin/main' into postgres_proxy
k-anshul Jun 18, 2024
35876f5
handle rotating tls configs
k-anshul Jun 18, 2024
6cef73c
lint fix
k-anshul Jun 19, 2024
3b99d1e
use prepared stmt to get schema description
k-anshul Jun 19, 2024
7e16c82
Merge remote-tracking branch 'origin/main' into postgres_proxy
k-anshul Jun 19, 2024
cac35ae
main merge conflicts
k-anshul Jun 19, 2024
0e82228
dev tests
k-anshul Jun 20, 2024
f3f1803
fix ping query
k-anshul Jun 20, 2024
86c2db8
Merge remote-tracking branch 'origin/main' into postgres_proxy
k-anshul Jun 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions admin/server/auth/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,22 @@ func (a *Authenticator) HTTPMiddlewareLenient(next http.Handler) http.Handler {
return a.httpMiddleware(next, true)
}

type PostgresPassword struct{}
k-anshul marked this conversation as resolved.
Show resolved Hide resolved

func (a *Authenticator) PostgresAuthHandler() func(ctx context.Context, username, password string) (context.Context, bool, error) {
k-anshul marked this conversation as resolved.
Show resolved Hide resolved
return func(ctx context.Context, username, password string) (context.Context, bool, error) {
// Clients do not pass Bearer to avoid encoding the password
// The default token type is Bearer. In case of different token type, the type will be passed in postgres additional properties
password = fmt.Sprintf("bearer %s", password)
ctx = context.WithValue(ctx, PostgresPassword{}, password)
newCtx, err := a.parseClaimsFromBearer(ctx, password)
if err != nil {
newCtx = context.WithValue(ctx, claimsContextKey{}, anonClaims{})
}
k-anshul marked this conversation as resolved.
Show resolved Hide resolved
return newCtx, true, nil
}
}

// httpMiddleware is the actual implementation of HTTPMiddleware and HTTPMiddlewareLenient.
func (a *Authenticator) httpMiddleware(next http.Handler, lenient bool) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down
204 changes: 204 additions & 0 deletions admin/server/postgres.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
package server

import (
"context"
"fmt"
"net/url"
"strconv"
"strings"
"sync"
"time"

"github.com/jackc/pgx/v5/pgxpool"
wire "github.com/jeroenrinzema/psql-wire"
"github.com/lib/pq/oid"
"github.com/rilldata/rill/admin/server/auth"
runtimeauth "github.com/rilldata/rill/runtime/server/auth"
"go.uber.org/zap"
)

func (s *Server) QueryHandler(ctx context.Context, query string) (wire.PreparedStatements, error) {
begelundmuller marked this conversation as resolved.
Show resolved Hide resolved
s.logger.Debug("query", zap.String("query", query))
k-anshul marked this conversation as resolved.
Show resolved Hide resolved
if strings.Trim(query, " ") == "" {
return wire.Prepared(wire.NewStatement(func(ctx context.Context, writer wire.DataWriter, parameters []wire.Parameter) error {
return writer.Empty()
})), nil
}

upperQuery := strings.ToUpper(query)
if strings.HasPrefix(upperQuery, "SET") {
return wire.Prepared(wire.NewStatement(func(ctx context.Context, writer wire.DataWriter, parameters []wire.Parameter) error {
return writer.Complete("SET")
}, wire.WithColumns(nil))), nil
}

if strings.HasPrefix(upperQuery, "BEGIN") || strings.HasPrefix(upperQuery, "COMMIT") || strings.HasPrefix(upperQuery, "ROLLBACK") {
return wire.Prepared(wire.NewStatement(func(ctx context.Context, writer wire.DataWriter, parameters []wire.Parameter) error {
return writer.Complete(strings.Trim(upperQuery, ";"))
}, wire.WithColumns(nil))), nil
}
k-anshul marked this conversation as resolved.
Show resolved Hide resolved

clientParams := wire.ClientParameters(ctx)
// database is org.password
k-anshul marked this conversation as resolved.
Show resolved Hide resolved
tokens := strings.Split(clientParams[wire.ParamDatabase], ".")
if len(tokens) != 2 {
return nil, fmt.Errorf("invalid org or project")
}
org := tokens[0]
project := tokens[1]

// Find the production deployment for the project we're proxying to
proj, err := s.admin.DB.FindProjectByName(ctx, org, project)
if err != nil {
return nil, fmt.Errorf("invalid org or project")
}

if proj.ProdDeploymentID == nil {
return nil, fmt.Errorf("no prod deployment for project")
}
depl, err := s.admin.DB.FindDeployment(ctx, *proj.ProdDeploymentID)
if err != nil {
return nil, fmt.Errorf("no prod deployment for project")
}

var jwt string
claims := auth.GetClaims(ctx)
switch claims.OwnerType() {
case auth.OwnerTypeAnon:
// If the client is not authenticated with the admin service, we just proxy the contents of the password to the runtime (if any).
password := ctx.Value(auth.PostgresPassword{}).(string)
if len(password) >= 6 && strings.EqualFold(password[0:6], "bearer") {
jwt = strings.TrimSpace(password[6:])
}
k-anshul marked this conversation as resolved.
Show resolved Hide resolved
case auth.OwnerTypeUser, auth.OwnerTypeService:
// If the client is authenticated with the admin service, we issue a new ephemeral runtime JWT.
// The JWT should have the same permissions/configuration as one they would get by calling AdminService.GetProject.
permissions := claims.ProjectPermissions(ctx, proj.OrganizationID, depl.ProjectID)
if !permissions.ReadProd {
return nil, fmt.Errorf("does not have permission to access the production deployment")
}

var attr map[string]any
if claims.OwnerType() == auth.OwnerTypeUser {
attr, err = s.jwtAttributesForUser(ctx, claims.OwnerID(), proj.OrganizationID, permissions)
if err != nil {
return nil, err
}
}

jwt, err = s.issuer.NewToken(runtimeauth.TokenOptions{
AudienceURL: depl.RuntimeAudience,
Subject: claims.OwnerID(),
TTL: runtimeAccessTokenDefaultTTL,
k-anshul marked this conversation as resolved.
Show resolved Hide resolved
InstancePermissions: map[string][]runtimeauth.Permission{
depl.RuntimeInstanceID: {
// TODO: Remove ReadProfiling and ReadRepo (may require frontend changes)
runtimeauth.ReadObjects,
runtimeauth.ReadMetrics,
runtimeauth.ReadProfiling,
runtimeauth.ReadRepo,
runtimeauth.ReadAPI,
},
},
Attributes: attr,
})
if err != nil {
return nil, err
}
default:
return nil, fmt.Errorf("runtime proxy not available for owner type %q", claims.OwnerType())
}

// Track usage of the deployment
s.admin.Used.Deployment(depl.ID)

hostURL, err := url.Parse(depl.RuntimeHost)
if err != nil {
return nil, err
}
hostURL.Scheme = "postgres"
hostURL.Host = hostURL.Hostname() + ":" + strconv.FormatInt(int64(15432), 10)
k-anshul marked this conversation as resolved.
Show resolved Hide resolved
hostURL.User = url.UserPassword("postgres", fmt.Sprintf("Bearer %s", jwt))
hostURL.Path = depl.RuntimeInstanceID
conn, err := connectionPool(ctx, hostURL.String())
if err != nil {
s.logger.Info("error in get connection pool", zap.Error(err))
return nil, err
}
k-anshul marked this conversation as resolved.
Show resolved Hide resolved

rows, err := conn.Query(ctx, query) // query to underlying host
if err != nil {
s.logger.Info("error in query", zap.Error(err))
return nil, err
}
defer rows.Close()

// handle schema
var cols []wire.Column
fds := rows.FieldDescriptions()
k-anshul marked this conversation as resolved.
Show resolved Hide resolved
for _, fd := range fds {
cols = append(cols, wire.Column{
Table: int32(fd.TableOID),
Name: fd.Name,
Oid: oid.Oid(fd.DataTypeOID),
Width: fd.DataTypeSize,
Attr: int16(fd.TableAttributeNumber),
})
}

// handle data
// NOTE :: This creates a copy of data and stores this till client starts reading data. This is required so that we
// can close runtime connection and not wait for client to complete reading whole data which can leak connection.
// We can improve this logic in future.
k-anshul marked this conversation as resolved.
Show resolved Hide resolved
var data [][]any
for rows.Next() {
d, err := rows.Values()
if err != nil {
s.logger.Info("error in fetching next row", zap.Error(err))
return nil, err
}
data = append(data, d)
}
if rows.Err() != nil {
s.logger.Info("error in fetching rows", zap.Error(err))
return nil, err
}

handle := func(ctx context.Context, writer wire.DataWriter, parameters []wire.Parameter) error {
for i := 0; i < len(data); i++ {
if err := writer.Row(data[i]); err != nil {
return err
}
}
return writer.Complete("OK")
}
return wire.Prepared(wire.NewStatement(handle, wire.WithColumns(cols))), nil
}

var (
runtimePool map[string]*pgxpool.Pool = make(map[string]*pgxpool.Pool)
mu sync.Mutex
)

func connectionPool(ctx context.Context, dsn string) (*pgxpool.Pool, error) {
k-anshul marked this conversation as resolved.
Show resolved Hide resolved
mu.Lock()
defer mu.Unlock()

pool, ok := runtimePool[dsn]
if ok {
return pool, nil
}

config, err := pgxpool.ParseConfig(dsn)
if err != nil {
return nil, fmt.Errorf("failed to parse dsn: %w", err)
}

// Runtime JWts are valid for 30 minutes only
config.MaxConnLifetime = time.Minute * 29
// since runtimes get restarted more often than actual DB servers. Consider if this should be reduced to even less time
// also consider if we should add some health check on connection acquisition
config.HealthCheckPeriod = time.Minute

return pgxpool.NewWithConfig(ctx, config)
}
7 changes: 7 additions & 0 deletions admin/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ var (
type Options struct {
HTTPPort int
GRPCPort int
PostgresPort int
ExternalURL string
FrontendURL string
AllowedOrigins []string
Expand Down Expand Up @@ -189,6 +190,12 @@ func (s *Server) ServeHTTP(ctx context.Context) error {
})
}

// ServePostgres Starts the postrges server.
func (s *Server) ServePostgres(ctx context.Context) error {
authHandler := s.authenticator.PostgresAuthHandler()
return graceful.ServePostgres(ctx, s.QueryHandler, authHandler, s.opts.PostgresPort, false, s.logger)
}

// HTTPHandler HTTP handler serving REST gateway.
func (s *Server) HTTPHandler(ctx context.Context) (http.Handler, error) {
// Create REST gateway
Expand Down
3 changes: 3 additions & 0 deletions cli/cmd/admin/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ type Config struct {
TracesExporter observability.Exporter `default:"" split_words:"true"`
HTTPPort int `default:"8080" split_words:"true"`
GRPCPort int `default:"9090" split_words:"true"`
PostgresPort int `default:"25432" split_words:"true"`
k-anshul marked this conversation as resolved.
Show resolved Hide resolved
DebugPort int `split_words:"true"`
ExternalURL string `default:"http://localhost:8080" split_words:"true"`
ExternalGRPCURL string `envconfig:"external_grpc_url"`
Expand Down Expand Up @@ -288,6 +289,7 @@ func StartCmd(ch *cmdutil.Helper) *cobra.Command {
srv, err := server.New(logger, adm, issuer, limiter, activityClient, &server.Options{
HTTPPort: conf.HTTPPort,
GRPCPort: conf.GRPCPort,
PostgresPort: conf.PostgresPort,
ExternalURL: conf.ExternalURL,
FrontendURL: conf.FrontendURL,
AllowedOrigins: conf.AllowedOrigins,
Expand All @@ -306,6 +308,7 @@ func StartCmd(ch *cmdutil.Helper) *cobra.Command {
}
group.Go(func() error { return srv.ServeGRPC(cctx) })
group.Go(func() error { return srv.ServeHTTP(cctx) })
group.Go(func() error { return srv.ServePostgres(cctx) })
if conf.DebugPort != 0 {
group.Go(func() error { return debugserver.ServeHTTP(cctx, conf.DebugPort) })
}
Expand Down
5 changes: 4 additions & 1 deletion cli/cmd/runtime/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ import (
_ "github.com/rilldata/rill/runtime/drivers/s3"
_ "github.com/rilldata/rill/runtime/drivers/salesforce"
_ "github.com/rilldata/rill/runtime/drivers/slack"
_ "github.com/rilldata/rill/runtime/drivers/snowflake"
_ "github.com/rilldata/rill/runtime/drivers/sqlite"
_ "github.com/rilldata/rill/runtime/reconcilers"
_ "github.com/rilldata/rill/runtime/resolvers"
Expand All @@ -62,6 +61,7 @@ type Config struct {
HTTPPort int `default:"8080" split_words:"true"`
GRPCPort int `default:"9090" split_words:"true"`
DebugPort int `default:"6060" split_words:"true"`
PostgresPort int `default:"15432" split_words:"true"`
k-anshul marked this conversation as resolved.
Show resolved Hide resolved
AllowedOrigins []string `default:"*" split_words:"true"`
SessionKeyPairs []string `split_words:"true"`
AuthEnable bool `default:"false" split_words:"true"`
Expand Down Expand Up @@ -237,12 +237,14 @@ func StartCmd(ch *cmdutil.Helper) *cobra.Command {
srvOpts := &server.Options{
HTTPPort: conf.HTTPPort,
GRPCPort: conf.GRPCPort,
PostgresPort: conf.PostgresPort,
AllowedOrigins: conf.AllowedOrigins,
ServePrometheus: conf.MetricsExporter == observability.PrometheusExporter,
SessionKeyPairs: keyPairs,
AuthEnable: conf.AuthEnable,
AuthIssuerURL: conf.AuthIssuerURL,
AuthAudienceURL: conf.AuthAudienceURL,
DataDir: conf.DataDir,
}
s, err := server.NewServer(ctx, srvOpts, rt, logger, limiter, activityClient)
if err != nil {
Expand All @@ -253,6 +255,7 @@ func StartCmd(ch *cmdutil.Helper) *cobra.Command {
group, cctx := errgroup.WithContext(ctx)
group.Go(func() error { return s.ServeGRPC(cctx) })
group.Go(func() error { return s.ServeHTTP(cctx, nil) })
group.Go(func() error { return s.ServePostgres(cctx, true) })
if conf.DebugPort != 0 {
group.Go(func() error { return debugserver.ServeHTTP(cctx, conf.DebugPort) })
}
Expand Down
4 changes: 3 additions & 1 deletion cli/cmd/start/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ func StartCmd(ch *cmdutil.Helper) *cobra.Command {
var olapDSN string
var httpPort int
var grpcPort int
var postgresPort int
var verbose bool
var debug bool
var readonly bool
Expand Down Expand Up @@ -157,7 +158,7 @@ func StartCmd(ch *cmdutil.Helper) *cobra.Command {

userID, _ := ch.CurrentUserID(cmd.Context())

err = app.Serve(httpPort, grpcPort, !noUI, !noOpen, readonly, userID, tlsCertPath, tlsKeyPath)
err = app.Serve(httpPort, grpcPort, postgresPort, !noUI, !noOpen, readonly, userID, tlsCertPath, tlsKeyPath)
if err != nil {
return fmt.Errorf("serve: %w", err)
}
Expand All @@ -170,6 +171,7 @@ func StartCmd(ch *cmdutil.Helper) *cobra.Command {
startCmd.Flags().BoolVar(&noOpen, "no-open", false, "Do not open browser")
startCmd.Flags().IntVar(&httpPort, "port", 9009, "Port for HTTP")
startCmd.Flags().IntVar(&grpcPort, "port-grpc", 49009, "Port for gRPC (internal)")
startCmd.Flags().IntVar(&postgresPort, "port-postgres", 0, "Port for postgres server")
k-anshul marked this conversation as resolved.
Show resolved Hide resolved
startCmd.Flags().BoolVar(&readonly, "readonly", false, "Show only dashboards in UI")
startCmd.Flags().BoolVar(&noUI, "no-ui", false, "Serve only the backend")
startCmd.Flags().BoolVar(&verbose, "verbose", false, "Sets the log level to debug")
Expand Down
8 changes: 7 additions & 1 deletion cli/pkg/local/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ func (a *App) Close() error {
return nil
}

func (a *App) Serve(httpPort, grpcPort int, enableUI, openBrowser, readonly bool, userID, tlsCertPath, tlsKeyPath string) error {
func (a *App) Serve(httpPort, grpcPort, postgresPort int, enableUI, openBrowser, readonly bool, userID, tlsCertPath, tlsKeyPath string) error {
// Get analytics info
installID, enabled, err := dotrill.AnalyticsInfo()
if err != nil {
Expand Down Expand Up @@ -378,6 +378,7 @@ func (a *App) Serve(httpPort, grpcPort int, enableUI, openBrowser, readonly bool
opts := &runtimeserver.Options{
HTTPPort: httpPort,
GRPCPort: grpcPort,
PostgresPort: postgresPort,
TLSCertPath: tlsCertPath,
TLSKeyPath: tlsKeyPath,
AllowedOrigins: []string{"*"},
Expand Down Expand Up @@ -408,6 +409,11 @@ func (a *App) Serve(httpPort, grpcPort int, enableUI, openBrowser, readonly bool
if a.Debug {
group.Go(func() error { return debugserver.ServeHTTP(ctx, 6060) })
}
if postgresPort != 0 {
group.Go(func() error {
return runtimeServer.ServePostgres(ctx, false)
})
}

// Open the browser when health check succeeds
go a.pollServer(ctx, httpPort, enableUI && openBrowser, secure)
Expand Down
Loading
Loading