3
3
// See the LICENSE file in the project root for more information.
4
4
5
5
#include " stdafx.h"
6
- #include < thread>
7
6
#include < algorithm>
8
7
#include " constants.h"
9
8
#include " connection_impl.h"
17
16
#include < assert.h>
18
17
#include " signalrclient/websocket_client.h"
19
18
#include " default_websocket_client.h"
19
+ #include " signalr_default_scheduler.h"
20
20
21
21
namespace signalr
22
22
{
@@ -26,42 +26,25 @@ namespace signalr
26
26
}
27
27
28
28
std::shared_ptr<connection_impl> connection_impl::create (const std::string& url, trace_level trace_level, const std::shared_ptr<log_writer>& log_writer,
29
- std::shared_ptr<http_client> http_client , std::function<std::shared_ptr<websocket_client>(const signalr_client_config&)> websocket_factory, const bool skip_negotiation)
29
+ std::function<std:: shared_ptr<http_client>( const signalr_client_config&)> http_client_factory , std::function<std::shared_ptr<websocket_client>(const signalr_client_config&)> websocket_factory, const bool skip_negotiation)
30
30
{
31
31
return std::shared_ptr<connection_impl>(new connection_impl (url, trace_level,
32
- log_writer ? log_writer : std::make_shared<trace_log_writer>(), http_client , websocket_factory, skip_negotiation));
32
+ log_writer ? log_writer : std::make_shared<trace_log_writer>(), http_client_factory , websocket_factory, skip_negotiation));
33
33
}
34
34
35
35
connection_impl::connection_impl (const std::string& url, trace_level trace_level, const std::shared_ptr<log_writer>& log_writer,
36
- std::unique_ptr<http_client> http_client, std::unique_ptr<transport_factory> transport_factory, const bool skip_negotiation)
37
- : m_base_url(url), m_connection_state(connection_state::disconnected), m_logger(log_writer, trace_level), m_transport(nullptr ),
38
- m_transport_factory (std::move(transport_factory)), m_skip_negotiation(skip_negotiation), m_message_received([](const std::string&) noexcept {}), m_disconnected([]() noexcept {})
39
- {
40
- if (http_client != nullptr )
41
- {
42
- m_http_client = std::move (http_client);
43
- }
44
- else
45
- {
46
- #ifdef USE_CPPRESTSDK
47
- m_http_client = std::unique_ptr<class http_client >(new default_http_client ());
48
- #endif
49
- }
50
- }
51
-
52
- connection_impl::connection_impl (const std::string& url, trace_level trace_level, const std::shared_ptr<log_writer>& log_writer,
53
- std::shared_ptr<http_client> http_client, std::function<std::shared_ptr<websocket_client>(const signalr_client_config&)> websocket_factory, const bool skip_negotiation)
36
+ std::function<std::shared_ptr<http_client>(const signalr_client_config&)> http_client_factory, std::function<std::shared_ptr<websocket_client>(const signalr_client_config&)> websocket_factory, const bool skip_negotiation)
54
37
: m_base_url(url), m_connection_state(connection_state::disconnected), m_logger(log_writer, trace_level), m_transport(nullptr ), m_skip_negotiation(skip_negotiation),
55
- m_message_received ([](const std::string&) noexcept {}), m_disconnected([]() noexcept {})
38
+ m_message_received ([](const std::string&) noexcept {}), m_disconnected([]() noexcept {}), m_disconnect_cts(std::make_shared<cancellation_token>())
56
39
{
57
- if (http_client != nullptr )
40
+ if (http_client_factory != nullptr )
58
41
{
59
- m_http_client = std::move (http_client );
42
+ m_http_client_factory = std::move (http_client_factory );
60
43
}
61
44
else
62
45
{
63
46
#ifdef USE_CPPRESTSDK
64
- m_http_client = std::unique_ptr<class http_client >(new default_http_client ());
47
+ m_http_client_factory = []( const signalr_client_config&) { return std::unique_ptr<class http_client >(new default_http_client ()); } ;
65
48
#endif
66
49
}
67
50
@@ -72,7 +55,7 @@ namespace signalr
72
55
#endif
73
56
}
74
57
75
- m_transport_factory = std::unique_ptr<transport_factory>(new transport_factory (m_http_client , websocket_factory));
58
+ m_transport_factory = std::unique_ptr<transport_factory>(new transport_factory (m_http_client_factory , websocket_factory));
76
59
}
77
60
78
61
connection_impl::~connection_impl ()
@@ -138,11 +121,18 @@ namespace signalr
138
121
// there should not be any active transport at this point
139
122
assert (!m_transport);
140
123
141
- m_disconnect_cts = std::make_shared<cancellation_token> ();
124
+ m_disconnect_cts-> reset ();
142
125
m_start_completed_event.reset ();
143
126
m_connection_id = " " ;
144
127
}
145
128
129
+ m_scheduler = m_signalr_client_config.get_scheduler ();
130
+ if (!m_scheduler)
131
+ {
132
+ m_scheduler = std::make_shared<signalr_default_scheduler>();
133
+ m_signalr_client_config.set_scheduler (m_scheduler);
134
+ }
135
+
146
136
start_negotiate (m_base_url, 0 , callback);
147
137
}
148
138
@@ -157,7 +147,7 @@ namespace signalr
157
147
}
158
148
159
149
std::weak_ptr<connection_impl> weak_connection = shared_from_this ();
160
- const auto & token = m_disconnect_cts;
150
+ const auto token = m_disconnect_cts;
161
151
162
152
const auto transport_started = [weak_connection, callback, token](std::shared_ptr<transport> transport, std::exception_ptr exception)
163
153
{
@@ -225,7 +215,8 @@ namespace signalr
225
215
return start_transport (url, transport_started);
226
216
}
227
217
228
- negotiate::negotiate (*m_http_client, url, m_signalr_client_config,
218
+ auto http_client = m_http_client_factory (m_signalr_client_config);
219
+ negotiate::negotiate (http_client, url, m_signalr_client_config,
229
220
[callback, weak_connection, redirect_count, token, url, transport_started](negotiation_response&& response, std::exception_ptr exception)
230
221
{
231
222
auto connection = weak_connection.lock ();
@@ -320,7 +311,7 @@ namespace signalr
320
311
std::shared_ptr<std::mutex> connect_request_lock = std::make_shared<std::mutex>();
321
312
322
313
auto weak_connection = std::weak_ptr<connection_impl>(connection);
323
- const auto & disconnect_cts = m_disconnect_cts;
314
+ const auto disconnect_cts = m_disconnect_cts;
324
315
const auto & logger = m_logger;
325
316
326
317
auto transport = connection->m_transport_factory ->create_transport (
@@ -406,39 +397,51 @@ namespace signalr
406
397
}
407
398
});
408
399
409
- std::thread ([disconnect_cts, connect_request_done, connect_request_lock, callback, weak_connection]()
410
- {
411
- disconnect_cts->wait (5000 );
400
+ disconnect_cts->register_callback ([connect_request_done, connect_request_lock, callback]()
401
+ {
402
+ bool run_callback = false ;
403
+ {
404
+ std::lock_guard<std::mutex> lock (*connect_request_lock);
412
405
406
+ // no op after connection started successfully
407
+ if (*connect_request_done == false )
408
+ {
409
+ *connect_request_done = true ;
410
+ run_callback = true ;
411
+ }
412
+ } // unlock
413
+
414
+ if (run_callback)
415
+ {
416
+ // The callback checks the disconnect_cts token and will handle it appropriately
417
+ callback ({}, nullptr );
418
+ }
419
+ });
420
+
421
+ timer (m_scheduler, [connect_request_done, connect_request_lock, callback](std::chrono::milliseconds duration) {
413
422
bool run_callback = false ;
414
423
{
415
424
std::lock_guard<std::mutex> lock (*connect_request_lock);
425
+
416
426
// no op after connection started successfully
417
427
if (*connect_request_done == false )
418
428
{
429
+ if (duration < std::chrono::seconds (5 ))
430
+ {
431
+ return false ;
432
+ }
419
433
*connect_request_done = true ;
420
434
run_callback = true ;
421
435
}
422
- }
436
+ } // unlock
423
437
424
- // if the disconnect_cts is canceled it means that the connection has been stopped or went out of scope in
425
- // which case we should not throw due to timeout.
426
- if (disconnect_cts->is_canceled ())
438
+ if (run_callback)
427
439
{
428
- if (run_callback)
429
- {
430
- // The callback checks the disconnect_cts token and will handle it appropriately
431
- callback ({}, nullptr );
432
- }
433
- }
434
- else
435
- {
436
- if (run_callback)
437
- {
438
- callback ({}, std::make_exception_ptr (signalr_exception (" transport timed out when trying to connect" )));
439
- }
440
+ callback ({}, std::make_exception_ptr (signalr_exception (" transport timed out when trying to connect" )));
440
441
}
441
- }).detach ();
442
+
443
+ return true ;
444
+ });
442
445
443
446
connection->send_connect_request (transport, url, [callback, connect_request_done, connect_request_lock, transport](std::exception_ptr exception)
444
447
{
@@ -597,6 +600,7 @@ namespace signalr
597
600
const auto current_state = get_connection_state ();
598
601
if (current_state == connection_state::disconnected)
599
602
{
603
+ m_disconnect_cts->cancel ();
600
604
callback (nullptr );
601
605
return ;
602
606
}
0 commit comments