On Thu, Aug 17, 2023 at 12:02:35PM +1000, David Gwynne wrote:
> there are links between the pcb/socket layer and pf as an optimisation,
> and links on mbufs between both sides of a forwarded connection.
> these links let pf skip an rb tree lookup for outgoing packets.
> 
> right now these links are between pf_state_key structs, which are the
> things that contain the actual addresses used by the connection, but you
> then have to iterate over a list in pf_state_keys to get to the pf_state
> structures.
> 
> i dont understand why we dont just link the actual pf_state structs.
> my best guess is there wasnt enough machinery (ie, refcnts and
> mtxes) on a pf_state struct to make it safe, so the compromise was
> the pf_state keys. it still got to avoid the tree lookup.

This linkage is much older than MP in pf.  There was no refcount
and mutex when added by henning@.

> i wanted this to make it easier to look up information on pf states from
> the socket layer, but sashan@ said i should send it out. i do think it
> makes things a bit easier to understand.
> 
> the most worrying bit is the change to pf_state_find().
> 
> thoughts? ok?

It took me years to get the logic correct for all corner cases.
When using pf-divert with UDP strange things start to happen.  Most
things are covered by regress/sys/net/pf_divert .  You have to setup
two machines to run it.

I don't know why it was written that way, but I know it works for
me.  What do you want to fix?

bluhm

> Index: kern/uipc_mbuf.c
> ===================================================================
> RCS file: /cvs/src/sys/kern/uipc_mbuf.c,v
> retrieving revision 1.287
> diff -u -p -r1.287 uipc_mbuf.c
> --- kern/uipc_mbuf.c  23 Jun 2023 04:36:49 -0000      1.287
> +++ kern/uipc_mbuf.c  17 Aug 2023 01:31:04 -0000
> @@ -308,7 +308,7 @@ m_clearhdr(struct mbuf *m)
>       /* delete all mbuf tags to reset the state */
>       m_tag_delete_chain(m);
>  #if NPF > 0
> -     pf_mbuf_unlink_state_key(m);
> +     pf_mbuf_unlink_state(m);
>       pf_mbuf_unlink_inpcb(m);
>  #endif       /* NPF > 0 */
>  
> @@ -440,7 +440,7 @@ m_free(struct mbuf *m)
>       if (m->m_flags & M_PKTHDR) {
>               m_tag_delete_chain(m);
>  #if NPF > 0
> -             pf_mbuf_unlink_state_key(m);
> +             pf_mbuf_unlink_state(m);
>               pf_mbuf_unlink_inpcb(m);
>  #endif       /* NPF > 0 */
>       }
> @@ -1398,8 +1398,8 @@ m_dup_pkthdr(struct mbuf *to, struct mbu
>       to->m_pkthdr = from->m_pkthdr;
>  
>  #if NPF > 0
> -     to->m_pkthdr.pf.statekey = NULL;
> -     pf_mbuf_link_state_key(to, from->m_pkthdr.pf.statekey);
> +     to->m_pkthdr.pf.st = NULL;
> +     pf_mbuf_link_state(to, from->m_pkthdr.pf.st);
>       to->m_pkthdr.pf.inp = NULL;
>       pf_mbuf_link_inpcb(to, from->m_pkthdr.pf.inp);
>  #endif       /* NPF > 0 */
> @@ -1526,8 +1526,8 @@ m_print(void *v,
>                   m->m_pkthdr.csum_flags, MCS_BITS);
>               (*pr)("m_pkthdr.ether_vtag: %u\tm_ptkhdr.ph_rtableid: %u\n",
>                   m->m_pkthdr.ether_vtag, m->m_pkthdr.ph_rtableid);
> -             (*pr)("m_pkthdr.pf.statekey: %p\tm_pkthdr.pf.inp %p\n",
> -                 m->m_pkthdr.pf.statekey, m->m_pkthdr.pf.inp);
> +             (*pr)("m_pkthdr.pf.st: %p\tm_pkthdr.pf.inp %p\n",
> +                 m->m_pkthdr.pf.st, m->m_pkthdr.pf.inp);
>               (*pr)("m_pkthdr.pf.qid: %u\tm_pkthdr.pf.tag: %u\n",
>                   m->m_pkthdr.pf.qid, m->m_pkthdr.pf.tag);
>               (*pr)("m_pkthdr.pf.flags: %b\n",
> Index: net/if_mpw.c
> ===================================================================
> RCS file: /cvs/src/sys/net/if_mpw.c,v
> retrieving revision 1.63
> diff -u -p -r1.63 if_mpw.c
> --- net/if_mpw.c      29 Aug 2022 07:51:45 -0000      1.63
> +++ net/if_mpw.c      17 Aug 2023 01:31:04 -0000
> @@ -620,7 +620,7 @@ mpw_input(struct mpw_softc *sc, struct m
>       m->m_pkthdr.ph_rtableid = ifp->if_rdomain;
>  
>       /* packet has not been processed by PF yet. */
> -     KASSERT(m->m_pkthdr.pf.statekey == NULL);
> +     KASSERT(m->m_pkthdr.pf.st == NULL);
>  
>       if_vinput(ifp, m);
>       return;
> Index: net/if_tpmr.c
> ===================================================================
> RCS file: /cvs/src/sys/net/if_tpmr.c,v
> retrieving revision 1.33
> diff -u -p -r1.33 if_tpmr.c
> --- net/if_tpmr.c     16 May 2023 14:32:54 -0000      1.33
> +++ net/if_tpmr.c     17 Aug 2023 01:31:04 -0000
> @@ -303,7 +303,7 @@ tpmr_pf(struct ifnet *ifp0, int dir, str
>               return (NULL);
>  
>       if (dir == PF_IN && ISSET(m->m_pkthdr.pf.flags, PF_TAG_DIVERTED)) {
> -             pf_mbuf_unlink_state_key(m);
> +             pf_mbuf_unlink_state(m);
>               pf_mbuf_unlink_inpcb(m);
>               (*fam->ip_input)(ifp0, m);
>               return (NULL);
> Index: net/if_veb.c
> ===================================================================
> RCS file: /cvs/src/sys/net/if_veb.c,v
> retrieving revision 1.31
> diff -u -p -r1.31 if_veb.c
> --- net/if_veb.c      16 May 2023 14:32:54 -0000      1.31
> +++ net/if_veb.c      17 Aug 2023 01:31:04 -0000
> @@ -654,7 +654,7 @@ veb_pf(struct ifnet *ifp0, int dir, stru
>               return (NULL);
>  
>       if (dir == PF_IN && ISSET(m->m_pkthdr.pf.flags, PF_TAG_DIVERTED)) {
> -             pf_mbuf_unlink_state_key(m);
> +             pf_mbuf_unlink_state(m);
>               pf_mbuf_unlink_inpcb(m);
>               (*fam->ip_input)(ifp0, m);
>               return (NULL);
> Index: net/pf.c
> ===================================================================
> RCS file: /cvs/src/sys/net/pf.c,v
> retrieving revision 1.1184
> diff -u -p -r1.1184 pf.c
> --- net/pf.c  31 Jul 2023 11:13:09 -0000      1.1184
> +++ net/pf.c  17 Aug 2023 01:31:04 -0000
> @@ -247,16 +247,17 @@ int                      pf_state_insert(struct pfi_kif 
> *,
>                           struct pf_state_key **, struct pf_state_key **,
>                           struct pf_state *);
>  
> +int                   pf_state_isvalid(struct pf_state *);
>  int                   pf_state_key_isvalid(struct pf_state_key *);
>  struct pf_state_key  *pf_state_key_ref(struct pf_state_key *);
>  void                  pf_state_key_unref(struct pf_state_key *);
> -void                  pf_state_key_link_reverse(struct pf_state_key *,
> -                         struct pf_state_key *);
> -void                  pf_state_key_unlink_reverse(struct pf_state_key *);
> -void                  pf_state_key_link_inpcb(struct pf_state_key *,
> +void                  pf_state_link_reverse(struct pf_state *,
> +                         struct pf_state *);
> +void                  pf_state_unlink_reverse(struct pf_state *);
> +void                  pf_state_link_inpcb(struct pf_state *,
>                           struct inpcb *);
> -void                  pf_state_key_unlink_inpcb(struct pf_state_key *);
> -void                  pf_inpcb_unlink_state_key(struct inpcb *);
> +void                  pf_state_unlink_inpcb(struct pf_state *);
> +void                  pf_inpcb_unlink_state(struct inpcb *);
>  void                  pf_pktenqueue_delayed(void *);
>  int32_t                       pf_state_expires(const struct pf_state *, 
> uint8_t);
>  
> @@ -852,8 +853,6 @@ pf_state_key_detach(struct pf_state *st,
>       if (TAILQ_EMPTY(&sk->sk_states)) {
>               RBT_REMOVE(pf_state_tree, &pf_statetbl, sk);
>               sk->sk_removed = 1;
> -             pf_state_key_unlink_reverse(sk);
> -             pf_state_key_unlink_inpcb(sk);
>               pf_state_key_unref(sk);
>       }
>  
> @@ -1115,13 +1114,41 @@ pf_compare_state_keys(struct pf_state_ke
>       }
>  }
>  
> +static inline struct pf_state *
> +pf_find_state_lookup(struct pf_pdesc *pd, const struct pf_state_key_cmp *key)
> +{
> +     struct pf_state_key     *sk;
> +     struct pf_state_item    *si;
> +     struct pf_state         *st;
> +     uint8_t                  dir = pd->dir;
> +
> +     sk = RBT_FIND(pf_state_tree, &pf_statetbl, (struct pf_state_key *)key);
> +     if (sk == NULL)
> +             return (NULL);
> +
> +     /* list is sorted, if-bound states before floating ones */
> +     TAILQ_FOREACH(si, &sk->sk_states, si_entry) {
> +             st = si->si_st;
> +             if (st->timeout == PFTM_PURGE)
> +                     continue;
> +             if (st->kif != pfi_all && st->kif != pd->kif)
> +                     continue;
> +
> +             if (st->key[dir == PF_IN ? PF_SK_WIRE : PF_SK_STACK] == sk)
> +                     return (st);
> +     }
> +
> +     return (NULL);
> +}
> +
>  int
>  pf_find_state(struct pf_pdesc *pd, struct pf_state_key_cmp *key,
>      struct pf_state **stp)
>  {
> -     struct pf_state_key     *sk, *pkt_sk, *inp_sk;
> -     struct pf_state_item    *si;
>       struct pf_state         *st = NULL;
> +     struct pf_state         *strev = NULL;
> +     struct inpcb            *inp = NULL;
> +     int                      rv = PF_DROP;
>  
>       pf_status.fcounters[FCNT_STATE_SEARCH]++;
>       if (pf_status.debug >= LOG_DEBUG) {
> @@ -1131,80 +1158,67 @@ pf_find_state(struct pf_pdesc *pd, struc
>               addlog("\n");
>       }
>  
> -     inp_sk = NULL;
> -     pkt_sk = NULL;
> -     sk = NULL;
>       if (pd->dir == PF_OUT) {
> +             /* take the references */
> +             strev = pd->m->m_pkthdr.pf.st;
> +             inp = pd->m->m_pkthdr.pf.inp;
> +
>               /* first if block deals with outbound forwarded packet */
> -             pkt_sk = pd->m->m_pkthdr.pf.statekey;
> +             if (strev != NULL) {
> +                     pd->m->m_pkthdr.pf.st = NULL;
> +                     KASSERT(inp == NULL);
>  
> -             if (!pf_state_key_isvalid(pkt_sk)) {
> -                     pf_mbuf_unlink_state_key(pd->m);
> -                     pkt_sk = NULL;
> -             }
> +                     if (pf_state_isvalid(strev)) {
> +                             st = strev->reverse;
> +                             if (st != NULL && pf_state_isvalid(st))
> +                                     goto match;
> +                     }
>  
> -             if (pkt_sk && pf_state_key_isvalid(pkt_sk->sk_reverse))
> -                     sk = pkt_sk->sk_reverse;
> +                     /* this handles st not being valid too */
> +                     pf_state_unlink_reverse(strev);
>  
> -             if (pkt_sk == NULL) {
> +             } else if (inp != NULL) {
>                       /* here we deal with local outbound packet */
> -                     if (pd->m->m_pkthdr.pf.inp != NULL) {
> -                             inp_sk = pd->m->m_pkthdr.pf.inp->inp_pf_sk;
> -                             if (pf_state_key_isvalid(inp_sk))
> -                                     sk = inp_sk;
> -                             else
> -                                     pf_inpcb_unlink_state_key(
> -                                         pd->m->m_pkthdr.pf.inp);
> +                     pd->m->m_pkthdr.pf.inp = NULL;
> +
> +                     st = inp->inp_pf_st;
> +                     if (st != NULL) {
> +                             if (pf_state_isvalid(st))
> +                                     goto match;
> +
> +                             pf_inpcb_unlink_state(inp);
>                       }
>               }
>       }
>  
> -     if (sk == NULL) {
> -             if ((sk = RBT_FIND(pf_state_tree, &pf_statetbl,
> -                 (struct pf_state_key *)key)) == NULL)
> -                     return (PF_DROP);
> -             if (pd->dir == PF_OUT && pkt_sk &&
> -                 pf_compare_state_keys(pkt_sk, sk, pd->kif, pd->dir) == 0)
> -                     pf_state_key_link_reverse(sk, pkt_sk);
> -             else if (pd->dir == PF_OUT && pd->m->m_pkthdr.pf.inp &&
> -                 !pd->m->m_pkthdr.pf.inp->inp_pf_sk && !sk->sk_inp)
> -                     pf_state_key_link_inpcb(sk, pd->m->m_pkthdr.pf.inp);
> -     }
> -
> -     /* remove firewall data from outbound packet */
> -     if (pd->dir == PF_OUT)
> -             pf_pkt_addr_changed(pd->m);
> +     st = pf_find_state_lookup(pd, key);
> +     if (st == NULL || ISSET(st->state_flags, PFSTATE_INP_UNLINKED))
> +             goto drop;
>  
> -     /* list is sorted, if-bound states before floating ones */
> -     TAILQ_FOREACH(si, &sk->sk_states, si_entry) {
> -             struct pf_state *sist = si->si_st;
> -             if (sist->timeout != PFTM_PURGE &&
> -                 (sist->kif == pfi_all || sist->kif == pd->kif) &&
> -                 ((sist->key[PF_SK_WIRE]->af == sist->key[PF_SK_STACK]->af &&
> -                   sk == (pd->dir == PF_IN ? sist->key[PF_SK_WIRE] :
> -                 sist->key[PF_SK_STACK])) ||
> -                 (sist->key[PF_SK_WIRE]->af != sist->key[PF_SK_STACK]->af
> -                 && pd->dir == PF_IN && (sk == sist->key[PF_SK_STACK] ||
> -                 sk == sist->key[PF_SK_WIRE])))) {
> -                     st = sist;
> -                     break;
> -             }
> +     if (pd->dir == PF_OUT) {
> +             if (strev != NULL)
> +                     pf_state_link_reverse(st, strev);
> +             else if (inp != NULL)
> +                     pf_state_link_inpcb(st, inp);
>       }
>  
> -     if (st == NULL)
> -             return (PF_DROP);
> -     if (ISSET(st->state_flags, PFSTATE_INP_UNLINKED))
> -             return (PF_DROP);
> -
> +match:
>       if (st->rule.ptr->pktrate.limit && pd->dir == st->direction) {
>               pf_add_threshold(&st->rule.ptr->pktrate);
>               if (pf_check_threshold(&st->rule.ptr->pktrate))
> -                     return (PF_DROP);
> +                     goto drop;
>       }
>  
>       *stp = st;
> +     rv = PF_MATCH;
> +
> +drop:
> +     if (strev != NULL)
> +             pf_state_unref(strev);
> +     else if (inp != NULL)
> +             in_pcbunref(inp);
>  
> -     return (PF_MATCH);
> +     return (rv);
>  }
>  
>  struct pf_state *
> @@ -1763,6 +1777,9 @@ pf_remove_state(struct pf_state *st)
>  
>       st->timeout = PFTM_UNLINKED;
>  
> +     pf_state_unlink_reverse(st);
> +     pf_state_unlink_inpcb(st);
> +
>       /* handle load balancing related tasks */
>       pf_postprocess_addr(st);
>  
> @@ -1792,38 +1809,32 @@ pf_remove_state(struct pf_state *st)
>  }
>  
>  void
> -pf_remove_divert_state(struct pf_state_key *sk)
> +pf_remove_divert_state(struct pf_state *st)
>  {
> -     struct pf_state_item    *si;
> -
>       PF_ASSERT_UNLOCKED();
>  
>       PF_LOCK();
>       PF_STATE_ENTER_WRITE();
> -     TAILQ_FOREACH(si, &sk->sk_states, si_entry) {
> -             struct pf_state *sist = si->si_st;
> -             if (sk == sist->key[PF_SK_STACK] && sist->rule.ptr &&
> -                 (sist->rule.ptr->divert.type == PF_DIVERT_TO ||
> -                  sist->rule.ptr->divert.type == PF_DIVERT_REPLY)) {
> -                     if (sist->key[PF_SK_STACK]->proto == IPPROTO_TCP &&
> -                         sist->key[PF_SK_WIRE] != sist->key[PF_SK_STACK]) {
> -                             /*
> -                              * If the local address is translated, keep
> -                              * the state for "tcp.closed" seconds to
> -                              * prevent its source port from being reused.
> -                              */
> -                             if (sist->src.state < TCPS_FIN_WAIT_2 ||
> -                                 sist->dst.state < TCPS_FIN_WAIT_2) {
> -                                     pf_set_protostate(sist, PF_PEER_BOTH,
> -                                         TCPS_TIME_WAIT);
> -                                     sist->timeout = PFTM_TCP_CLOSED;
> -                                     sist->expire = getuptime();
> -                             }
> -                             sist->state_flags |= PFSTATE_INP_UNLINKED;
> -                     } else
> -                             pf_remove_state(sist);
> -                     break;
> -             }
> +     if (st->rule.ptr &&
> +         (st->rule.ptr->divert.type == PF_DIVERT_TO ||
> +          st->rule.ptr->divert.type == PF_DIVERT_REPLY)) {
> +             if (st->key[PF_SK_STACK]->proto == IPPROTO_TCP &&
> +                 st->key[PF_SK_WIRE] != st->key[PF_SK_STACK]) {
> +                     /*
> +                      * If the local address is translated, keep
> +                      * the state for "tcp.closed" seconds to
> +                      * prevent its source port from being reused.
> +                      */
> +                     if (st->src.state < TCPS_FIN_WAIT_2 ||
> +                         st->dst.state < TCPS_FIN_WAIT_2) {
> +                             pf_set_protostate(st, PF_PEER_BOTH,
> +                                 TCPS_TIME_WAIT);
> +                             st->timeout = PFTM_TCP_CLOSED;
> +                             st->expire = getuptime();
> +                     }
> +                     st->state_flags |= PFSTATE_INP_UNLINKED;
> +             } else
> +                     pf_remove_state(st);
>       }
>       PF_STATE_EXIT_WRITE();
>       PF_UNLOCK();
> @@ -7836,17 +7847,22 @@ done:
>  
>       if (action == PF_PASS && qid)
>               pd.m->m_pkthdr.pf.qid = qid;
> -     if (pd.dir == PF_IN && st && st->key[PF_SK_STACK])
> -             pf_mbuf_link_state_key(pd.m, st->key[PF_SK_STACK]);
> -     if (pd.dir == PF_OUT &&
> -         pd.m->m_pkthdr.pf.inp && !pd.m->m_pkthdr.pf.inp->inp_pf_sk &&
> -         st && st->key[PF_SK_STACK] && !st->key[PF_SK_STACK]->sk_inp)
> -             pf_state_key_link_inpcb(st->key[PF_SK_STACK],
> -                 pd.m->m_pkthdr.pf.inp);
> -
> -     if (st != NULL && !ISSET(pd.m->m_pkthdr.csum_flags, M_FLOWID)) {
> -             pd.m->m_pkthdr.ph_flowid = st->key[PF_SK_WIRE]->hash;
> -             SET(pd.m->m_pkthdr.csum_flags, M_FLOWID);
> +     if (st != NULL) {
> +             struct mbuf *m = pd.m;
> +             struct inpcb *inp = m->m_pkthdr.pf.inp;
> +
> +             if (pd.dir == PF_IN) {
> +                     KASSERT(inp == NULL);
> +                     pf_mbuf_link_state(m, st);
> +             } else {
> +                     if (inp != NULL && inp->inp_pf_st == NULL)
> +                             pf_state_link_inpcb(st, inp);
> +             }
> +
> +             if (!ISSET(m->m_pkthdr.csum_flags, M_FLOWID)) {
> +                     m->m_pkthdr.ph_flowid = st->key[PF_SK_WIRE]->hash;
> +                     SET(m->m_pkthdr.csum_flags, M_FLOWID);
> +             }
>       }
>  
>       /*
> @@ -8004,14 +8020,14 @@ done:
>  int
>  pf_ouraddr(struct mbuf *m)
>  {
> -     struct pf_state_key     *sk;
> +     struct pf_state         *st;
>  
>       if (m->m_pkthdr.pf.flags & PF_TAG_DIVERTED)
>               return (1);
>  
> -     sk = m->m_pkthdr.pf.statekey;
> -     if (sk != NULL) {
> -             if (sk->sk_inp != NULL)
> +     st = m->m_pkthdr.pf.st;
> +     if (st != NULL) {
> +             if (st->inp != NULL)
>                       return (1);
>       }
>  
> @@ -8025,7 +8041,7 @@ pf_ouraddr(struct mbuf *m)
>  void
>  pf_pkt_addr_changed(struct mbuf *m)
>  {
> -     pf_mbuf_unlink_state_key(m);
> +     pf_mbuf_unlink_state(m);
>       pf_mbuf_unlink_inpcb(m);
>  }
>  
> @@ -8033,71 +8049,56 @@ struct inpcb *
>  pf_inp_lookup(struct mbuf *m)
>  {
>       struct inpcb *inp = NULL;
> -     struct pf_state_key *sk = m->m_pkthdr.pf.statekey;
> +     struct pf_state *st;
>  
> -     if (!pf_state_key_isvalid(sk))
> -             pf_mbuf_unlink_state_key(m);
> -     else
> -             inp = m->m_pkthdr.pf.statekey->sk_inp;
> +     st = m->m_pkthdr.pf.st;
> +     if (st == NULL)
> +             return (NULL);
> +     if (!pf_state_isvalid(st)) {
> +             pf_mbuf_unlink_state(m);
> +             return (NULL);
> +     }
> +
> +     inp = st->inp;
> +     if (inp == NULL)
> +             return (NULL);
>  
> -     if (inp && inp->inp_pf_sk)
> -             KASSERT(m->m_pkthdr.pf.statekey == inp->inp_pf_sk);
> +     KASSERT(inp->inp_pf_st == NULL || inp->inp_pf_st == st);
>  
> -     in_pcbref(inp);
> -     return (inp);
> +     return (in_pcbref(inp));
>  }
>  
> +/*
> + * This is called from the IP stack after it's found an inpcb for
> + * an mbuf so it can link the pf_state to that pcb.
> + */
>  void
>  pf_inp_link(struct mbuf *m, struct inpcb *inp)
>  {
> -     struct pf_state_key *sk = m->m_pkthdr.pf.statekey;
> +     struct pf_state *st;
>  
> -     if (!pf_state_key_isvalid(sk)) {
> -             pf_mbuf_unlink_state_key(m);
> +     st = m->m_pkthdr.pf.st;
> +     if (st == NULL)
>               return;
> -     }
>  
>       /*
>        * we don't need to grab PF-lock here. At worst case we link inp to
>        * state, which might be just being marked as deleted by another
>        * thread.
>        */
> -     if (inp && !sk->sk_inp && !inp->inp_pf_sk)
> -             pf_state_key_link_inpcb(sk, inp);
> +     if (pf_state_isvalid(st)) {
> +             if (st->inp == NULL && inp->inp_pf_st == NULL)
> +                     pf_state_link_inpcb(st, inp);
> +     }
>  
>       /* The statekey has finished finding the inp, it is no longer needed. */
> -     pf_mbuf_unlink_state_key(m);
> +     pf_mbuf_unlink_state(m);
>  }
>  
>  void
>  pf_inp_unlink(struct inpcb *inp)
>  {
> -     pf_inpcb_unlink_state_key(inp);
> -}
> -
> -void
> -pf_state_key_link_reverse(struct pf_state_key *sk, struct pf_state_key 
> *skrev)
> -{
> -     struct pf_state_key *old_reverse;
> -
> -     old_reverse = atomic_cas_ptr(&sk->sk_reverse, NULL, skrev);
> -     if (old_reverse != NULL)
> -             KASSERT(old_reverse == skrev);
> -     else {
> -             pf_state_key_ref(skrev);
> -
> -             /*
> -              * NOTE: if sk == skrev, then KASSERT() below holds true, we
> -              * still want to grab a reference in such case, because
> -              * pf_state_key_unlink_reverse() does not check whether keys
> -              * are identical or not.
> -              */
> -             old_reverse = atomic_cas_ptr(&skrev->sk_reverse, NULL, sk);
> -             if (old_reverse != NULL)
> -                     KASSERT(old_reverse == sk);
> -
> -             pf_state_key_ref(sk);
> -     }
> +     pf_inpcb_unlink_state(inp);
>  }
>  
>  #if NPFLOG > 0
> @@ -8132,10 +8133,6 @@ pf_state_key_unref(struct pf_state_key *
>       if (PF_REF_RELE(sk->sk_refcnt)) {
>               /* state key must be removed from tree */
>               KASSERT(!pf_state_key_isvalid(sk));
> -             /* state key must be unlinked from reverse key */
> -             KASSERT(sk->sk_reverse == NULL);
> -             /* state key must be unlinked from socket */
> -             KASSERT(sk->sk_inp == NULL);
>               pool_put(&pf_state_key_pl, sk);
>       }
>  }
> @@ -8146,21 +8143,28 @@ pf_state_key_isvalid(struct pf_state_key
>       return ((sk != NULL) && (sk->sk_removed == 0));
>  }
>  
> +int
> +pf_state_isvalid(struct pf_state *st)
> +{
> +     return (st->timeout < PFTM_MAX);
> +}
> +
>  void
> -pf_mbuf_link_state_key(struct mbuf *m, struct pf_state_key *sk)
> +pf_mbuf_link_state(struct mbuf *m, struct pf_state *st)
>  {
> -     KASSERT(m->m_pkthdr.pf.statekey == NULL);
> -     m->m_pkthdr.pf.statekey = pf_state_key_ref(sk);
> +     KASSERT(m->m_pkthdr.pf.st == NULL);
> +     m->m_pkthdr.pf.st = pf_state_ref(st);
>  }
>  
>  void
> -pf_mbuf_unlink_state_key(struct mbuf *m)
> +pf_mbuf_unlink_state(struct mbuf *m)
>  {
> -     struct pf_state_key *sk = m->m_pkthdr.pf.statekey;
> +     struct pf_state *st;
>  
> -     if (sk != NULL) {
> -             m->m_pkthdr.pf.statekey = NULL;
> -             pf_state_key_unref(sk);
> +     st = m->m_pkthdr.pf.st;
> +     if (st != NULL) {
> +             m->m_pkthdr.pf.st = NULL;
> +             pf_state_unref(st);
>       }
>  }
>  
> @@ -8174,64 +8178,107 @@ pf_mbuf_link_inpcb(struct mbuf *m, struc
>  void
>  pf_mbuf_unlink_inpcb(struct mbuf *m)
>  {
> -     struct inpcb *inp = m->m_pkthdr.pf.inp;
> +     struct inpcb *inp;
>  
> +     inp = m->m_pkthdr.pf.inp;
>       if (inp != NULL) {
>               m->m_pkthdr.pf.inp = NULL;
>               in_pcbunref(inp);
>       }
>  }
>  
> +/* assumes caller has an exclusive lock around inp */
>  void
> -pf_state_key_link_inpcb(struct pf_state_key *sk, struct inpcb *inp)
> +pf_state_link_inpcb(struct pf_state *st, struct inpcb *inp)
>  {
> -     KASSERT(sk->sk_inp == NULL);
> -     sk->sk_inp = in_pcbref(inp);
> -     KASSERT(inp->inp_pf_sk == NULL);
> -     inp->inp_pf_sk = pf_state_key_ref(sk);
> +     KASSERT(inp->inp_pf_st == NULL);
> +     inp->inp_pf_st = pf_state_ref(st);
> +
> +     mtx_enter(&st->mtx);
> +     KASSERT(st->inp == NULL);
> +     st->inp = in_pcbref(inp);
> +     mtx_leave(&st->mtx);
>  }
>  
> +/* assumes caller has an exclusive lock around inp */
>  void
> -pf_inpcb_unlink_state_key(struct inpcb *inp)
> +pf_inpcb_unlink_state(struct inpcb *inp)
>  {
> -     struct pf_state_key *sk = inp->inp_pf_sk;
> +     struct pf_state *st;
>  
> -     if (sk != NULL) {
> -             KASSERT(sk->sk_inp == inp);
> -             sk->sk_inp = NULL;
> -             inp->inp_pf_sk = NULL;
> -             pf_state_key_unref(sk);
> +     st = inp->inp_pf_st;
> +     if (st != NULL) {
> +             inp->inp_pf_st = NULL;
> +
> +             mtx_enter(&st->mtx);
> +             KASSERT(st->inp == inp);
> +             st->inp = NULL;
> +             mtx_leave(&st->mtx);
>               in_pcbunref(inp);
> +
> +             pf_state_unref(st);
>       }
>  }
>  
>  void
> -pf_state_key_unlink_inpcb(struct pf_state_key *sk)
> +pf_state_unlink_inpcb(struct pf_state *st)
>  {
> -     struct inpcb *inp = sk->sk_inp;
> +     struct inpcb *inp;
> +
> +     mtx_enter(&st->mtx);
> +     inp = st->inp;
> +     if (inp != NULL)
> +             st->inp = NULL;
> +     mtx_leave(&st->mtx);
>  
> +     /* XXX wtf lock? */
>       if (inp != NULL) {
> -             KASSERT(inp->inp_pf_sk == sk);
> -             sk->sk_inp = NULL;
> -             inp->inp_pf_sk = NULL;
> -             pf_state_key_unref(sk);
> +             KASSERT(inp->inp_pf_st == st);
> +             inp->inp_pf_st = NULL;
> +
> +             pf_state_unref(st);
>               in_pcbunref(inp);
>       }
>  }
>  
>  void
> -pf_state_key_unlink_reverse(struct pf_state_key *sk)
> +pf_state_link_reverse(struct pf_state *st, struct pf_state *strev)
>  {
> -     struct pf_state_key *skrev = sk->sk_reverse;
> +     mtx_enter(&st->mtx);
> +     if (st->reverse == NULL)
> +             st->reverse = pf_state_ref(strev);
> +     mtx_leave(&st->mtx);
>  
> -     /* Note that sk and skrev may be equal, then we unref twice. */
> -     if (skrev != NULL) {
> -             KASSERT(skrev->sk_reverse == sk);
> -             sk->sk_reverse = NULL;
> -             skrev->sk_reverse = NULL;
> -             pf_state_key_unref(skrev);
> -             pf_state_key_unref(sk);
> -     }
> +     mtx_enter(&strev->mtx);
> +     if (strev->reverse == NULL)
> +             strev->reverse = pf_state_ref(st);
> +     mtx_leave(&strev->mtx);
> +}
> +
> +void
> +pf_state_unlink_reverse(struct pf_state *st)
> +{
> +     struct pf_state *strev;
> +
> +     mtx_enter(&st->mtx);
> +     strev = st->reverse;
> +     if (strev != NULL)
> +             st->reverse = NULL; /* take over strev reference */
> +     mtx_leave(&st->mtx);
> +
> +     if (strev == NULL)
> +             return;
> +
> +     mtx_enter(&strev->mtx);
> +     if (strev->reverse == st)
> +             strev->reverse = NULL;
> +     else
> +             st = NULL;
> +     mtx_leave(&strev->mtx);
> +
> +     pf_state_unref(strev); /* drop the reference we just inherited */
> +     if (st != NULL)
> +             pf_state_unref(st); /* drop the reference strev had */
>  }
>  
>  struct pf_state *
> @@ -8257,6 +8304,11 @@ pf_state_unref(struct pf_state *st)
>  
>               pf_state_key_unref(st->key[PF_SK_WIRE]);
>               pf_state_key_unref(st->key[PF_SK_STACK]);
> +
> +             /* state must be unlinked from reverse */
> +             KASSERT(st->reverse == NULL);
> +             /* state must be unlinked from socket */
> +             KASSERT(st->inp == NULL);
>  
>               pool_put(&pf_state_pl, st);
>       }
> Index: net/pfvar.h
> ===================================================================
> RCS file: /cvs/src/sys/net/pfvar.h,v
> retrieving revision 1.533
> diff -u -p -r1.533 pfvar.h
> --- net/pfvar.h       6 Jul 2023 04:55:05 -0000       1.533
> +++ net/pfvar.h       17 Aug 2023 01:31:04 -0000
> @@ -1606,7 +1606,7 @@ extern void                      
> pf_calc_skip_steps(struct
>  extern void                   pf_purge_expired_src_nodes(void);
>  extern void                   pf_purge_expired_rules(void);
>  extern void                   pf_remove_state(struct pf_state *);
> -extern void                   pf_remove_divert_state(struct pf_state_key *);
> +extern void                   pf_remove_divert_state(struct pf_state *);
>  extern void                   pf_free_state(struct pf_state *);
>  int                           pf_insert_src_node(struct pf_src_node **,
>                                   struct pf_rule *, enum pf_sn_types,
> @@ -1860,9 +1860,8 @@ int                      pf_map_addr(sa_family_t, 
> struct p
>                           struct pf_pool *, enum pf_sn_types);
>  int                   pf_postprocess_addr(struct pf_state *);
>  
> -void                  pf_mbuf_link_state_key(struct mbuf *,
> -                         struct pf_state_key *);
> -void                  pf_mbuf_unlink_state_key(struct mbuf *);
> +void                  pf_mbuf_link_state(struct mbuf *, struct pf_state *);
> +void                  pf_mbuf_unlink_state(struct mbuf *);
>  void                  pf_mbuf_link_inpcb(struct mbuf *, struct inpcb *);
>  void                  pf_mbuf_unlink_inpcb(struct mbuf *);
>  
> Index: net/pfvar_priv.h
> ===================================================================
> RCS file: /cvs/src/sys/net/pfvar_priv.h,v
> retrieving revision 1.34
> diff -u -p -r1.34 pfvar_priv.h
> --- net/pfvar_priv.h  6 Jul 2023 04:55:05 -0000       1.34
> +++ net/pfvar_priv.h  17 Aug 2023 01:31:04 -0000
> @@ -69,8 +69,6 @@ struct pf_state_key {
>  
>       RB_ENTRY(pf_state_key)   sk_entry;
>       struct pf_statelisthead  sk_states;
> -     struct pf_state_key     *sk_reverse;
> -     struct inpcb            *sk_inp;
>       pf_refcnt_t              sk_refcnt;
>       u_int8_t                 sk_removed;
>  };
> @@ -115,6 +113,8 @@ struct pf_state {
>       struct pf_sn_head        src_nodes;     /* [I] */
>       struct pf_state_key     *key[2];        /* [I] stack and wire */
>       struct pfi_kif          *kif;           /* [I] */
> +     struct pf_state         *reverse;       /* [M] */
> +     struct inpcb            *inp;           /* [M] */
>       struct mutex             mtx;
>       pf_refcnt_t              refcnt;
>       u_int64_t                packets[2];
> Index: netinet/in_pcb.c
> ===================================================================
> RCS file: /cvs/src/sys/netinet/in_pcb.c,v
> retrieving revision 1.277
> diff -u -p -r1.277 in_pcb.c
> --- netinet/in_pcb.c  24 Jun 2023 20:54:46 -0000      1.277
> +++ netinet/in_pcb.c  17 Aug 2023 01:31:04 -0000
> @@ -538,8 +538,8 @@ void
>  in_pcbdisconnect(struct inpcb *inp)
>  {
>  #if NPF > 0
> -     if (inp->inp_pf_sk) {
> -             pf_remove_divert_state(inp->inp_pf_sk);
> +     if (inp->inp_pf_st) {
> +             pf_remove_divert_state(inp->inp_pf_st);
>               /* pf_remove_divert_state() may have detached the state */
>               pf_inp_unlink(inp);
>       }
> @@ -588,8 +588,8 @@ in_pcbdetach(struct inpcb *inp)
>  #endif
>               ip_freemoptions(inp->inp_moptions);
>  #if NPF > 0
> -     if (inp->inp_pf_sk) {
> -             pf_remove_divert_state(inp->inp_pf_sk);
> +     if (inp->inp_pf_st) {
> +             pf_remove_divert_state(inp->inp_pf_st);
>               /* pf_remove_divert_state() may have detached the state */
>               pf_inp_unlink(inp);
>       }
> Index: netinet/in_pcb.h
> ===================================================================
> RCS file: /cvs/src/sys/netinet/in_pcb.h,v
> retrieving revision 1.136
> diff -u -p -r1.136 in_pcb.h
> --- netinet/in_pcb.h  24 Jun 2023 20:54:46 -0000      1.136
> +++ netinet/in_pcb.h  17 Aug 2023 01:31:04 -0000
> @@ -84,7 +84,7 @@
>   *   p       inpcb_mtx               pcb mutex
>   */
>  
> -struct pf_state_key;
> +struct pf_state;
>  
>  union inpaddru {
>       struct in6_addr iau_addr6;
> @@ -155,7 +155,7 @@ struct inpcb {
>  #define inp_csumoffset       inp_cksum6
>  #endif
>       struct  icmp6_filter *inp_icmp6filt;
> -     struct  pf_state_key *inp_pf_sk;
> +     struct  pf_state *inp_pf_st;
>       struct  mbuf *(*inp_upcall)(void *, struct mbuf *,
>                   struct ip *, struct ip6_hdr *, void *, int);
>       void    *inp_upcall_arg;
> Index: sys/mbuf.h
> ===================================================================
> RCS file: /cvs/src/sys/sys/mbuf.h,v
> retrieving revision 1.261
> diff -u -p -r1.261 mbuf.h
> --- sys/mbuf.h        16 Jul 2023 03:01:31 -0000      1.261
> +++ sys/mbuf.h        17 Aug 2023 01:31:04 -0000
> @@ -92,11 +92,11 @@ struct m_hdr {
>  };
>  
>  /* pf stuff */
> -struct pf_state_key;
> +struct pf_state;
>  struct inpcb;
>  
>  struct pkthdr_pf {
> -     struct pf_state_key *statekey;  /* pf stackside statekey */
> +     struct pf_state *st;            /* pf state */
>       struct inpcb    *inp;           /* connected pcb for outgoing packet */
>       u_int32_t        qid;           /* queue id */
>       u_int16_t        tag;           /* tag id */
> @@ -327,7 +327,7 @@ u_int mextfree_register(void (*)(caddr_t
>       (to)->m_pkthdr = (from)->m_pkthdr;                              \
>       (from)->m_flags &= ~M_PKTHDR;                                   \
>       SLIST_INIT(&(from)->m_pkthdr.ph_tags);                          \
> -     (from)->m_pkthdr.pf.statekey = NULL;                            \
> +     (from)->m_pkthdr.pf.st = NULL;                                  \
>  } while (/* CONSTCOND */ 0)
>  
>  /*

Reply via email to