Skip to content

Commit

Permalink
GODRIVER-3284 Allow valid SRV hostnames with less than 3 parts. (#1898)
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyang-hu authored Jan 14, 2025
1 parent c79e929 commit 9afdc8c
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 4 deletions.
122 changes: 122 additions & 0 deletions x/mongo/driver/connstring/initial_dns_seedlist_discovery_prose_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// Copyright (C) MongoDB, Inc. 2024-present.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0

package connstring

import (
"fmt"
"net"
"testing"

"go.mongodb.org/mongo-driver/v2/internal/assert"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/dns"
)

func TestInitialDNSSeedlistDiscoveryProse(t *testing.T) {
newTestParser := func(record string) *parser {
return &parser{&dns.Resolver{
LookupSRV: func(_, _, _ string) (string, []*net.SRV, error) {
return "", []*net.SRV{
{
Target: record,
Port: 27017,
},
}, nil
},
LookupTXT: func(string) ([]string, error) {
return nil, nil
},
}}
}

t.Run("1. Allow SRVs with fewer than 3 . separated parts", func(t *testing.T) {
t.Parallel()

cases := []struct {
record string
uri string
}{
{"test_1.localhost", "mongodb+srv://localhost"},
{"test_1.mongo.local", "mongodb+srv://mongo.local"},
}
for _, c := range cases {
c := c
t.Run(c.uri, func(t *testing.T) {
t.Parallel()

_, err := newTestParser(c.record).parse(c.uri)
assert.NoError(t, err, "expected no URI parsing error, got %v", err)
})
}
})
t.Run("2. Throw when return address does not end with SRV domain", func(t *testing.T) {
t.Parallel()

cases := []struct {
record string
uri string
}{
{"localhost.mongodb", "mongodb+srv://localhost"},
{"test_1.evil.local", "mongodb+srv://mongo.local"},
{"blogs.evil.com", "mongodb+srv://blogs.mongodb.com"},
}
for _, c := range cases {
c := c
t.Run(c.uri, func(t *testing.T) {
t.Parallel()

_, err := newTestParser(c.record).parse(c.uri)
assert.ErrorContains(t, err, "Domain suffix from SRV record not matched input domain")
})
}
})
t.Run("3. Throw when return address is identical to SRV hostname", func(t *testing.T) {
t.Parallel()

cases := []struct {
record string
uri string
labels int
}{
{"localhost", "mongodb+srv://localhost", 1},
{"mongo.local", "mongodb+srv://mongo.local", 2},
}
for _, c := range cases {
c := c
t.Run(c.uri, func(t *testing.T) {
t.Parallel()

_, err := newTestParser(c.record).parse(c.uri)
expected := fmt.Sprintf(
"Server record (%d levels) should have more domain levels than parent URI (%d levels)",
c.labels, c.labels,
)
assert.ErrorContains(t, err, expected)
})
}
})
t.Run("4. Throw when return address does not contain . separating shared part of domain", func(t *testing.T) {
t.Parallel()

cases := []struct {
record string
uri string
}{
{"test_1.cluster_1localhost", "mongodb+srv://localhost"},
{"test_1.my_hostmongo.local", "mongodb+srv://mongo.local"},
{"cluster.testmongodb.com", "mongodb+srv://blogs.mongodb.com"},
}
for _, c := range cases {
c := c
t.Run(c.uri, func(t *testing.T) {
t.Parallel()

_, err := newTestParser(c.record).parse(c.uri)
assert.ErrorContains(t, err, "Domain suffix from SRV record not matched input domain")
})
}
})
}
11 changes: 7 additions & 4 deletions x/mongo/driver/dns/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,18 @@ func (r *Resolver) fetchSeedlistFromSRV(host string, srvName string, stopOnErr b
func validateSRVResult(recordFromSRV, inputHostName string) error {
separatedInputDomain := strings.Split(strings.ToLower(inputHostName), ".")
separatedRecord := strings.Split(strings.ToLower(recordFromSRV), ".")
if len(separatedRecord) < 2 {
return errors.New("DNS name must contain at least 2 labels")
if l := len(separatedInputDomain); l < 3 && len(separatedRecord) <= l {
return fmt.Errorf("Server record (%d levels) should have more domain levels than parent URI (%d levels)", l, len(separatedRecord))
}
if len(separatedRecord) < len(separatedInputDomain) {
return errors.New("Domain suffix from SRV record not matched input domain")
}

inputDomainSuffix := separatedInputDomain[1:]
domainSuffixOffset := len(separatedRecord) - (len(separatedInputDomain) - 1)
inputDomainSuffix := separatedInputDomain
if len(inputDomainSuffix) > 2 {
inputDomainSuffix = inputDomainSuffix[1:]
}
domainSuffixOffset := len(separatedRecord) - len(inputDomainSuffix)

recordDomainSuffix := separatedRecord[domainSuffixOffset:]
for ix, label := range inputDomainSuffix {
Expand Down

0 comments on commit 9afdc8c

Please sign in to comment.