summaryrefslogtreecommitdiff
blob: 5d5ef64333d34c8c6c7f4f163dbd9d7d0202bfbf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
From 156280bcb881915701b25ad57e1efe2dcef73c6b Mon Sep 17 00:00:00 2001
From: "Jason A. Donenfeld" <Jason@zx2c4.com>
Date: Tue, 6 Jun 2017 21:49:29 +0200
Subject: noise: fix race when replacing handshake

Replacing an entry that's already been replaced is something that could
happen when processing handshake messages in parallel, when starting up
multiple instances on the same machine.

Reported-by: Hubert Goisern <zweizweizwoelf@gmail.com>
---
 src/hashtables.c |  5 ++++-
 src/hashtables.h |  2 +-
 src/noise.c      | 28 +++++++++++++++++++---------
 3 files changed, 24 insertions(+), 11 deletions(-)

diff --git a/src/hashtables.c b/src/hashtables.c
index db97f7e..a01a899 100644
--- a/src/hashtables.c
+++ b/src/hashtables.c
@@ -97,13 +97,16 @@ search_unused_slot:
 	return entry->index;
 }
 
-void index_hashtable_replace(struct index_hashtable *table, struct index_hashtable_entry *old, struct index_hashtable_entry *new)
+bool index_hashtable_replace(struct index_hashtable *table, struct index_hashtable_entry *old, struct index_hashtable_entry *new)
 {
+	if (unlikely(hlist_unhashed(&old->index_hash)))
+		return false;
 	spin_lock_bh(&table->lock);
 	new->index = old->index;
 	hlist_replace_rcu(&old->index_hash, &new->index_hash);
 	INIT_HLIST_NODE(&old->index_hash);
 	spin_unlock_bh(&table->lock);
+	return true;
 }
 
 void index_hashtable_remove(struct index_hashtable *table, struct index_hashtable_entry *entry)
diff --git a/src/hashtables.h b/src/hashtables.h
index 9fa47d5..08a2a5d 100644
--- a/src/hashtables.h
+++ b/src/hashtables.h
@@ -40,7 +40,7 @@ struct index_hashtable_entry {
 };
 void index_hashtable_init(struct index_hashtable *table);
 __le32 index_hashtable_insert(struct index_hashtable *table, struct index_hashtable_entry *entry);
-void index_hashtable_replace(struct index_hashtable *table, struct index_hashtable_entry *old, struct index_hashtable_entry *new);
+bool index_hashtable_replace(struct index_hashtable *table, struct index_hashtable_entry *old, struct index_hashtable_entry *new);
 void index_hashtable_remove(struct index_hashtable *table, struct index_hashtable_entry *entry);
 struct index_hashtable_entry *index_hashtable_lookup(struct index_hashtable *table, const enum index_hashtable_type type_mask, const __le32 index);
 
diff --git a/src/noise.c b/src/noise.c
index 7ca2a67..9583ab1 100644
--- a/src/noise.c
+++ b/src/noise.c
@@ -59,16 +59,21 @@ bool noise_handshake_init(struct noise_handshake *handshake, struct noise_static
 	return noise_precompute_static_static(peer);
 }
 
-void noise_handshake_clear(struct noise_handshake *handshake)
+static void handshake_zero(struct noise_handshake *handshake)
 {
-	index_hashtable_remove(&handshake->entry.peer->device->index_hashtable, &handshake->entry);
-	down_write(&handshake->lock);
 	memset(&handshake->ephemeral_private, 0, NOISE_PUBLIC_KEY_LEN);
 	memset(&handshake->remote_ephemeral, 0, NOISE_PUBLIC_KEY_LEN);
 	memset(&handshake->hash, 0, NOISE_HASH_LEN);
 	memset(&handshake->chaining_key, 0, NOISE_HASH_LEN);
 	handshake->remote_index = 0;
 	handshake->state = HANDSHAKE_ZEROED;
+}
+
+void noise_handshake_clear(struct noise_handshake *handshake)
+{
+	index_hashtable_remove(&handshake->entry.peer->device->index_hashtable, &handshake->entry);
+	down_write(&handshake->lock);
+	handshake_zero(handshake);
 	up_write(&handshake->lock);
 	index_hashtable_remove(&handshake->entry.peer->device->index_hashtable, &handshake->entry);
 }
@@ -371,8 +376,8 @@ bool noise_handshake_create_initiation(struct message_handshake_initiation *dst,
 
 	dst->sender_index = index_hashtable_insert(&handshake->entry.peer->device->index_hashtable, &handshake->entry);
 
-	ret = true;
 	handshake->state = HANDSHAKE_CREATED_INITIATION;
+	ret = true;
 
 out:
 	up_write(&handshake->lock);
@@ -548,6 +553,11 @@ struct wireguard_peer *noise_handshake_consume_response(struct message_handshake
 
 	/* Success! Copy everything to peer */
 	down_write(&handshake->lock);
+	/* It's important to check that the state is still the same, while we have an exclusive lock */
+	if (handshake->state != state) {
+		up_write(&handshake->lock);
+		goto fail;
+	}
 	memcpy(handshake->remote_ephemeral, e, NOISE_PUBLIC_KEY_LEN);
 	memcpy(handshake->hash, hash, NOISE_HASH_LEN);
 	memcpy(handshake->chaining_key, chaining_key, NOISE_HASH_LEN);
@@ -573,7 +583,7 @@ bool noise_handshake_begin_session(struct noise_handshake *handshake, struct noi
 {
 	struct noise_keypair *new_keypair;
 
-	down_read(&handshake->lock);
+	down_write(&handshake->lock);
 	if (handshake->state != HANDSHAKE_CREATED_RESPONSE && handshake->state != HANDSHAKE_CONSUMED_RESPONSE)
 		goto fail;
 
@@ -587,16 +597,16 @@ bool noise_handshake_begin_session(struct noise_handshake *handshake, struct noi
 		derive_keys(&new_keypair->sending, &new_keypair->receiving, handshake->chaining_key);
 	else
 		derive_keys(&new_keypair->receiving, &new_keypair->sending, handshake->chaining_key);
-	up_read(&handshake->lock);
 
+	handshake_zero(handshake);
 	add_new_keypair(keypairs, new_keypair);
-	index_hashtable_replace(&handshake->entry.peer->device->index_hashtable, &handshake->entry, &new_keypair->entry);
-	noise_handshake_clear(handshake);
 	net_dbg_ratelimited("%s: Keypair %Lu created for peer %Lu\n", netdev_pub(new_keypair->entry.peer->device)->name, new_keypair->internal_id, new_keypair->entry.peer->internal_id);
+	WARN_ON(!index_hashtable_replace(&handshake->entry.peer->device->index_hashtable, &handshake->entry, &new_keypair->entry));
+	up_write(&handshake->lock);
 
 	return true;
 
 fail:
-	up_read(&handshake->lock);
+	up_write(&handshake->lock);
 	return false;
 }
-- 
cgit v1.1-9-ge9c1d