-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdew.go
152 lines (123 loc) · 3.72 KB
/
dew.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
package dew
import (
"context"
"errors"
"fmt"
"sync"
)
var (
// ErrValidationFailed is returned when the command validation fails.
ErrValidationFailed = fmt.Errorf("validation failed")
)
// Dispatch executes the action.
func Dispatch[T Action](ctx context.Context, action *T) (*T, error) {
return action, DispatchMulti(ctx, NewAction(action))
}
// DispatchMulti executes all actions synchronously.
// It assumes that all handlers have been registered to the same mux.
func DispatchMulti(ctx context.Context, actions ...CommandHandler[Action]) error {
if len(actions) == 0 {
return nil
}
bus, ok := FromContext(ctx)
if !ok {
return errors.New("bus not found in context")
}
for _, action := range actions {
if err := action.Resolve(bus); err != nil {
return err
}
}
mux := bus.(*mux)
rctx := mux.pool.Get().(*BusContext)
rctx.Reset()
rctx.ctx = context.WithValue(ctx, busKey{}, mux)
defer mux.pool.Put(rctx)
return mux.mHandlers[mDispatch](rctx, func(ctx Context) error {
for _, action := range actions {
if err := action.Command().(Action).Validate(ctx.Context()); err != nil {
return fmt.Errorf("%w: %v", ErrValidationFailed, err)
}
if err := action.Mux().dispatch(ACTION, ctx, action); err != nil {
return err
}
}
return nil
})
}
// Query executes the query and returns the result.
func Query[T QueryAction](ctx context.Context, query *T) (*T, error) {
bus, ok := FromContext(ctx)
if !ok {
return nil, errors.New("bus not found in context")
}
queryObj := NewQuery(query)
if err := queryObj.Resolve(bus); err != nil {
return nil, err
}
mux := bus.(*mux)
rctx := mux.pool.Get().(*BusContext)
rctx.Reset()
rctx.ctx = context.WithValue(ctx, busKey{}, mux)
defer mux.pool.Put(rctx)
if err := mux.mHandlers[mQuery](rctx, func(ctx Context) error {
return queryObj.Mux().dispatch(QUERY, ctx, queryObj)
}); err != nil {
return nil, err
}
return queryObj.Command().(*T), nil
}
// QueryAsync executes all queries asynchronously and collects errors.
// It assumes that all handlers have been registered to the same mux.
func QueryAsync(ctx context.Context, queries ...CommandHandler[Command]) error {
if len(queries) == 0 {
return nil
}
bus, ok := FromContext(ctx)
if !ok {
return errors.New("bus not found in context")
}
for _, query := range queries {
if err := query.Resolve(bus); err != nil {
return err
}
}
mux := bus.(*mux)
rctx := mux.pool.Get().(*BusContext) // Get a context from the pool.
rctx.Reset()
rctx.ctx = context.WithValue(ctx, busKey{}, mux)
defer mux.pool.Put(rctx) // Ensure the context is put back into the pool.
return mux.mHandlers[mQuery](rctx, func(ctx Context) error {
// Create a goroutine for each query and synchronize with WaitGroup.
var wg sync.WaitGroup
errs := make(chan error, len(queries)) // Buffered channel to collect errors from goroutines.
for _, query := range queries {
query := query
wg.Add(1)
go func(query CommandHandler[Command]) {
defer wg.Done()
rctx := mux.pool.Get().(*BusContext) // Get a context from the pool.
rctx.Reset()
rctx.Copy(ctx.(*BusContext)) // Copy the context to the new context.
defer mux.pool.Put(rctx) // Ensure the context is put back into the pool.
if err := mux.mHandlers[mQuery](rctx, func(ctx Context) error {
return query.Mux().dispatch(QUERY, ctx, query)
}); err != nil {
errs <- err // Send errors to the channel.
}
}(query)
}
wg.Wait()
close(errs) // Close the channel after all goroutines are done.
// Collect errors from the channel.
var combinedError error
for err := range errs {
if combinedError == nil {
combinedError = err
} else {
combinedError = errors.Join(combinedError, err)
}
}
return combinedError
})
}