forked from CorentinB/warc
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdns.go
127 lines (103 loc) · 2.88 KB
/
dns.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
package warc
import (
"fmt"
"log/slog"
"net"
"sync"
"time"
"github.com/miekg/dns"
)
type cachedIP struct {
expiresAt time.Time
ip net.IP
}
const maxFallbackDNSServers = 3
func (d *customDialer) archiveDNS(address string) (resolvedIP net.IP, err error) {
// Get the address without the port if there is one
address, _, err = net.SplitHostPort(address)
if err != nil {
return resolvedIP, err
}
// Check if the address is already an IP
resolvedIP = net.ParseIP(address)
if resolvedIP != nil {
return resolvedIP, nil
}
// Check cache first
if cached, ok := d.DNSRecords.Load(address); ok {
cachedEntry := cached.(cachedIP)
if time.Now().Before(cachedEntry.expiresAt) {
return cachedEntry.ip, nil
}
// Cache entry expired, remove it
d.DNSRecords.Delete(address)
}
var wg sync.WaitGroup
var ipv4, ipv6 net.IP
var errA, errAAAA error
if len(d.DNSConfig.Servers) == 0 {
return nil, fmt.Errorf("no DNS servers configured")
}
fallbackServers := min(maxFallbackDNSServers, len(d.DNSConfig.Servers)-1)
for DNSServer := 0; DNSServer <= fallbackServers; DNSServer++ {
wg.Add(2)
go func() {
defer wg.Done()
ipv4, errA = d.lookupIP(address, dns.TypeA, DNSServer)
}()
go func() {
defer wg.Done()
ipv6, errAAAA = d.lookupIP(address, dns.TypeAAAA, DNSServer)
}()
wg.Wait()
if errA == nil || errAAAA == nil {
break
}
slog.Warn("Failed to resolve DNS", "DNS", d.DNSConfig.Servers[DNSServer], "address", address, "errA", errA, "errAAAA", errAAAA)
}
if errA != nil && errAAAA != nil {
return nil, fmt.Errorf("failed to resolve DNS: A error: %v, AAAA error: %v", errA, errAAAA)
}
// Prioritize IPv6 if both are available and enabled
if ipv6 != nil && !d.disableIPv6 {
resolvedIP = ipv6
} else if ipv4 != nil && !d.disableIPv4 {
resolvedIP = ipv4
}
if resolvedIP != nil {
// Cache the result
d.DNSRecords.Store(address, cachedIP{
ip: resolvedIP,
expiresAt: time.Now().Add(d.DNSRecordsTTL),
})
return resolvedIP, nil
}
return nil, fmt.Errorf("no suitable IP address found for %s", address)
}
func (d *customDialer) lookupIP(address string, recordType uint16, DNSServer int) (net.IP, error) {
m := new(dns.Msg)
m.SetQuestion(dns.Fqdn(address), recordType)
r, _, err := d.DNSClient.Exchange(m, net.JoinHostPort(d.DNSConfig.Servers[DNSServer], d.DNSConfig.Port))
if err != nil {
return nil, err
}
// Record the DNS response
recordTypeStr := "TYPE=A"
if recordType == dns.TypeAAAA {
recordTypeStr = "TYPE=AAAA"
}
d.client.WriteRecord(fmt.Sprintf("dns:%s?%s", address, recordTypeStr), "resource", "text/dns", r.String())
for _, answer := range r.Answer {
switch recordType {
case dns.TypeA:
if a, ok := answer.(*dns.A); ok {
return a.A, nil
}
case dns.TypeAAAA:
if aaaa, ok := answer.(*dns.AAAA); ok {
return aaaa.AAAA, nil
}
}
}
return nil, fmt.Errorf("no %s record found", recordTypeStr)
}