diff --git a/net/openvswitch/conntrack.c b/net/openvswitch/conntrack.c index 7c5bb98c22c6..f989ccf38eab 100644 --- a/net/openvswitch/conntrack.c +++ b/net/openvswitch/conntrack.c @@ -73,6 +73,8 @@ struct ovs_conntrack_info { #endif }; +static bool labels_nonzero(const struct ovs_key_ct_labels *labels); + static void __ovs_ct_free_action(struct ovs_conntrack_info *ct_info); static u16 key_to_nfproto(const struct sw_flow_key *key) @@ -270,18 +272,32 @@ static int ovs_ct_init_labels(struct nf_conn *ct, struct sw_flow_key *key, const struct ovs_key_ct_labels *labels, const struct ovs_key_ct_labels *mask) { - struct nf_conn_labels *cl; - u32 *dst; - int i; + struct nf_conn_labels *cl, *master_cl; + bool have_mask = labels_nonzero(mask); + + /* Inherit master's labels to the related connection? */ + master_cl = ct->master ? nf_ct_labels_find(ct->master) : NULL; + + if (!master_cl && !have_mask) + return 0; /* Nothing to do. */ cl = ovs_ct_get_conn_labels(ct); if (!cl) return -ENOSPC; - dst = (u32 *)cl->bits; - for (i = 0; i < OVS_CT_LABELS_LEN_32; i++) - dst[i] = (dst[i] & ~mask->ct_labels_32[i]) | - (labels->ct_labels_32[i] & mask->ct_labels_32[i]); + /* Inherit the master's labels, if any. */ + if (master_cl) + *cl = *master_cl; + + if (have_mask) { + u32 *dst = (u32 *)cl->bits; + int i; + + for (i = 0; i < OVS_CT_LABELS_LEN_32; i++) + dst[i] = (dst[i] & ~mask->ct_labels_32[i]) | + (labels->ct_labels_32[i] + & mask->ct_labels_32[i]); + } memcpy(&key->ct.labels, cl->bits, OVS_CT_LABELS_LEN); @@ -909,13 +925,14 @@ static int ovs_ct_commit(struct net *net, struct sw_flow_key *key, if (err) return err; } - if (labels_nonzero(&info->labels.mask)) { - if (!nf_ct_is_confirmed(ct)) - err = ovs_ct_init_labels(ct, key, &info->labels.value, - &info->labels.mask); - else - err = ovs_ct_set_labels(ct, key, &info->labels.value, - &info->labels.mask); + if (!nf_ct_is_confirmed(ct)) { + err = ovs_ct_init_labels(ct, key, &info->labels.value, + &info->labels.mask); + if (err) + return err; + } else if (labels_nonzero(&info->labels.mask)) { + err = ovs_ct_set_labels(ct, key, &info->labels.value, + &info->labels.mask); if (err) return err; }