Merge lp:~pedronis/ubuntu-push/try-hosts into lp:ubuntu-push
- try-hosts
- Merge into trunk
Proposed by
Samuele Pedroni
Status: | Superseded |
---|---|
Proposed branch: | lp:~pedronis/ubuntu-push/try-hosts |
Merge into: | lp:ubuntu-push |
Diff against target: |
610 lines (+374/-40) 2 files modified
client/session/session.go (+122/-26) client/session/session_test.go (+252/-14) |
To merge this branch: | bzr merge lp:~pedronis/ubuntu-push/try-hosts |
Related bugs: |
Reviewer | Review Type | Date Requested | Status |
---|---|---|---|
Ubuntu Push Hackers | Pending | ||
Review via email:
|
This proposal has been superseded by a proposal from 2014-03-27.
Commit message
Description of the change
optionally retrieves a list of hosts to try to connect to from a url,
still need changes in the client config and the interface between client and session: for a second branch, this one is already biggish
found that session.Close() has a race about Connection, will fix also in a later branch
To post a comment you must log in.
- 98. By Samuele Pedroni
-
clarity
- 99. By Samuele Pedroni
-
ordering
- 100. By Samuele Pedroni
-
formatting
Unmerged revisions
Preview Diff
[H/L] Next/Prev Comment, [J/K] Next/Prev File, [N/P] Next/Prev Hunk
1 | === modified file 'client/session/session.go' |
2 | --- client/session/session.go 2014-03-26 16:26:36 +0000 |
3 | +++ client/session/session.go 2014-03-27 21:20:49 +0000 |
4 | @@ -23,14 +23,17 @@ |
5 | "crypto/x509" |
6 | "errors" |
7 | "fmt" |
8 | + "math/rand" |
9 | + "net" |
10 | + "strings" |
11 | + "sync/atomic" |
12 | + "time" |
13 | + |
14 | + "launchpad.net/ubuntu-push/client/gethosts" |
15 | "launchpad.net/ubuntu-push/client/session/levelmap" |
16 | "launchpad.net/ubuntu-push/logger" |
17 | "launchpad.net/ubuntu-push/protocol" |
18 | "launchpad.net/ubuntu-push/util" |
19 | - "math/rand" |
20 | - "net" |
21 | - "sync/atomic" |
22 | - "time" |
23 | ) |
24 | |
25 | var wireVersionBytes = []byte{protocol.ProtocolWireVersion} |
26 | @@ -45,6 +48,15 @@ |
27 | protocol.NotificationsMsg |
28 | } |
29 | |
30 | +// parseServerAddrSpec recognizes whether spec is a HTTP URL to get |
31 | +// hosts from or a |-separated list of host:port pairs. |
32 | +func parseServerAddrSpec(spec string) (hostsEndpoint string, fallbackHosts []string) { |
33 | + if strings.HasPrefix(spec, "http") { |
34 | + return spec, nil |
35 | + } |
36 | + return "", strings.Split(spec, "|") |
37 | +} |
38 | + |
39 | // ClientSessionState is a way to broadly track the progress of the session |
40 | type ClientSessionState uint32 |
41 | |
42 | @@ -56,14 +68,27 @@ |
43 | Running |
44 | ) |
45 | |
46 | +type hostGetter interface { |
47 | + Get() ([]string, error) |
48 | +} |
49 | + |
50 | // ClienSession holds a client<->server session and its configuration. |
51 | type ClientSession struct { |
52 | // configuration |
53 | - DeviceId string |
54 | - ServerAddr string |
55 | - ExchangeTimeout time.Duration |
56 | - Levels levelmap.LevelMap |
57 | - Protocolator func(net.Conn) protocol.Protocol |
58 | + DeviceId string |
59 | + getHost hostGetter |
60 | + fallbackHosts []string |
61 | + ExchangeTimeout time.Duration |
62 | + HostsCachingExpiry time.Duration |
63 | + Levels levelmap.LevelMap |
64 | + Protocolator func(net.Conn) protocol.Protocol |
65 | + // hosts |
66 | + timeSince func(time.Time) time.Duration // hook for testing |
67 | + deliveryHostsTimestamp time.Time |
68 | + deliveryHosts []string |
69 | + lastAttemptTimestamp time.Time |
70 | + leftToTry int |
71 | + tryHost int |
72 | // connection |
73 | Connection net.Conn |
74 | Log logger.Logger |
75 | @@ -77,7 +102,7 @@ |
76 | MsgCh chan *Notification |
77 | } |
78 | |
79 | -func NewSession(serverAddr string, pem []byte, exchangeTimeout time.Duration, |
80 | +func NewSession(serverAddrSpec string, pem []byte, exchangeTimeout time.Duration, |
81 | deviceId string, levelmapFactory func() (levelmap.LevelMap, error), |
82 | log logger.Logger) (*ClientSession, error) { |
83 | state := uint32(Disconnected) |
84 | @@ -85,15 +110,23 @@ |
85 | if err != nil { |
86 | return nil, err |
87 | } |
88 | + var getHost hostGetter |
89 | + hostsEndpoint, fallbackHosts := parseServerAddrSpec(serverAddrSpec) |
90 | + if hostsEndpoint != "" { |
91 | + getHost = gethosts.New(deviceId, hostsEndpoint, exchangeTimeout) |
92 | + } |
93 | sess := &ClientSession{ |
94 | - ExchangeTimeout: exchangeTimeout, |
95 | - ServerAddr: serverAddr, |
96 | - DeviceId: deviceId, |
97 | - Log: log, |
98 | - Protocolator: protocol.NewProtocol0, |
99 | - Levels: levels, |
100 | - TLS: &tls.Config{InsecureSkipVerify: true}, // XXX |
101 | - stateP: &state, |
102 | + ExchangeTimeout: exchangeTimeout, |
103 | + HostsCachingExpiry: 12 * time.Hour, |
104 | + getHost: getHost, |
105 | + fallbackHosts: fallbackHosts, |
106 | + DeviceId: deviceId, |
107 | + Log: log, |
108 | + Protocolator: protocol.NewProtocol0, |
109 | + Levels: levels, |
110 | + TLS: &tls.Config{InsecureSkipVerify: true}, // XXX |
111 | + stateP: &state, |
112 | + timeSince: time.Since, |
113 | } |
114 | if pem != nil { |
115 | cp := x509.NewCertPool() |
116 | @@ -114,13 +147,72 @@ |
117 | atomic.StoreUint32(sess.stateP, uint32(state)) |
118 | } |
119 | |
120 | +// getHosts sets deliverHosts possibly querying a remote endpoint |
121 | +func (sess *ClientSession) getHosts() error { |
122 | + if sess.getHost != nil { |
123 | + if sess.timeSince(sess.deliveryHostsTimestamp) < sess.HostsCachingExpiry { |
124 | + return nil |
125 | + } |
126 | + hosts, err := sess.getHost.Get() |
127 | + if err != nil { |
128 | + sess.Log.Errorf("getHosts: %v", err) |
129 | + sess.setState(Error) |
130 | + return err |
131 | + } |
132 | + sess.deliveryHostsTimestamp = time.Now() |
133 | + sess.deliveryHosts = hosts |
134 | + } else { |
135 | + sess.deliveryHosts = sess.fallbackHosts |
136 | + } |
137 | + return nil |
138 | +} |
139 | + |
140 | +// startConnectionAttempt/nextHostToTry help connect iterating over candidate hosts |
141 | + |
142 | +func (sess *ClientSession) startConnectionAttempt() { |
143 | + if sess.timeSince(sess.lastAttemptTimestamp) > 10*sess.ExchangeTimeout { |
144 | + sess.tryHost = 0 |
145 | + } |
146 | + sess.leftToTry = len(sess.deliveryHosts) |
147 | + sess.lastAttemptTimestamp = time.Now() |
148 | +} |
149 | + |
150 | +func (sess *ClientSession) nextHostToTry() string { |
151 | + if sess.leftToTry == 0 { |
152 | + return "" |
153 | + } |
154 | + res := sess.deliveryHosts[sess.tryHost] |
155 | + sess.tryHost = (sess.tryHost + 1) % len(sess.deliveryHosts) |
156 | + sess.leftToTry-- |
157 | + return res |
158 | +} |
159 | + |
160 | +// we reached the Started state, we can retry with the same host if we |
161 | +// have to retry again |
162 | +func (sess *ClientSession) started() { |
163 | + sess.tryHost-- |
164 | + if sess.tryHost == -1 { |
165 | + sess.tryHost = len(sess.deliveryHosts) - 1 |
166 | + } |
167 | + sess.setState(Started) |
168 | +} |
169 | + |
170 | // connect to a server using the configuration in the ClientSession |
171 | // and set up the connection. |
172 | func (sess *ClientSession) connect() error { |
173 | - conn, err := net.DialTimeout("tcp", sess.ServerAddr, sess.ExchangeTimeout) |
174 | - if err != nil { |
175 | - sess.setState(Error) |
176 | - return fmt.Errorf("connect: %s", err) |
177 | + sess.startConnectionAttempt() |
178 | + var err error |
179 | + var conn net.Conn |
180 | + for { |
181 | + host := sess.nextHostToTry() |
182 | + if host == "" { |
183 | + sess.setState(Error) |
184 | + return fmt.Errorf("connect: %s", err) |
185 | + } |
186 | + conn, err = net.DialTimeout("tcp", host, sess.ExchangeTimeout) |
187 | + if err == nil { |
188 | + break |
189 | + } |
190 | } |
191 | sess.Connection = tls.Client(conn, sess.TLS) |
192 | sess.setState(Connected) |
193 | @@ -279,15 +371,19 @@ |
194 | sess.proto = proto |
195 | sess.pingInterval = pingInterval |
196 | sess.Log.Debugf("Connected %v.", conn.LocalAddr()) |
197 | - sess.setState(Started) |
198 | + sess.started() // deals with choosing which host to retry with as well |
199 | return nil |
200 | } |
201 | |
202 | // run calls connect, and if it works it calls start, and if it works |
203 | // it runs loop in a goroutine, and ships its return value over ErrCh. |
204 | -func (sess *ClientSession) run(closer func(), connecter, starter, looper func() error) error { |
205 | +func (sess *ClientSession) run(closer func(), hostGetter, connecter, starter, looper func() error) error { |
206 | closer() |
207 | - err := connecter() |
208 | + err := hostGetter() |
209 | + if err != nil { |
210 | + return err |
211 | + } |
212 | + err = connecter() |
213 | if err == nil { |
214 | err = starter() |
215 | if err == nil { |
216 | @@ -317,7 +413,7 @@ |
217 | // keep on trying. |
218 | panic("can't Dial() without a protocol constructor.") |
219 | } |
220 | - return sess.run(sess.doClose, sess.connect, sess.start, sess.loop) |
221 | + return sess.run(sess.doClose, sess.getHosts, sess.connect, sess.start, sess.loop) |
222 | } |
223 | |
224 | func init() { |
225 | |
226 | === modified file 'client/session/session_test.go' |
227 | --- client/session/session_test.go 2014-03-27 13:26:10 +0000 |
228 | +++ client/session/session_test.go 2014-03-27 21:20:49 +0000 |
229 | @@ -23,16 +23,21 @@ |
230 | "fmt" |
231 | "io" |
232 | "io/ioutil" |
233 | + "net" |
234 | + "net/http" |
235 | + "net/http/httptest" |
236 | + "reflect" |
237 | + "testing" |
238 | + "time" |
239 | + |
240 | . "launchpad.net/gocheck" |
241 | + |
242 | "launchpad.net/ubuntu-push/client/session/levelmap" |
243 | + //"launchpad.net/ubuntu-push/client/gethosts" |
244 | "launchpad.net/ubuntu-push/logger" |
245 | "launchpad.net/ubuntu-push/protocol" |
246 | helpers "launchpad.net/ubuntu-push/testing" |
247 | "launchpad.net/ubuntu-push/testing/condition" |
248 | - "net" |
249 | - "reflect" |
250 | - "testing" |
251 | - "time" |
252 | ) |
253 | |
254 | func TestSession(t *testing.T) { TestingT(t) } |
255 | @@ -181,18 +186,43 @@ |
256 | } |
257 | |
258 | /**************************************************************** |
259 | + parseServerAddrSpec() tests |
260 | +****************************************************************/ |
261 | + |
262 | +func (cs *clientSessionSuite) TestParseServerAddrSpec(c *C) { |
263 | + hEp, fallbackHosts := parseServerAddrSpec("http://foo/hosts") |
264 | + c.Check(hEp, Equals, "http://foo/hosts") |
265 | + c.Check(fallbackHosts, IsNil) |
266 | + |
267 | + hEp, fallbackHosts = parseServerAddrSpec("foo:443") |
268 | + c.Check(hEp, Equals, "") |
269 | + c.Check(fallbackHosts, DeepEquals, []string{"foo:443"}) |
270 | + |
271 | + hEp, fallbackHosts = parseServerAddrSpec("foo:443|bar:443") |
272 | + c.Check(hEp, Equals, "") |
273 | + c.Check(fallbackHosts, DeepEquals, []string{"foo:443", "bar:443"}) |
274 | +} |
275 | + |
276 | +/**************************************************************** |
277 | NewSession() tests |
278 | ****************************************************************/ |
279 | |
280 | func (cs *clientSessionSuite) TestNewSessionPlainWorks(c *C) { |
281 | - sess, err := NewSession("", nil, 0, "", cs.lvls, cs.log) |
282 | + sess, err := NewSession("foo:443", nil, 0, "", cs.lvls, cs.log) |
283 | c.Check(sess, NotNil) |
284 | c.Check(err, IsNil) |
285 | + c.Check(sess.fallbackHosts, DeepEquals, []string{"foo:443"}) |
286 | // but no root CAs set |
287 | c.Check(sess.TLS.RootCAs, IsNil) |
288 | c.Check(sess.State(), Equals, Disconnected) |
289 | } |
290 | |
291 | +func (cs *clientSessionSuite) TestNewSessionHostEndpointWorks(c *C) { |
292 | + sess, err := NewSession("http://foo/hosts", pem, 0, "wah", cs.lvls, cs.log) |
293 | + c.Assert(err, IsNil) |
294 | + c.Check(sess.getHost, NotNil) |
295 | +} |
296 | + |
297 | var certfile string = helpers.SourceRelative("../../server/acceptance/config/testing.cert") |
298 | var pem, _ = ioutil.ReadFile(certfile) |
299 | |
300 | @@ -218,12 +248,141 @@ |
301 | } |
302 | |
303 | /**************************************************************** |
304 | + getHosts() tests |
305 | +****************************************************************/ |
306 | + |
307 | +func (cs *clientSessionSuite) TestGetHostsFallback(c *C) { |
308 | + fallback := []string{"foo:443", "bar:443"} |
309 | + sess := &ClientSession{fallbackHosts: fallback} |
310 | + err := sess.getHosts() |
311 | + c.Assert(err, IsNil) |
312 | + c.Check(sess.deliveryHosts, DeepEquals, fallback) |
313 | +} |
314 | + |
315 | +type testHostGetter struct { |
316 | + hosts []string |
317 | + err error |
318 | +} |
319 | + |
320 | +func (thg *testHostGetter) Get() ([]string, error) { |
321 | + return thg.hosts, thg.err |
322 | +} |
323 | + |
324 | +func (cs *clientSessionSuite) TestGetHostsRemote(c *C) { |
325 | + hostGetter := &testHostGetter{[]string{"foo:443", "bar:443"}, nil} |
326 | + sess := &ClientSession{getHost: hostGetter, timeSince: time.Since} |
327 | + err := sess.getHosts() |
328 | + c.Assert(err, IsNil) |
329 | + c.Check(sess.deliveryHosts, DeepEquals, []string{"foo:443", "bar:443"}) |
330 | +} |
331 | + |
332 | +func (cs *clientSessionSuite) TestGetHostsRemoteError(c *C) { |
333 | + sess, err := NewSession("", nil, 0, "", cs.lvls, cs.log) |
334 | + c.Assert(err, IsNil) |
335 | + hostsErr := errors.New("failed") |
336 | + hostGetter := &testHostGetter{nil, hostsErr} |
337 | + sess.getHost = hostGetter |
338 | + err = sess.getHosts() |
339 | + c.Assert(err, Equals, hostsErr) |
340 | + c.Check(sess.deliveryHosts, IsNil) |
341 | + c.Check(sess.State(), Equals, Error) |
342 | +} |
343 | + |
344 | +func (cs *clientSessionSuite) TestGetHostsRemoteCaching(c *C) { |
345 | + hostGetter := &testHostGetter{[]string{"foo:443", "bar:443"}, nil} |
346 | + sess := &ClientSession{ |
347 | + getHost: hostGetter, |
348 | + HostsCachingExpiry: 2 * time.Hour, |
349 | + timeSince: time.Since, |
350 | + } |
351 | + err := sess.getHosts() |
352 | + c.Assert(err, IsNil) |
353 | + hostGetter.hosts = []string{"baz:443"} |
354 | + // cached |
355 | + err = sess.getHosts() |
356 | + c.Assert(err, IsNil) |
357 | + c.Check(sess.deliveryHosts, DeepEquals, []string{"foo:443", "bar:443"}) |
358 | + // expired |
359 | + sess.timeSince = func(ts time.Time) time.Duration { |
360 | + return 3 * time.Hour |
361 | + } |
362 | + err = sess.getHosts() |
363 | + c.Assert(err, IsNil) |
364 | + c.Check(sess.deliveryHosts, DeepEquals, []string{"baz:443"}) |
365 | +} |
366 | + |
367 | +/**************************************************************** |
368 | + startConnectionAttempt()/nextHostToTry()/started tests |
369 | +****************************************************************/ |
370 | + |
371 | +func (cs *clientSessionSuite) TestStartConnectionAttempt(c *C) { |
372 | + since := time.Since(time.Time{}) |
373 | + sess := &ClientSession{ |
374 | + ExchangeTimeout: 10 * time.Second, |
375 | + timeSince: func(ts time.Time) time.Duration { |
376 | + return since |
377 | + }, |
378 | + deliveryHosts: []string{"foo:443", "bar:443"}, |
379 | + } |
380 | + // start from first host |
381 | + sess.startConnectionAttempt() |
382 | + c.Check(sess.lastAttemptTimestamp, Not(Equals), 0) |
383 | + c.Check(sess.tryHost, Equals, 0) |
384 | + c.Check(sess.leftToTry, Equals, 2) |
385 | + since = 1 * time.Second |
386 | + sess.tryHost = 1 |
387 | + // just continue |
388 | + sess.startConnectionAttempt() |
389 | + c.Check(sess.tryHost, Equals, 1) |
390 | + sess.tryHost = 2 |
391 | +} |
392 | + |
393 | +func (cs *clientSessionSuite) TestNextHostToTry(c *C) { |
394 | + sess := &ClientSession{ |
395 | + deliveryHosts: []string{"foo:443", "bar:443", "baz:443"}, |
396 | + tryHost: 0, |
397 | + leftToTry: 3, |
398 | + } |
399 | + c.Check(sess.nextHostToTry(), Equals, "foo:443") |
400 | + c.Check(sess.nextHostToTry(), Equals, "bar:443") |
401 | + c.Check(sess.nextHostToTry(), Equals, "baz:443") |
402 | + c.Check(sess.nextHostToTry(), Equals, "") |
403 | + c.Check(sess.nextHostToTry(), Equals, "") |
404 | + c.Check(sess.tryHost, Equals, 0) |
405 | + |
406 | + sess.leftToTry = 3 |
407 | + sess.tryHost = 1 |
408 | + c.Check(sess.nextHostToTry(), Equals, "bar:443") |
409 | + c.Check(sess.nextHostToTry(), Equals, "baz:443") |
410 | + c.Check(sess.nextHostToTry(), Equals, "foo:443") |
411 | + c.Check(sess.nextHostToTry(), Equals, "") |
412 | + c.Check(sess.nextHostToTry(), Equals, "") |
413 | + c.Check(sess.tryHost, Equals, 1) |
414 | +} |
415 | + |
416 | +func (cs *clientSessionSuite) TestStarted(c *C) { |
417 | + sess, err := NewSession("", nil, 0, "", cs.lvls, cs.log) |
418 | + c.Assert(err, IsNil) |
419 | + |
420 | + sess.deliveryHosts = []string{"foo:443", "bar:443", "baz:443"} |
421 | + sess.tryHost = 1 |
422 | + |
423 | + sess.started() |
424 | + c.Check(sess.tryHost, Equals, 0) |
425 | + c.Check(sess.State(), Equals, Started) |
426 | + |
427 | + sess.started() |
428 | + c.Check(sess.tryHost, Equals, 2) |
429 | +} |
430 | + |
431 | +/**************************************************************** |
432 | connect() tests |
433 | ****************************************************************/ |
434 | |
435 | func (cs *clientSessionSuite) TestConnectFailsWithNoAddress(c *C) { |
436 | sess, err := NewSession("", nil, 0, "wah", cs.lvls, cs.log) |
437 | c.Assert(err, IsNil) |
438 | + sess.deliveryHosts = []string{"nowhere"} |
439 | err = sess.connect() |
440 | c.Check(err, ErrorMatches, ".*connect.*address.*") |
441 | c.Check(sess.State(), Equals, Error) |
442 | @@ -233,12 +392,27 @@ |
443 | srv, err := net.Listen("tcp", "localhost:0") |
444 | c.Assert(err, IsNil) |
445 | defer srv.Close() |
446 | - sess, err := NewSession(srv.Addr().String(), nil, 0, "wah", cs.lvls, cs.log) |
447 | - c.Assert(err, IsNil) |
448 | - err = sess.connect() |
449 | - c.Check(err, IsNil) |
450 | - c.Check(sess.Connection, NotNil) |
451 | - c.Check(sess.State(), Equals, Connected) |
452 | + sess, err := NewSession("", nil, 0, "wah", cs.lvls, cs.log) |
453 | + c.Assert(err, IsNil) |
454 | + sess.deliveryHosts = []string{srv.Addr().String()} |
455 | + err = sess.connect() |
456 | + c.Check(err, IsNil) |
457 | + c.Check(sess.Connection, NotNil) |
458 | + c.Check(sess.State(), Equals, Connected) |
459 | +} |
460 | + |
461 | +func (cs *clientSessionSuite) TestConnectSecondConnects(c *C) { |
462 | + srv, err := net.Listen("tcp", "localhost:0") |
463 | + c.Assert(err, IsNil) |
464 | + defer srv.Close() |
465 | + sess, err := NewSession("", nil, 0, "wah", cs.lvls, cs.log) |
466 | + c.Assert(err, IsNil) |
467 | + sess.deliveryHosts = []string{"nowhere", srv.Addr().String()} |
468 | + err = sess.connect() |
469 | + c.Check(err, IsNil) |
470 | + c.Check(sess.Connection, NotNil) |
471 | + c.Check(sess.State(), Equals, Connected) |
472 | + c.Check(sess.tryHost, Equals, 0) |
473 | } |
474 | |
475 | func (cs *clientSessionSuite) TestConnectConnectFail(c *C) { |
476 | @@ -247,6 +421,7 @@ |
477 | sess, err := NewSession(srv.Addr().String(), nil, 0, "wah", cs.lvls, cs.log) |
478 | srv.Close() |
479 | c.Assert(err, IsNil) |
480 | + sess.deliveryHosts = []string{srv.Addr().String()} |
481 | err = sess.connect() |
482 | c.Check(err, ErrorMatches, ".*connection refused") |
483 | c.Check(sess.State(), Equals, Error) |
484 | @@ -688,20 +863,34 @@ |
485 | run() tests |
486 | ****************************************************************/ |
487 | |
488 | -func (cs *clientSessionSuite) TestRunBailsIfConnectFails(c *C) { |
489 | +func (cs *clientSessionSuite) TestRunBailsIfHostGetterFails(c *C) { |
490 | sess, err := NewSession("", nil, 0, "wah", cs.lvls, cs.log) |
491 | c.Assert(err, IsNil) |
492 | - failure := errors.New("TestRunBailsIfConnectFails") |
493 | + failure := errors.New("TestRunBailsIfHostGetterFails") |
494 | has_closed := false |
495 | err = sess.run( |
496 | func() { has_closed = true }, |
497 | func() error { return failure }, |
498 | nil, |
499 | + nil, |
500 | nil) |
501 | c.Check(err, Equals, failure) |
502 | c.Check(has_closed, Equals, true) |
503 | } |
504 | |
505 | +func (cs *clientSessionSuite) TestRunBailsIfConnectFails(c *C) { |
506 | + sess, err := NewSession("", nil, 0, "wah", cs.lvls, cs.log) |
507 | + c.Assert(err, IsNil) |
508 | + failure := errors.New("TestRunBailsIfConnectFails") |
509 | + err = sess.run( |
510 | + func() {}, |
511 | + func() error { return nil }, |
512 | + func() error { return failure }, |
513 | + nil, |
514 | + nil) |
515 | + c.Check(err, Equals, failure) |
516 | +} |
517 | + |
518 | func (cs *clientSessionSuite) TestRunBailsIfStartFails(c *C) { |
519 | sess, err := NewSession("", nil, 0, "wah", cs.lvls, cs.log) |
520 | c.Assert(err, IsNil) |
521 | @@ -709,6 +898,7 @@ |
522 | err = sess.run( |
523 | func() {}, |
524 | func() error { return nil }, |
525 | + func() error { return nil }, |
526 | func() error { return failure }, |
527 | nil) |
528 | c.Check(err, Equals, failure) |
529 | @@ -727,6 +917,7 @@ |
530 | func() {}, |
531 | func() error { return nil }, |
532 | func() error { return nil }, |
533 | + func() error { return nil }, |
534 | func() error { sess.MsgCh <- notf; return <-failureCh }) |
535 | c.Check(err, Equals, nil) |
536 | // if run doesn't error it sets up the channels |
537 | @@ -794,7 +985,20 @@ |
538 | timeout := 100 * time.Millisecond |
539 | lst, err := tls.Listen("tcp", "localhost:0", tlsCfg) |
540 | c.Assert(err, IsNil) |
541 | - sess, err := NewSession(lst.Addr().String(), nil, timeout, "wah", cs.lvls, cs.log) |
542 | + // advertise |
543 | + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
544 | + b, err := json.Marshal(map[string]interface{}{ |
545 | + "hosts": []string{"nowhere", lst.Addr().String()}, |
546 | + }) |
547 | + if err != nil { |
548 | + panic(err) |
549 | + } |
550 | + w.Header().Set("Content-Type", "application/json") |
551 | + w.Write(b) |
552 | + })) |
553 | + defer ts.Close() |
554 | + |
555 | + sess, err := NewSession(ts.URL, nil, timeout, "wah", cs.lvls, cs.log) |
556 | c.Assert(err, IsNil) |
557 | tconn := &testConn{CloseCondition: condition.Fail2Work(10)} |
558 | sess.Connection = tconn |
559 | @@ -823,6 +1027,9 @@ |
560 | c.Assert(err, IsNil) |
561 | c.Assert(v, Equals, protocol.ProtocolWireVersion) |
562 | |
563 | + // if something goes wrong session would try the first/other host |
564 | + c.Check(sess.tryHost, Equals, 0) |
565 | + |
566 | // 2. "connect" (but on the fake protcol above! woo) |
567 | |
568 | c.Check(takeNext(downCh), Equals, "deadline 100ms") |
569 | @@ -843,6 +1050,9 @@ |
570 | c.Check(takeNext(downCh), Equals, protocol.PingPongMsg{Type: "pong"}) |
571 | upCh <- nil |
572 | |
573 | + // session would retry the same host |
574 | + c.Check(sess.tryHost, Equals, 1) |
575 | + |
576 | // and broadcasts... |
577 | b := &protocol.BroadcastMsg{ |
578 | Type: "broadcast", |
579 | @@ -870,3 +1080,31 @@ |
580 | upCh <- failure |
581 | c.Check(<-sess.ErrCh, Equals, failure) |
582 | } |
583 | + |
584 | +func (cs *clientSessionSuite) TestDialWorksDirect(c *C) { |
585 | + // happy path thoughts |
586 | + cert, err := tls.X509KeyPair(helpers.TestCertPEMBlock, helpers.TestKeyPEMBlock) |
587 | + c.Assert(err, IsNil) |
588 | + tlsCfg := &tls.Config{ |
589 | + Certificates: []tls.Certificate{cert}, |
590 | + SessionTicketsDisabled: true, |
591 | + } |
592 | + |
593 | + timeout := 100 * time.Millisecond |
594 | + lst, err := tls.Listen("tcp", "localhost:0", tlsCfg) |
595 | + c.Assert(err, IsNil) |
596 | + sess, err := NewSession(lst.Addr().String(), nil, timeout, "wah", cs.lvls, cs.log) |
597 | + c.Assert(err, IsNil) |
598 | + //defer sess.Close() xxx provokes a race, fix in a later branch |
599 | + |
600 | + upCh := make(chan interface{}, 5) |
601 | + downCh := make(chan interface{}, 5) |
602 | + proto := &testProtocol{up: upCh, down: downCh} |
603 | + sess.Protocolator = func(net.Conn) protocol.Protocol { return proto } |
604 | + |
605 | + go sess.Dial() |
606 | + |
607 | + _, err = lst.Accept() |
608 | + c.Assert(err, IsNil) |
609 | + // connect done |
610 | +} |