On Thu, Aug 17, 2023 at 12:02:35PM +1000, David Gwynne wrote: > there are links between the pcb/socket layer and pf as an optimisation, > and links on mbufs between both sides of a forwarded connection. > these links let pf skip an rb tree lookup for outgoing packets. > > right now these links are between pf_state_key structs, which are the > things that contain the actual addresses used by the connection, but you > then have to iterate over a list in pf_state_keys to get to the pf_state > structures. > > i dont understand why we dont just link the actual pf_state structs. > my best guess is there wasnt enough machinery (ie, refcnts and > mtxes) on a pf_state struct to make it safe, so the compromise was > the pf_state keys. it still got to avoid the tree lookup.
This linkage is much older than MP in pf. There was no refcount and mutex when added by henning@. > i wanted this to make it easier to look up information on pf states from > the socket layer, but sashan@ said i should send it out. i do think it > makes things a bit easier to understand. > > the most worrying bit is the change to pf_state_find(). > > thoughts? ok? It took me years to get the logic correct for all corner cases. When using pf-divert with UDP strange things start to happen. Most things are covered by regress/sys/net/pf_divert . You have to setup two machines to run it. I don't know why it was written that way, but I know it works for me. What do you want to fix? bluhm > Index: kern/uipc_mbuf.c > =================================================================== > RCS file: /cvs/src/sys/kern/uipc_mbuf.c,v > retrieving revision 1.287 > diff -u -p -r1.287 uipc_mbuf.c > --- kern/uipc_mbuf.c 23 Jun 2023 04:36:49 -0000 1.287 > +++ kern/uipc_mbuf.c 17 Aug 2023 01:31:04 -0000 > @@ -308,7 +308,7 @@ m_clearhdr(struct mbuf *m) > /* delete all mbuf tags to reset the state */ > m_tag_delete_chain(m); > #if NPF > 0 > - pf_mbuf_unlink_state_key(m); > + pf_mbuf_unlink_state(m); > pf_mbuf_unlink_inpcb(m); > #endif /* NPF > 0 */ > > @@ -440,7 +440,7 @@ m_free(struct mbuf *m) > if (m->m_flags & M_PKTHDR) { > m_tag_delete_chain(m); > #if NPF > 0 > - pf_mbuf_unlink_state_key(m); > + pf_mbuf_unlink_state(m); > pf_mbuf_unlink_inpcb(m); > #endif /* NPF > 0 */ > } > @@ -1398,8 +1398,8 @@ m_dup_pkthdr(struct mbuf *to, struct mbu > to->m_pkthdr = from->m_pkthdr; > > #if NPF > 0 > - to->m_pkthdr.pf.statekey = NULL; > - pf_mbuf_link_state_key(to, from->m_pkthdr.pf.statekey); > + to->m_pkthdr.pf.st = NULL; > + pf_mbuf_link_state(to, from->m_pkthdr.pf.st); > to->m_pkthdr.pf.inp = NULL; > pf_mbuf_link_inpcb(to, from->m_pkthdr.pf.inp); > #endif /* NPF > 0 */ > @@ -1526,8 +1526,8 @@ m_print(void *v, > m->m_pkthdr.csum_flags, MCS_BITS); > (*pr)("m_pkthdr.ether_vtag: %u\tm_ptkhdr.ph_rtableid: %u\n", > m->m_pkthdr.ether_vtag, m->m_pkthdr.ph_rtableid); > - (*pr)("m_pkthdr.pf.statekey: %p\tm_pkthdr.pf.inp %p\n", > - m->m_pkthdr.pf.statekey, m->m_pkthdr.pf.inp); > + (*pr)("m_pkthdr.pf.st: %p\tm_pkthdr.pf.inp %p\n", > + m->m_pkthdr.pf.st, m->m_pkthdr.pf.inp); > (*pr)("m_pkthdr.pf.qid: %u\tm_pkthdr.pf.tag: %u\n", > m->m_pkthdr.pf.qid, m->m_pkthdr.pf.tag); > (*pr)("m_pkthdr.pf.flags: %b\n", > Index: net/if_mpw.c > =================================================================== > RCS file: /cvs/src/sys/net/if_mpw.c,v > retrieving revision 1.63 > diff -u -p -r1.63 if_mpw.c > --- net/if_mpw.c 29 Aug 2022 07:51:45 -0000 1.63 > +++ net/if_mpw.c 17 Aug 2023 01:31:04 -0000 > @@ -620,7 +620,7 @@ mpw_input(struct mpw_softc *sc, struct m > m->m_pkthdr.ph_rtableid = ifp->if_rdomain; > > /* packet has not been processed by PF yet. */ > - KASSERT(m->m_pkthdr.pf.statekey == NULL); > + KASSERT(m->m_pkthdr.pf.st == NULL); > > if_vinput(ifp, m); > return; > Index: net/if_tpmr.c > =================================================================== > RCS file: /cvs/src/sys/net/if_tpmr.c,v > retrieving revision 1.33 > diff -u -p -r1.33 if_tpmr.c > --- net/if_tpmr.c 16 May 2023 14:32:54 -0000 1.33 > +++ net/if_tpmr.c 17 Aug 2023 01:31:04 -0000 > @@ -303,7 +303,7 @@ tpmr_pf(struct ifnet *ifp0, int dir, str > return (NULL); > > if (dir == PF_IN && ISSET(m->m_pkthdr.pf.flags, PF_TAG_DIVERTED)) { > - pf_mbuf_unlink_state_key(m); > + pf_mbuf_unlink_state(m); > pf_mbuf_unlink_inpcb(m); > (*fam->ip_input)(ifp0, m); > return (NULL); > Index: net/if_veb.c > =================================================================== > RCS file: /cvs/src/sys/net/if_veb.c,v > retrieving revision 1.31 > diff -u -p -r1.31 if_veb.c > --- net/if_veb.c 16 May 2023 14:32:54 -0000 1.31 > +++ net/if_veb.c 17 Aug 2023 01:31:04 -0000 > @@ -654,7 +654,7 @@ veb_pf(struct ifnet *ifp0, int dir, stru > return (NULL); > > if (dir == PF_IN && ISSET(m->m_pkthdr.pf.flags, PF_TAG_DIVERTED)) { > - pf_mbuf_unlink_state_key(m); > + pf_mbuf_unlink_state(m); > pf_mbuf_unlink_inpcb(m); > (*fam->ip_input)(ifp0, m); > return (NULL); > Index: net/pf.c > =================================================================== > RCS file: /cvs/src/sys/net/pf.c,v > retrieving revision 1.1184 > diff -u -p -r1.1184 pf.c > --- net/pf.c 31 Jul 2023 11:13:09 -0000 1.1184 > +++ net/pf.c 17 Aug 2023 01:31:04 -0000 > @@ -247,16 +247,17 @@ int pf_state_insert(struct pfi_kif > *, > struct pf_state_key **, struct pf_state_key **, > struct pf_state *); > > +int pf_state_isvalid(struct pf_state *); > int pf_state_key_isvalid(struct pf_state_key *); > struct pf_state_key *pf_state_key_ref(struct pf_state_key *); > void pf_state_key_unref(struct pf_state_key *); > -void pf_state_key_link_reverse(struct pf_state_key *, > - struct pf_state_key *); > -void pf_state_key_unlink_reverse(struct pf_state_key *); > -void pf_state_key_link_inpcb(struct pf_state_key *, > +void pf_state_link_reverse(struct pf_state *, > + struct pf_state *); > +void pf_state_unlink_reverse(struct pf_state *); > +void pf_state_link_inpcb(struct pf_state *, > struct inpcb *); > -void pf_state_key_unlink_inpcb(struct pf_state_key *); > -void pf_inpcb_unlink_state_key(struct inpcb *); > +void pf_state_unlink_inpcb(struct pf_state *); > +void pf_inpcb_unlink_state(struct inpcb *); > void pf_pktenqueue_delayed(void *); > int32_t pf_state_expires(const struct pf_state *, > uint8_t); > > @@ -852,8 +853,6 @@ pf_state_key_detach(struct pf_state *st, > if (TAILQ_EMPTY(&sk->sk_states)) { > RBT_REMOVE(pf_state_tree, &pf_statetbl, sk); > sk->sk_removed = 1; > - pf_state_key_unlink_reverse(sk); > - pf_state_key_unlink_inpcb(sk); > pf_state_key_unref(sk); > } > > @@ -1115,13 +1114,41 @@ pf_compare_state_keys(struct pf_state_ke > } > } > > +static inline struct pf_state * > +pf_find_state_lookup(struct pf_pdesc *pd, const struct pf_state_key_cmp *key) > +{ > + struct pf_state_key *sk; > + struct pf_state_item *si; > + struct pf_state *st; > + uint8_t dir = pd->dir; > + > + sk = RBT_FIND(pf_state_tree, &pf_statetbl, (struct pf_state_key *)key); > + if (sk == NULL) > + return (NULL); > + > + /* list is sorted, if-bound states before floating ones */ > + TAILQ_FOREACH(si, &sk->sk_states, si_entry) { > + st = si->si_st; > + if (st->timeout == PFTM_PURGE) > + continue; > + if (st->kif != pfi_all && st->kif != pd->kif) > + continue; > + > + if (st->key[dir == PF_IN ? PF_SK_WIRE : PF_SK_STACK] == sk) > + return (st); > + } > + > + return (NULL); > +} > + > int > pf_find_state(struct pf_pdesc *pd, struct pf_state_key_cmp *key, > struct pf_state **stp) > { > - struct pf_state_key *sk, *pkt_sk, *inp_sk; > - struct pf_state_item *si; > struct pf_state *st = NULL; > + struct pf_state *strev = NULL; > + struct inpcb *inp = NULL; > + int rv = PF_DROP; > > pf_status.fcounters[FCNT_STATE_SEARCH]++; > if (pf_status.debug >= LOG_DEBUG) { > @@ -1131,80 +1158,67 @@ pf_find_state(struct pf_pdesc *pd, struc > addlog("\n"); > } > > - inp_sk = NULL; > - pkt_sk = NULL; > - sk = NULL; > if (pd->dir == PF_OUT) { > + /* take the references */ > + strev = pd->m->m_pkthdr.pf.st; > + inp = pd->m->m_pkthdr.pf.inp; > + > /* first if block deals with outbound forwarded packet */ > - pkt_sk = pd->m->m_pkthdr.pf.statekey; > + if (strev != NULL) { > + pd->m->m_pkthdr.pf.st = NULL; > + KASSERT(inp == NULL); > > - if (!pf_state_key_isvalid(pkt_sk)) { > - pf_mbuf_unlink_state_key(pd->m); > - pkt_sk = NULL; > - } > + if (pf_state_isvalid(strev)) { > + st = strev->reverse; > + if (st != NULL && pf_state_isvalid(st)) > + goto match; > + } > > - if (pkt_sk && pf_state_key_isvalid(pkt_sk->sk_reverse)) > - sk = pkt_sk->sk_reverse; > + /* this handles st not being valid too */ > + pf_state_unlink_reverse(strev); > > - if (pkt_sk == NULL) { > + } else if (inp != NULL) { > /* here we deal with local outbound packet */ > - if (pd->m->m_pkthdr.pf.inp != NULL) { > - inp_sk = pd->m->m_pkthdr.pf.inp->inp_pf_sk; > - if (pf_state_key_isvalid(inp_sk)) > - sk = inp_sk; > - else > - pf_inpcb_unlink_state_key( > - pd->m->m_pkthdr.pf.inp); > + pd->m->m_pkthdr.pf.inp = NULL; > + > + st = inp->inp_pf_st; > + if (st != NULL) { > + if (pf_state_isvalid(st)) > + goto match; > + > + pf_inpcb_unlink_state(inp); > } > } > } > > - if (sk == NULL) { > - if ((sk = RBT_FIND(pf_state_tree, &pf_statetbl, > - (struct pf_state_key *)key)) == NULL) > - return (PF_DROP); > - if (pd->dir == PF_OUT && pkt_sk && > - pf_compare_state_keys(pkt_sk, sk, pd->kif, pd->dir) == 0) > - pf_state_key_link_reverse(sk, pkt_sk); > - else if (pd->dir == PF_OUT && pd->m->m_pkthdr.pf.inp && > - !pd->m->m_pkthdr.pf.inp->inp_pf_sk && !sk->sk_inp) > - pf_state_key_link_inpcb(sk, pd->m->m_pkthdr.pf.inp); > - } > - > - /* remove firewall data from outbound packet */ > - if (pd->dir == PF_OUT) > - pf_pkt_addr_changed(pd->m); > + st = pf_find_state_lookup(pd, key); > + if (st == NULL || ISSET(st->state_flags, PFSTATE_INP_UNLINKED)) > + goto drop; > > - /* list is sorted, if-bound states before floating ones */ > - TAILQ_FOREACH(si, &sk->sk_states, si_entry) { > - struct pf_state *sist = si->si_st; > - if (sist->timeout != PFTM_PURGE && > - (sist->kif == pfi_all || sist->kif == pd->kif) && > - ((sist->key[PF_SK_WIRE]->af == sist->key[PF_SK_STACK]->af && > - sk == (pd->dir == PF_IN ? sist->key[PF_SK_WIRE] : > - sist->key[PF_SK_STACK])) || > - (sist->key[PF_SK_WIRE]->af != sist->key[PF_SK_STACK]->af > - && pd->dir == PF_IN && (sk == sist->key[PF_SK_STACK] || > - sk == sist->key[PF_SK_WIRE])))) { > - st = sist; > - break; > - } > + if (pd->dir == PF_OUT) { > + if (strev != NULL) > + pf_state_link_reverse(st, strev); > + else if (inp != NULL) > + pf_state_link_inpcb(st, inp); > } > > - if (st == NULL) > - return (PF_DROP); > - if (ISSET(st->state_flags, PFSTATE_INP_UNLINKED)) > - return (PF_DROP); > - > +match: > if (st->rule.ptr->pktrate.limit && pd->dir == st->direction) { > pf_add_threshold(&st->rule.ptr->pktrate); > if (pf_check_threshold(&st->rule.ptr->pktrate)) > - return (PF_DROP); > + goto drop; > } > > *stp = st; > + rv = PF_MATCH; > + > +drop: > + if (strev != NULL) > + pf_state_unref(strev); > + else if (inp != NULL) > + in_pcbunref(inp); > > - return (PF_MATCH); > + return (rv); > } > > struct pf_state * > @@ -1763,6 +1777,9 @@ pf_remove_state(struct pf_state *st) > > st->timeout = PFTM_UNLINKED; > > + pf_state_unlink_reverse(st); > + pf_state_unlink_inpcb(st); > + > /* handle load balancing related tasks */ > pf_postprocess_addr(st); > > @@ -1792,38 +1809,32 @@ pf_remove_state(struct pf_state *st) > } > > void > -pf_remove_divert_state(struct pf_state_key *sk) > +pf_remove_divert_state(struct pf_state *st) > { > - struct pf_state_item *si; > - > PF_ASSERT_UNLOCKED(); > > PF_LOCK(); > PF_STATE_ENTER_WRITE(); > - TAILQ_FOREACH(si, &sk->sk_states, si_entry) { > - struct pf_state *sist = si->si_st; > - if (sk == sist->key[PF_SK_STACK] && sist->rule.ptr && > - (sist->rule.ptr->divert.type == PF_DIVERT_TO || > - sist->rule.ptr->divert.type == PF_DIVERT_REPLY)) { > - if (sist->key[PF_SK_STACK]->proto == IPPROTO_TCP && > - sist->key[PF_SK_WIRE] != sist->key[PF_SK_STACK]) { > - /* > - * If the local address is translated, keep > - * the state for "tcp.closed" seconds to > - * prevent its source port from being reused. > - */ > - if (sist->src.state < TCPS_FIN_WAIT_2 || > - sist->dst.state < TCPS_FIN_WAIT_2) { > - pf_set_protostate(sist, PF_PEER_BOTH, > - TCPS_TIME_WAIT); > - sist->timeout = PFTM_TCP_CLOSED; > - sist->expire = getuptime(); > - } > - sist->state_flags |= PFSTATE_INP_UNLINKED; > - } else > - pf_remove_state(sist); > - break; > - } > + if (st->rule.ptr && > + (st->rule.ptr->divert.type == PF_DIVERT_TO || > + st->rule.ptr->divert.type == PF_DIVERT_REPLY)) { > + if (st->key[PF_SK_STACK]->proto == IPPROTO_TCP && > + st->key[PF_SK_WIRE] != st->key[PF_SK_STACK]) { > + /* > + * If the local address is translated, keep > + * the state for "tcp.closed" seconds to > + * prevent its source port from being reused. > + */ > + if (st->src.state < TCPS_FIN_WAIT_2 || > + st->dst.state < TCPS_FIN_WAIT_2) { > + pf_set_protostate(st, PF_PEER_BOTH, > + TCPS_TIME_WAIT); > + st->timeout = PFTM_TCP_CLOSED; > + st->expire = getuptime(); > + } > + st->state_flags |= PFSTATE_INP_UNLINKED; > + } else > + pf_remove_state(st); > } > PF_STATE_EXIT_WRITE(); > PF_UNLOCK(); > @@ -7836,17 +7847,22 @@ done: > > if (action == PF_PASS && qid) > pd.m->m_pkthdr.pf.qid = qid; > - if (pd.dir == PF_IN && st && st->key[PF_SK_STACK]) > - pf_mbuf_link_state_key(pd.m, st->key[PF_SK_STACK]); > - if (pd.dir == PF_OUT && > - pd.m->m_pkthdr.pf.inp && !pd.m->m_pkthdr.pf.inp->inp_pf_sk && > - st && st->key[PF_SK_STACK] && !st->key[PF_SK_STACK]->sk_inp) > - pf_state_key_link_inpcb(st->key[PF_SK_STACK], > - pd.m->m_pkthdr.pf.inp); > - > - if (st != NULL && !ISSET(pd.m->m_pkthdr.csum_flags, M_FLOWID)) { > - pd.m->m_pkthdr.ph_flowid = st->key[PF_SK_WIRE]->hash; > - SET(pd.m->m_pkthdr.csum_flags, M_FLOWID); > + if (st != NULL) { > + struct mbuf *m = pd.m; > + struct inpcb *inp = m->m_pkthdr.pf.inp; > + > + if (pd.dir == PF_IN) { > + KASSERT(inp == NULL); > + pf_mbuf_link_state(m, st); > + } else { > + if (inp != NULL && inp->inp_pf_st == NULL) > + pf_state_link_inpcb(st, inp); > + } > + > + if (!ISSET(m->m_pkthdr.csum_flags, M_FLOWID)) { > + m->m_pkthdr.ph_flowid = st->key[PF_SK_WIRE]->hash; > + SET(m->m_pkthdr.csum_flags, M_FLOWID); > + } > } > > /* > @@ -8004,14 +8020,14 @@ done: > int > pf_ouraddr(struct mbuf *m) > { > - struct pf_state_key *sk; > + struct pf_state *st; > > if (m->m_pkthdr.pf.flags & PF_TAG_DIVERTED) > return (1); > > - sk = m->m_pkthdr.pf.statekey; > - if (sk != NULL) { > - if (sk->sk_inp != NULL) > + st = m->m_pkthdr.pf.st; > + if (st != NULL) { > + if (st->inp != NULL) > return (1); > } > > @@ -8025,7 +8041,7 @@ pf_ouraddr(struct mbuf *m) > void > pf_pkt_addr_changed(struct mbuf *m) > { > - pf_mbuf_unlink_state_key(m); > + pf_mbuf_unlink_state(m); > pf_mbuf_unlink_inpcb(m); > } > > @@ -8033,71 +8049,56 @@ struct inpcb * > pf_inp_lookup(struct mbuf *m) > { > struct inpcb *inp = NULL; > - struct pf_state_key *sk = m->m_pkthdr.pf.statekey; > + struct pf_state *st; > > - if (!pf_state_key_isvalid(sk)) > - pf_mbuf_unlink_state_key(m); > - else > - inp = m->m_pkthdr.pf.statekey->sk_inp; > + st = m->m_pkthdr.pf.st; > + if (st == NULL) > + return (NULL); > + if (!pf_state_isvalid(st)) { > + pf_mbuf_unlink_state(m); > + return (NULL); > + } > + > + inp = st->inp; > + if (inp == NULL) > + return (NULL); > > - if (inp && inp->inp_pf_sk) > - KASSERT(m->m_pkthdr.pf.statekey == inp->inp_pf_sk); > + KASSERT(inp->inp_pf_st == NULL || inp->inp_pf_st == st); > > - in_pcbref(inp); > - return (inp); > + return (in_pcbref(inp)); > } > > +/* > + * This is called from the IP stack after it's found an inpcb for > + * an mbuf so it can link the pf_state to that pcb. > + */ > void > pf_inp_link(struct mbuf *m, struct inpcb *inp) > { > - struct pf_state_key *sk = m->m_pkthdr.pf.statekey; > + struct pf_state *st; > > - if (!pf_state_key_isvalid(sk)) { > - pf_mbuf_unlink_state_key(m); > + st = m->m_pkthdr.pf.st; > + if (st == NULL) > return; > - } > > /* > * we don't need to grab PF-lock here. At worst case we link inp to > * state, which might be just being marked as deleted by another > * thread. > */ > - if (inp && !sk->sk_inp && !inp->inp_pf_sk) > - pf_state_key_link_inpcb(sk, inp); > + if (pf_state_isvalid(st)) { > + if (st->inp == NULL && inp->inp_pf_st == NULL) > + pf_state_link_inpcb(st, inp); > + } > > /* The statekey has finished finding the inp, it is no longer needed. */ > - pf_mbuf_unlink_state_key(m); > + pf_mbuf_unlink_state(m); > } > > void > pf_inp_unlink(struct inpcb *inp) > { > - pf_inpcb_unlink_state_key(inp); > -} > - > -void > -pf_state_key_link_reverse(struct pf_state_key *sk, struct pf_state_key > *skrev) > -{ > - struct pf_state_key *old_reverse; > - > - old_reverse = atomic_cas_ptr(&sk->sk_reverse, NULL, skrev); > - if (old_reverse != NULL) > - KASSERT(old_reverse == skrev); > - else { > - pf_state_key_ref(skrev); > - > - /* > - * NOTE: if sk == skrev, then KASSERT() below holds true, we > - * still want to grab a reference in such case, because > - * pf_state_key_unlink_reverse() does not check whether keys > - * are identical or not. > - */ > - old_reverse = atomic_cas_ptr(&skrev->sk_reverse, NULL, sk); > - if (old_reverse != NULL) > - KASSERT(old_reverse == sk); > - > - pf_state_key_ref(sk); > - } > + pf_inpcb_unlink_state(inp); > } > > #if NPFLOG > 0 > @@ -8132,10 +8133,6 @@ pf_state_key_unref(struct pf_state_key * > if (PF_REF_RELE(sk->sk_refcnt)) { > /* state key must be removed from tree */ > KASSERT(!pf_state_key_isvalid(sk)); > - /* state key must be unlinked from reverse key */ > - KASSERT(sk->sk_reverse == NULL); > - /* state key must be unlinked from socket */ > - KASSERT(sk->sk_inp == NULL); > pool_put(&pf_state_key_pl, sk); > } > } > @@ -8146,21 +8143,28 @@ pf_state_key_isvalid(struct pf_state_key > return ((sk != NULL) && (sk->sk_removed == 0)); > } > > +int > +pf_state_isvalid(struct pf_state *st) > +{ > + return (st->timeout < PFTM_MAX); > +} > + > void > -pf_mbuf_link_state_key(struct mbuf *m, struct pf_state_key *sk) > +pf_mbuf_link_state(struct mbuf *m, struct pf_state *st) > { > - KASSERT(m->m_pkthdr.pf.statekey == NULL); > - m->m_pkthdr.pf.statekey = pf_state_key_ref(sk); > + KASSERT(m->m_pkthdr.pf.st == NULL); > + m->m_pkthdr.pf.st = pf_state_ref(st); > } > > void > -pf_mbuf_unlink_state_key(struct mbuf *m) > +pf_mbuf_unlink_state(struct mbuf *m) > { > - struct pf_state_key *sk = m->m_pkthdr.pf.statekey; > + struct pf_state *st; > > - if (sk != NULL) { > - m->m_pkthdr.pf.statekey = NULL; > - pf_state_key_unref(sk); > + st = m->m_pkthdr.pf.st; > + if (st != NULL) { > + m->m_pkthdr.pf.st = NULL; > + pf_state_unref(st); > } > } > > @@ -8174,64 +8178,107 @@ pf_mbuf_link_inpcb(struct mbuf *m, struc > void > pf_mbuf_unlink_inpcb(struct mbuf *m) > { > - struct inpcb *inp = m->m_pkthdr.pf.inp; > + struct inpcb *inp; > > + inp = m->m_pkthdr.pf.inp; > if (inp != NULL) { > m->m_pkthdr.pf.inp = NULL; > in_pcbunref(inp); > } > } > > +/* assumes caller has an exclusive lock around inp */ > void > -pf_state_key_link_inpcb(struct pf_state_key *sk, struct inpcb *inp) > +pf_state_link_inpcb(struct pf_state *st, struct inpcb *inp) > { > - KASSERT(sk->sk_inp == NULL); > - sk->sk_inp = in_pcbref(inp); > - KASSERT(inp->inp_pf_sk == NULL); > - inp->inp_pf_sk = pf_state_key_ref(sk); > + KASSERT(inp->inp_pf_st == NULL); > + inp->inp_pf_st = pf_state_ref(st); > + > + mtx_enter(&st->mtx); > + KASSERT(st->inp == NULL); > + st->inp = in_pcbref(inp); > + mtx_leave(&st->mtx); > } > > +/* assumes caller has an exclusive lock around inp */ > void > -pf_inpcb_unlink_state_key(struct inpcb *inp) > +pf_inpcb_unlink_state(struct inpcb *inp) > { > - struct pf_state_key *sk = inp->inp_pf_sk; > + struct pf_state *st; > > - if (sk != NULL) { > - KASSERT(sk->sk_inp == inp); > - sk->sk_inp = NULL; > - inp->inp_pf_sk = NULL; > - pf_state_key_unref(sk); > + st = inp->inp_pf_st; > + if (st != NULL) { > + inp->inp_pf_st = NULL; > + > + mtx_enter(&st->mtx); > + KASSERT(st->inp == inp); > + st->inp = NULL; > + mtx_leave(&st->mtx); > in_pcbunref(inp); > + > + pf_state_unref(st); > } > } > > void > -pf_state_key_unlink_inpcb(struct pf_state_key *sk) > +pf_state_unlink_inpcb(struct pf_state *st) > { > - struct inpcb *inp = sk->sk_inp; > + struct inpcb *inp; > + > + mtx_enter(&st->mtx); > + inp = st->inp; > + if (inp != NULL) > + st->inp = NULL; > + mtx_leave(&st->mtx); > > + /* XXX wtf lock? */ > if (inp != NULL) { > - KASSERT(inp->inp_pf_sk == sk); > - sk->sk_inp = NULL; > - inp->inp_pf_sk = NULL; > - pf_state_key_unref(sk); > + KASSERT(inp->inp_pf_st == st); > + inp->inp_pf_st = NULL; > + > + pf_state_unref(st); > in_pcbunref(inp); > } > } > > void > -pf_state_key_unlink_reverse(struct pf_state_key *sk) > +pf_state_link_reverse(struct pf_state *st, struct pf_state *strev) > { > - struct pf_state_key *skrev = sk->sk_reverse; > + mtx_enter(&st->mtx); > + if (st->reverse == NULL) > + st->reverse = pf_state_ref(strev); > + mtx_leave(&st->mtx); > > - /* Note that sk and skrev may be equal, then we unref twice. */ > - if (skrev != NULL) { > - KASSERT(skrev->sk_reverse == sk); > - sk->sk_reverse = NULL; > - skrev->sk_reverse = NULL; > - pf_state_key_unref(skrev); > - pf_state_key_unref(sk); > - } > + mtx_enter(&strev->mtx); > + if (strev->reverse == NULL) > + strev->reverse = pf_state_ref(st); > + mtx_leave(&strev->mtx); > +} > + > +void > +pf_state_unlink_reverse(struct pf_state *st) > +{ > + struct pf_state *strev; > + > + mtx_enter(&st->mtx); > + strev = st->reverse; > + if (strev != NULL) > + st->reverse = NULL; /* take over strev reference */ > + mtx_leave(&st->mtx); > + > + if (strev == NULL) > + return; > + > + mtx_enter(&strev->mtx); > + if (strev->reverse == st) > + strev->reverse = NULL; > + else > + st = NULL; > + mtx_leave(&strev->mtx); > + > + pf_state_unref(strev); /* drop the reference we just inherited */ > + if (st != NULL) > + pf_state_unref(st); /* drop the reference strev had */ > } > > struct pf_state * > @@ -8257,6 +8304,11 @@ pf_state_unref(struct pf_state *st) > > pf_state_key_unref(st->key[PF_SK_WIRE]); > pf_state_key_unref(st->key[PF_SK_STACK]); > + > + /* state must be unlinked from reverse */ > + KASSERT(st->reverse == NULL); > + /* state must be unlinked from socket */ > + KASSERT(st->inp == NULL); > > pool_put(&pf_state_pl, st); > } > Index: net/pfvar.h > =================================================================== > RCS file: /cvs/src/sys/net/pfvar.h,v > retrieving revision 1.533 > diff -u -p -r1.533 pfvar.h > --- net/pfvar.h 6 Jul 2023 04:55:05 -0000 1.533 > +++ net/pfvar.h 17 Aug 2023 01:31:04 -0000 > @@ -1606,7 +1606,7 @@ extern void > pf_calc_skip_steps(struct > extern void pf_purge_expired_src_nodes(void); > extern void pf_purge_expired_rules(void); > extern void pf_remove_state(struct pf_state *); > -extern void pf_remove_divert_state(struct pf_state_key *); > +extern void pf_remove_divert_state(struct pf_state *); > extern void pf_free_state(struct pf_state *); > int pf_insert_src_node(struct pf_src_node **, > struct pf_rule *, enum pf_sn_types, > @@ -1860,9 +1860,8 @@ int pf_map_addr(sa_family_t, > struct p > struct pf_pool *, enum pf_sn_types); > int pf_postprocess_addr(struct pf_state *); > > -void pf_mbuf_link_state_key(struct mbuf *, > - struct pf_state_key *); > -void pf_mbuf_unlink_state_key(struct mbuf *); > +void pf_mbuf_link_state(struct mbuf *, struct pf_state *); > +void pf_mbuf_unlink_state(struct mbuf *); > void pf_mbuf_link_inpcb(struct mbuf *, struct inpcb *); > void pf_mbuf_unlink_inpcb(struct mbuf *); > > Index: net/pfvar_priv.h > =================================================================== > RCS file: /cvs/src/sys/net/pfvar_priv.h,v > retrieving revision 1.34 > diff -u -p -r1.34 pfvar_priv.h > --- net/pfvar_priv.h 6 Jul 2023 04:55:05 -0000 1.34 > +++ net/pfvar_priv.h 17 Aug 2023 01:31:04 -0000 > @@ -69,8 +69,6 @@ struct pf_state_key { > > RB_ENTRY(pf_state_key) sk_entry; > struct pf_statelisthead sk_states; > - struct pf_state_key *sk_reverse; > - struct inpcb *sk_inp; > pf_refcnt_t sk_refcnt; > u_int8_t sk_removed; > }; > @@ -115,6 +113,8 @@ struct pf_state { > struct pf_sn_head src_nodes; /* [I] */ > struct pf_state_key *key[2]; /* [I] stack and wire */ > struct pfi_kif *kif; /* [I] */ > + struct pf_state *reverse; /* [M] */ > + struct inpcb *inp; /* [M] */ > struct mutex mtx; > pf_refcnt_t refcnt; > u_int64_t packets[2]; > Index: netinet/in_pcb.c > =================================================================== > RCS file: /cvs/src/sys/netinet/in_pcb.c,v > retrieving revision 1.277 > diff -u -p -r1.277 in_pcb.c > --- netinet/in_pcb.c 24 Jun 2023 20:54:46 -0000 1.277 > +++ netinet/in_pcb.c 17 Aug 2023 01:31:04 -0000 > @@ -538,8 +538,8 @@ void > in_pcbdisconnect(struct inpcb *inp) > { > #if NPF > 0 > - if (inp->inp_pf_sk) { > - pf_remove_divert_state(inp->inp_pf_sk); > + if (inp->inp_pf_st) { > + pf_remove_divert_state(inp->inp_pf_st); > /* pf_remove_divert_state() may have detached the state */ > pf_inp_unlink(inp); > } > @@ -588,8 +588,8 @@ in_pcbdetach(struct inpcb *inp) > #endif > ip_freemoptions(inp->inp_moptions); > #if NPF > 0 > - if (inp->inp_pf_sk) { > - pf_remove_divert_state(inp->inp_pf_sk); > + if (inp->inp_pf_st) { > + pf_remove_divert_state(inp->inp_pf_st); > /* pf_remove_divert_state() may have detached the state */ > pf_inp_unlink(inp); > } > Index: netinet/in_pcb.h > =================================================================== > RCS file: /cvs/src/sys/netinet/in_pcb.h,v > retrieving revision 1.136 > diff -u -p -r1.136 in_pcb.h > --- netinet/in_pcb.h 24 Jun 2023 20:54:46 -0000 1.136 > +++ netinet/in_pcb.h 17 Aug 2023 01:31:04 -0000 > @@ -84,7 +84,7 @@ > * p inpcb_mtx pcb mutex > */ > > -struct pf_state_key; > +struct pf_state; > > union inpaddru { > struct in6_addr iau_addr6; > @@ -155,7 +155,7 @@ struct inpcb { > #define inp_csumoffset inp_cksum6 > #endif > struct icmp6_filter *inp_icmp6filt; > - struct pf_state_key *inp_pf_sk; > + struct pf_state *inp_pf_st; > struct mbuf *(*inp_upcall)(void *, struct mbuf *, > struct ip *, struct ip6_hdr *, void *, int); > void *inp_upcall_arg; > Index: sys/mbuf.h > =================================================================== > RCS file: /cvs/src/sys/sys/mbuf.h,v > retrieving revision 1.261 > diff -u -p -r1.261 mbuf.h > --- sys/mbuf.h 16 Jul 2023 03:01:31 -0000 1.261 > +++ sys/mbuf.h 17 Aug 2023 01:31:04 -0000 > @@ -92,11 +92,11 @@ struct m_hdr { > }; > > /* pf stuff */ > -struct pf_state_key; > +struct pf_state; > struct inpcb; > > struct pkthdr_pf { > - struct pf_state_key *statekey; /* pf stackside statekey */ > + struct pf_state *st; /* pf state */ > struct inpcb *inp; /* connected pcb for outgoing packet */ > u_int32_t qid; /* queue id */ > u_int16_t tag; /* tag id */ > @@ -327,7 +327,7 @@ u_int mextfree_register(void (*)(caddr_t > (to)->m_pkthdr = (from)->m_pkthdr; \ > (from)->m_flags &= ~M_PKTHDR; \ > SLIST_INIT(&(from)->m_pkthdr.ph_tags); \ > - (from)->m_pkthdr.pf.statekey = NULL; \ > + (from)->m_pkthdr.pf.st = NULL; \ > } while (/* CONSTCOND */ 0) > > /*