mpls: Protect net->mpls.platform_label with a per-netns mutex.
MPLS (re)uses RTNL to protect net->mpls.platform_label, but the lock does not need to be RTNL at all. Let's protect net->mpls.platform_label with a dedicated per-netns mutex. Signed-off-by: Kuniyuki Iwashima <kuniyu@google.com> Reviewed-by: Guillaume Nault <gnault@redhat.com> Link: https://patch.msgid.link/20251029173344.2934622-13-kuniyu@google.com Signed-off-by: Jakub Kicinski <kuba@kernel.org>pull/1354/merge
parent
fb2b77b9b1
commit
e833eb2516
|
|
@ -16,6 +16,7 @@ struct netns_mpls {
|
|||
int default_ttl;
|
||||
size_t platform_labels;
|
||||
struct mpls_route __rcu * __rcu *platform_label;
|
||||
struct mutex platform_mutex;
|
||||
|
||||
struct ctl_table_header *ctl;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -79,8 +79,8 @@ static struct mpls_route *mpls_route_input(struct net *net, unsigned int index)
|
|||
{
|
||||
struct mpls_route __rcu **platform_label;
|
||||
|
||||
platform_label = rtnl_dereference(net->mpls.platform_label);
|
||||
return rtnl_dereference(platform_label[index]);
|
||||
platform_label = mpls_dereference(net, net->mpls.platform_label);
|
||||
return mpls_dereference(net, platform_label[index]);
|
||||
}
|
||||
|
||||
static struct mpls_route *mpls_route_input_rcu(struct net *net, unsigned int index)
|
||||
|
|
@ -578,10 +578,8 @@ static void mpls_route_update(struct net *net, unsigned index,
|
|||
struct mpls_route __rcu **platform_label;
|
||||
struct mpls_route *rt;
|
||||
|
||||
ASSERT_RTNL();
|
||||
|
||||
platform_label = rtnl_dereference(net->mpls.platform_label);
|
||||
rt = rtnl_dereference(platform_label[index]);
|
||||
platform_label = mpls_dereference(net, net->mpls.platform_label);
|
||||
rt = mpls_dereference(net, platform_label[index]);
|
||||
rcu_assign_pointer(platform_label[index], new);
|
||||
|
||||
mpls_notify_route(net, index, rt, new, info);
|
||||
|
|
@ -1472,8 +1470,6 @@ static struct mpls_dev *mpls_add_dev(struct net_device *dev)
|
|||
int err = -ENOMEM;
|
||||
int i;
|
||||
|
||||
ASSERT_RTNL();
|
||||
|
||||
mdev = kzalloc(sizeof(*mdev), GFP_KERNEL);
|
||||
if (!mdev)
|
||||
return ERR_PTR(err);
|
||||
|
|
@ -1633,6 +1629,8 @@ static int mpls_dev_notify(struct notifier_block *this, unsigned long event,
|
|||
unsigned int flags;
|
||||
int err;
|
||||
|
||||
mutex_lock(&net->mpls.platform_mutex);
|
||||
|
||||
if (event == NETDEV_REGISTER) {
|
||||
mdev = mpls_add_dev(dev);
|
||||
if (IS_ERR(mdev)) {
|
||||
|
|
@ -1695,9 +1693,11 @@ static int mpls_dev_notify(struct notifier_block *this, unsigned long event,
|
|||
}
|
||||
|
||||
out:
|
||||
mutex_unlock(&net->mpls.platform_mutex);
|
||||
return NOTIFY_OK;
|
||||
|
||||
err:
|
||||
mutex_unlock(&net->mpls.platform_mutex);
|
||||
return notifier_from_errno(err);
|
||||
}
|
||||
|
||||
|
|
@ -1973,6 +1973,7 @@ errout:
|
|||
static int mpls_rtm_delroute(struct sk_buff *skb, struct nlmsghdr *nlh,
|
||||
struct netlink_ext_ack *extack)
|
||||
{
|
||||
struct net *net = sock_net(skb->sk);
|
||||
struct mpls_route_config *cfg;
|
||||
int err;
|
||||
|
||||
|
|
@ -1984,7 +1985,9 @@ static int mpls_rtm_delroute(struct sk_buff *skb, struct nlmsghdr *nlh,
|
|||
if (err < 0)
|
||||
goto out;
|
||||
|
||||
mutex_lock(&net->mpls.platform_mutex);
|
||||
err = mpls_route_del(cfg, extack);
|
||||
mutex_unlock(&net->mpls.platform_mutex);
|
||||
out:
|
||||
kfree(cfg);
|
||||
|
||||
|
|
@ -1995,6 +1998,7 @@ out:
|
|||
static int mpls_rtm_newroute(struct sk_buff *skb, struct nlmsghdr *nlh,
|
||||
struct netlink_ext_ack *extack)
|
||||
{
|
||||
struct net *net = sock_net(skb->sk);
|
||||
struct mpls_route_config *cfg;
|
||||
int err;
|
||||
|
||||
|
|
@ -2006,7 +2010,9 @@ static int mpls_rtm_newroute(struct sk_buff *skb, struct nlmsghdr *nlh,
|
|||
if (err < 0)
|
||||
goto out;
|
||||
|
||||
mutex_lock(&net->mpls.platform_mutex);
|
||||
err = mpls_route_add(cfg, extack);
|
||||
mutex_unlock(&net->mpls.platform_mutex);
|
||||
out:
|
||||
kfree(cfg);
|
||||
|
||||
|
|
@ -2407,6 +2413,8 @@ static int mpls_getroute(struct sk_buff *in_skb, struct nlmsghdr *in_nlh,
|
|||
u8 n_labels;
|
||||
int err;
|
||||
|
||||
mutex_lock(&net->mpls.platform_mutex);
|
||||
|
||||
err = mpls_valid_getroute_req(in_skb, in_nlh, tb, extack);
|
||||
if (err < 0)
|
||||
goto errout;
|
||||
|
|
@ -2450,7 +2458,8 @@ static int mpls_getroute(struct sk_buff *in_skb, struct nlmsghdr *in_nlh,
|
|||
goto errout_free;
|
||||
}
|
||||
|
||||
return rtnl_unicast(skb, net, portid);
|
||||
err = rtnl_unicast(skb, net, portid);
|
||||
goto errout;
|
||||
}
|
||||
|
||||
if (tb[RTA_NEWDST]) {
|
||||
|
|
@ -2542,12 +2551,14 @@ static int mpls_getroute(struct sk_buff *in_skb, struct nlmsghdr *in_nlh,
|
|||
|
||||
err = rtnl_unicast(skb, net, portid);
|
||||
errout:
|
||||
mutex_unlock(&net->mpls.platform_mutex);
|
||||
return err;
|
||||
|
||||
nla_put_failure:
|
||||
nlmsg_cancel(skb, nlh);
|
||||
err = -EMSGSIZE;
|
||||
errout_free:
|
||||
mutex_unlock(&net->mpls.platform_mutex);
|
||||
kfree_skb(skb);
|
||||
return err;
|
||||
}
|
||||
|
|
@ -2603,9 +2614,10 @@ static int resize_platform_label_table(struct net *net, size_t limit)
|
|||
lo->addr_len);
|
||||
}
|
||||
|
||||
rtnl_lock();
|
||||
mutex_lock(&net->mpls.platform_mutex);
|
||||
|
||||
/* Remember the original table */
|
||||
old = rtnl_dereference(net->mpls.platform_label);
|
||||
old = mpls_dereference(net, net->mpls.platform_label);
|
||||
old_limit = net->mpls.platform_labels;
|
||||
|
||||
/* Free any labels beyond the new table */
|
||||
|
|
@ -2636,7 +2648,7 @@ static int resize_platform_label_table(struct net *net, size_t limit)
|
|||
net->mpls.platform_labels = limit;
|
||||
rcu_assign_pointer(net->mpls.platform_label, labels);
|
||||
|
||||
rtnl_unlock();
|
||||
mutex_unlock(&net->mpls.platform_mutex);
|
||||
|
||||
mpls_rt_free(rt2);
|
||||
mpls_rt_free(rt0);
|
||||
|
|
@ -2709,12 +2721,13 @@ static const struct ctl_table mpls_table[] = {
|
|||
},
|
||||
};
|
||||
|
||||
static int mpls_net_init(struct net *net)
|
||||
static __net_init int mpls_net_init(struct net *net)
|
||||
{
|
||||
size_t table_size = ARRAY_SIZE(mpls_table);
|
||||
struct ctl_table *table;
|
||||
int i;
|
||||
|
||||
mutex_init(&net->mpls.platform_mutex);
|
||||
net->mpls.platform_labels = 0;
|
||||
net->mpls.platform_label = NULL;
|
||||
net->mpls.ip_ttl_propagate = 1;
|
||||
|
|
@ -2740,7 +2753,7 @@ static int mpls_net_init(struct net *net)
|
|||
return 0;
|
||||
}
|
||||
|
||||
static void mpls_net_exit(struct net *net)
|
||||
static __net_exit void mpls_net_exit(struct net *net)
|
||||
{
|
||||
struct mpls_route __rcu **platform_label;
|
||||
size_t platform_labels;
|
||||
|
|
@ -2760,16 +2773,20 @@ static void mpls_net_exit(struct net *net)
|
|||
* As such no additional rcu synchronization is necessary when
|
||||
* freeing the platform_label table.
|
||||
*/
|
||||
rtnl_lock();
|
||||
platform_label = rtnl_dereference(net->mpls.platform_label);
|
||||
mutex_lock(&net->mpls.platform_mutex);
|
||||
|
||||
platform_label = mpls_dereference(net, net->mpls.platform_label);
|
||||
platform_labels = net->mpls.platform_labels;
|
||||
|
||||
for (index = 0; index < platform_labels; index++) {
|
||||
struct mpls_route *rt = rtnl_dereference(platform_label[index]);
|
||||
RCU_INIT_POINTER(platform_label[index], NULL);
|
||||
struct mpls_route *rt;
|
||||
|
||||
rt = mpls_dereference(net, platform_label[index]);
|
||||
mpls_notify_route(net, index, rt, NULL, NULL);
|
||||
mpls_rt_free(rt);
|
||||
}
|
||||
rtnl_unlock();
|
||||
|
||||
mutex_unlock(&net->mpls.platform_mutex);
|
||||
|
||||
kvfree(platform_label);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -185,6 +185,11 @@ static inline struct mpls_entry_decoded mpls_entry_decode(struct mpls_shim_hdr *
|
|||
return result;
|
||||
}
|
||||
|
||||
#define mpls_dereference(net, p) \
|
||||
rcu_dereference_protected( \
|
||||
(p), \
|
||||
lockdep_is_held(&(net)->mpls.platform_mutex))
|
||||
|
||||
static inline struct mpls_dev *mpls_dev_rcu(const struct net_device *dev)
|
||||
{
|
||||
return rcu_dereference(dev->mpls_ptr);
|
||||
|
|
@ -193,7 +198,7 @@ static inline struct mpls_dev *mpls_dev_rcu(const struct net_device *dev)
|
|||
static inline struct mpls_dev *mpls_dev_get(const struct net *net,
|
||||
const struct net_device *dev)
|
||||
{
|
||||
return rcu_dereference_rtnl(dev->mpls_ptr);
|
||||
return mpls_dereference(net, dev->mpls_ptr);
|
||||
}
|
||||
|
||||
int nla_put_labels(struct sk_buff *skb, int attrtype, u8 labels,
|
||||
|
|
|
|||
Loading…
Reference in New Issue