On Tue Apr 2, 2024 at 7:16 PM CEST, Stefan Hanreich wrote: > Additionally we implement FromStr for all rule types and parts, which > can be used for parsing firewall config rules. Initial rule parsing > works by parsing the different options into a HashMap and only then > de-serializing a struct from the parsed options. > > This intermediate step makes rule parsing a lot easier, since we can > reuse the deserialization logic from serde. Also, we can split the > parsing/deserialization logic from the validation logic. > > Co-authored-by: Wolfgang Bumiller <w.bumil...@proxmox.com> > Signed-off-by: Stefan Hanreich <s.hanre...@proxmox.com> > --- > proxmox-ve-config/src/firewall/parse.rs | 185 ++++ > proxmox-ve-config/src/firewall/types/mod.rs | 3 + > proxmox-ve-config/src/firewall/types/rule.rs | 412 ++++++++ > .../src/firewall/types/rule_match.rs | 953 ++++++++++++++++++ > 4 files changed, 1553 insertions(+) > create mode 100644 proxmox-ve-config/src/firewall/types/rule.rs > create mode 100644 proxmox-ve-config/src/firewall/types/rule_match.rs > > diff --git a/proxmox-ve-config/src/firewall/parse.rs > b/proxmox-ve-config/src/firewall/parse.rs > index 669623b..227e045 100644 > --- a/proxmox-ve-config/src/firewall/parse.rs > +++ b/proxmox-ve-config/src/firewall/parse.rs > @@ -1,3 +1,5 @@ > +use std::fmt; > + > use anyhow::{bail, format_err, Error}; > > /// Parses out a "name" which can be alphanumeric and include dashes. > @@ -78,3 +80,186 @@ pub fn parse_bool(value: &str) -> Result<bool, Error> { > }, > ) > } > + > +/// `&str` deserializer which also accepts an `Option`. > +/// > +/// Serde's `StringDeserializer` does not. > +#[derive(Clone, Copy, Debug)] > +pub struct SomeStrDeserializer<'a, E>(serde::de::value::StrDeserializer<'a, > E>); > + > +impl<'de, 'a, E> serde::de::Deserializer<'de> for SomeStrDeserializer<'a, E> > +where > + E: serde::de::Error, > +{ > + type Error = E; > + > + fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error> > + where > + V: serde::de::Visitor<'de>, > + { > + self.0.deserialize_any(visitor) > + } > + > + fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, > Self::Error> > + where > + V: serde::de::Visitor<'de>, > + { > + visitor.visit_some(self.0) > + } > + > + fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error> > + where > + V: serde::de::Visitor<'de>, > + { > + self.0.deserialize_str(visitor) > + } > + > + fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, > Self::Error> > + where > + V: serde::de::Visitor<'de>, > + { > + self.0.deserialize_string(visitor) > + } > + > + fn deserialize_enum<V>( > + self, > + _name: &str, > + _variants: &'static [&'static str], > + visitor: V, > + ) -> Result<V::Value, Self::Error> > + where > + V: serde::de::Visitor<'de>, > + { > + visitor.visit_enum(self.0) > + } > + > + serde::forward_to_deserialize_any! { > + bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char > + bytes byte_buf unit unit_struct newtype_struct seq tuple > + tuple_struct map struct identifier ignored_any > + } > +} > + > +/// `&str` wrapper which implements `IntoDeserializer` via > `SomeStrDeserializer`. > +#[derive(Clone, Debug)] > +pub struct SomeStr<'a>(pub &'a str); > + > +impl<'a> From<&'a str> for SomeStr<'a> { > + fn from(s: &'a str) -> Self { > + Self(s) > + } > +} > + > +impl<'de, 'a, E> serde::de::IntoDeserializer<'de, E> for SomeStr<'a> > +where > + E: serde::de::Error, > +{ > + type Deserializer = SomeStrDeserializer<'a, E>; > + > + fn into_deserializer(self) -> Self::Deserializer { > + SomeStrDeserializer(self.0.into_deserializer()) > + } > +} > + > +/// `String` deserializer which also accepts an `Option`. > +/// > +/// Serde's `StringDeserializer` does not. > +#[derive(Clone, Debug)] > +pub struct > SomeStringDeserializer<E>(serde::de::value::StringDeserializer<E>); > + > +impl<'de, E> serde::de::Deserializer<'de> for SomeStringDeserializer<E> > +where > + E: serde::de::Error, > +{ > + type Error = E; > + > + fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error> > + where > + V: serde::de::Visitor<'de>, > + { > + self.0.deserialize_any(visitor) > + } > + > + fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, > Self::Error> > + where > + V: serde::de::Visitor<'de>, > + { > + visitor.visit_some(self.0) > + } > + > + fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error> > + where > + V: serde::de::Visitor<'de>, > + { > + self.0.deserialize_str(visitor) > + } > + > + fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, > Self::Error> > + where > + V: serde::de::Visitor<'de>, > + { > + self.0.deserialize_string(visitor) > + } > + > + fn deserialize_enum<V>( > + self, > + _name: &str, > + _variants: &'static [&'static str], > + visitor: V, > + ) -> Result<V::Value, Self::Error> > + where > + V: serde::de::Visitor<'de>, > + { > + visitor.visit_enum(self.0) > + } > + > + serde::forward_to_deserialize_any! { > + bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char > + bytes byte_buf unit unit_struct newtype_struct seq tuple > + tuple_struct map struct identifier ignored_any > + } > +} > + > +/// `&str` wrapper which implements `IntoDeserializer` via > `SomeStringDeserializer`. > +#[derive(Clone, Debug)] > +pub struct SomeString(pub String); > + > +impl From<&str> for SomeString { > + fn from(s: &str) -> Self { > + Self::from(s.to_string()) > + } > +} > + > +impl From<String> for SomeString { > + fn from(s: String) -> Self { > + Self(s) > + } > +} > + > +impl<'de, E> serde::de::IntoDeserializer<'de, E> for SomeString > +where > + E: serde::de::Error, > +{ > + type Deserializer = SomeStringDeserializer<E>; > + > + fn into_deserializer(self) -> Self::Deserializer { > + SomeStringDeserializer(self.0.into_deserializer()) > + } > +} > + > +#[derive(Debug)] > +pub struct SerdeStringError(String); > + > +impl std::error::Error for SerdeStringError {} > + > +impl fmt::Display for SerdeStringError { > + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { > + f.write_str(&self.0) > + } > +} > + > +impl serde::de::Error for SerdeStringError { > + fn custom<T: fmt::Display>(msg: T) -> Self { > + Self(msg.to_string()) > + } > +} > diff --git a/proxmox-ve-config/src/firewall/types/mod.rs > b/proxmox-ve-config/src/firewall/types/mod.rs > index 5833787..b4a6b12 100644 > --- a/proxmox-ve-config/src/firewall/types/mod.rs > +++ b/proxmox-ve-config/src/firewall/types/mod.rs > @@ -3,7 +3,10 @@ pub mod alias; > pub mod ipset; > pub mod log; > pub mod port; > +pub mod rule; > +pub mod rule_match; > > pub use address::Cidr; > pub use alias::Alias; > pub use ipset::Ipset; > +pub use rule::Rule; > diff --git a/proxmox-ve-config/src/firewall/types/rule.rs > b/proxmox-ve-config/src/firewall/types/rule.rs > new file mode 100644 > index 0000000..20deb3a > --- /dev/null > +++ b/proxmox-ve-config/src/firewall/types/rule.rs > @@ -0,0 +1,412 @@ > +use core::fmt::Display; > +use std::fmt; > +use std::str::FromStr; > + > +use anyhow::{bail, ensure, format_err, Error}; > + > +use crate::firewall::parse::match_name; > +use crate::firewall::types::rule_match::RuleMatch; > +use crate::firewall::types::rule_match::RuleOptions; > + > +#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] > +pub enum Direction { > + #[default] > + In, > + Out, > +} > + > +impl std::str::FromStr for Direction { > + type Err = Error; > + > + fn from_str(s: &str) -> Result<Self, Error> { > + for (name, dir) in [("IN", Direction::In), ("OUT", Direction::Out)] { > + if s.eq_ignore_ascii_case(name) { > + return Ok(dir); > + } > + } > + > + bail!("invalid direction: {s:?}, expect 'IN' or 'OUT'"); > + } > +} > + > +serde_plain::derive_deserialize_from_fromstr!(Direction, "valid packet > direction"); > + > +impl fmt::Display for Direction { > + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { > + match self { > + Direction::In => f.write_str("in"), > + Direction::Out => f.write_str("out"), > + } > + } > +} > + > +#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] > +pub enum Verdict { > + Accept, > + Reject, > + #[default] > + Drop, > +} > + > +impl std::str::FromStr for Verdict { > + type Err = Error; > + > + fn from_str(s: &str) -> Result<Self, Error> { > + for (name, verdict) in [ > + ("ACCEPT", Verdict::Accept), > + ("REJECT", Verdict::Reject), > + ("DROP", Verdict::Drop), > + ] { > + if s.eq_ignore_ascii_case(name) { > + return Ok(verdict); > + } > + } > + bail!("invalid verdict {s:?}, expected one of 'ACCEPT', 'REJECT' or > 'DROP'"); > + } > +} > + > +impl Display for Verdict { > + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { > + let string = match self { > + Verdict::Accept => "ACCEPT", > + Verdict::Drop => "DROP", > + Verdict::Reject => "REJECT", > + }; > + > + write!(f, "{string}") > + } > +} > + > +serde_plain::derive_deserialize_from_fromstr!(Verdict, "valid verdict"); > + > +#[derive(Clone, Debug)] > +#[cfg_attr(test, derive(Eq, PartialEq))] > +pub struct Rule { > + pub(crate) disabled: bool, > + pub(crate) kind: Kind, > + pub(crate) comment: Option<String>, > +} > + > +impl std::ops::Deref for Rule { > + type Target = Kind; > + > + fn deref(&self) -> &Self::Target { > + &self.kind > + } > +} > + > +impl std::ops::DerefMut for Rule { > + fn deref_mut(&mut self) -> &mut Self::Target { > + &mut self.kind > + } > +} > + > +impl FromStr for Rule { > + type Err = Error; > + > + fn from_str(input: &str) -> Result<Self, Self::Err> { > + if input.contains(['\n', '\r']) { > + bail!("rule must not contain any newlines!"); > + } > + > + let (line, comment) = match input.rsplit_once('#') { > + Some((line, comment)) if !comment.is_empty() => (line.trim(), > Some(comment.trim())), > + _ => (input.trim(), None), > + }; > + > + let (disabled, line) = match line.strip_prefix('|') { > + Some(line) => (true, line.trim_start()), > + None => (false, line), > + }; > + > + // todo: case insensitive? > + let kind = if line.starts_with("GROUP") { > + Kind::from(line.parse::<RuleGroup>()?) > + } else { > + Kind::from(line.parse::<RuleMatch>()?) > + }; > + > + Ok(Self { > + disabled, > + comment: comment.map(str::to_string), > + kind, > + }) > + } > +} > + > +impl Rule { > + pub fn iface(&self) -> Option<&str> { > + match &self.kind { > + Kind::Group(group) => group.iface(), > + Kind::Match(rule) => rule.iface(), > + } > + } > + > + pub fn disabled(&self) -> bool { > + self.disabled > + } > + > + pub fn kind(&self) -> &Kind { > + &self.kind > + } > + > + pub fn comment(&self) -> Option<&str> { > + self.comment.as_deref() > + } > +} > + > +#[derive(Clone, Debug)] > +#[cfg_attr(test, derive(Eq, PartialEq))] > +pub enum Kind { > + Group(RuleGroup), > + Match(RuleMatch), > +} > + > +impl Kind { > + pub fn is_group(&self) -> bool { > + matches!(self, Kind::Group(_)) > + } > + > + pub fn is_match(&self) -> bool { > + matches!(self, Kind::Match(_)) > + } > +} > + > +impl From<RuleGroup> for Kind { > + fn from(value: RuleGroup) -> Self { > + Kind::Group(value) > + } > +} > + > +impl From<RuleMatch> for Kind { > + fn from(value: RuleMatch) -> Self { > + Kind::Match(value) > + } > +} > + > +#[derive(Clone, Debug)] > +#[cfg_attr(test, derive(Eq, PartialEq))] > +pub struct RuleGroup { > + pub(crate) group: String, > + pub(crate) iface: Option<String>, > +} > + > +impl RuleGroup { > + pub(crate) fn from_options(group: String, options: RuleOptions) -> > Result<Self, Error> { > + ensure!( > + options.proto.is_none() > + && options.dport.is_none() > + && options.sport.is_none() > + && options.dest.is_none() > + && options.source.is_none() > + && options.log.is_none() > + && options.icmp_type.is_none(), > + "only interface parameter is permitted for group rules" > + ); > + > + Ok(Self { > + group, > + iface: options.iface, > + }) > + } > + > + pub fn group(&self) -> &str { > + &self.group > + } > + > + pub fn iface(&self) -> Option<&str> { > + self.iface.as_deref() > + } > +} > + > +impl FromStr for RuleGroup { > + type Err = Error; > + > + fn from_str(input: &str) -> Result<Self, Self::Err> { > + let (keyword, rest) = match_name(input) > + .ok_or_else(|| format_err!("expected a leading keyword in rule > group"))?; > + > + if !keyword.eq_ignore_ascii_case("group") { > + bail!("Expected keyword GROUP") > + } > + > + let (name, rest) = > + match_name(rest.trim()).ok_or_else(|| format_err!("expected a > name for rule group"))?; > + > + let options = rest.trim_start().parse()?; > + > + Self::from_options(name.to_string(), options) > + } > +} > + > +#[cfg(test)] > +mod tests { > + use crate::firewall::types::{ > + address::{IpEntry, IpList}, > + alias::{AliasName, AliasScope}, > + ipset::{IpsetName, IpsetScope}, > + log::LogLevel, > + rule_match::{Icmp, IcmpCode, IpAddrMatch, IpMatch, Ports, Protocol, > Udp}, > + Cidr, > + }; > + > + use super::*; > + > + #[test] > + fn test_parse_rule() { > + let mut rule: Rule = "|GROUP tgr -i eth0 # > acomm".parse().expect("valid rule"); > + > + assert_eq!( > + rule, > + Rule { > + disabled: true, > + comment: Some("acomm".to_string()), > + kind: Kind::Group(RuleGroup { > + group: "tgr".to_string(), > + iface: Some("eth0".to_string()), > + }), > + }, > + ); > + > + rule = "IN ACCEPT -p udp -dport 33 -sport 22 -log warning" > + .parse() > + .expect("valid rule"); > + > + assert_eq!( > + rule, > + Rule { > + disabled: false, > + comment: None, > + kind: Kind::Match(RuleMatch { > + dir: Direction::In, > + verdict: Verdict::Accept, > + proto: Some(Udp::new(Ports::from_u16(22, 33)).into()), > + log: Some(LogLevel::Warning), > + ..Default::default() > + }), > + } > + ); > + > + rule = "IN ACCEPT --proto udp -i eth0".parse().expect("valid rule"); > + > + assert_eq!( > + rule, > + Rule { > + disabled: false, > + comment: None, > + kind: Kind::Match(RuleMatch { > + dir: Direction::In, > + verdict: Verdict::Accept, > + proto: Some(Udp::new(Ports::new(None, None)).into()), > + iface: Some("eth0".to_string()), > + ..Default::default() > + }), > + } > + ); > + > + rule = " OUT DROP \ > + -source 10.0.0.0/24 -dest 20.0.0.0-20.255.255.255,192.168.0.0/16 \ > + -p icmp -log nolog -icmp-type port-unreachable " > + .parse() > + .expect("valid rule"); > + > + assert_eq!( > + rule, > + Rule { > + disabled: false, > + comment: None, > + kind: Kind::Match(RuleMatch { > + dir: Direction::Out, > + verdict: Verdict::Drop, > + ip: IpMatch::new( > + IpAddrMatch::Ip(IpList::from(Cidr::new_v4([10, 0, 0, > 0], 24).unwrap())), > + IpAddrMatch::Ip( > + IpList::new(vec![ > + IpEntry::Range([20, 0, 0, 0].into(), [20, > 255, 255, 255].into()), > + IpEntry::Cidr(Cidr::new_v4([192, 168, 0, 0], > 16).unwrap()), > + ]) > + .unwrap() > + ), > + ) > + .ok(), > + proto: > Some(Protocol::Icmp(Icmp::new_code(IcmpCode::Named( > + "port-unreachable" > + )))), > + log: Some(LogLevel::Nolog), > + ..Default::default() > + }), > + } > + ); > + > + rule = "IN BGP(ACCEPT) --log crit --iface eth0" > + .parse() > + .expect("valid rule"); > + > + assert_eq!( > + rule, > + Rule { > + disabled: false, > + comment: None, > + kind: Kind::Match(RuleMatch { > + dir: Direction::In, > + verdict: Verdict::Accept, > + log: Some(LogLevel::Critical), > + fw_macro: Some("BGP".to_string()), > + iface: Some("eth0".to_string()), > + ..Default::default() > + }), > + } > + ); > + > + rule = "IN ACCEPT --source dc/test --dest +dc/test" > + .parse() > + .expect("valid rule"); > + > + assert_eq!( > + rule, > + Rule { > + disabled: false, > + comment: None, > + kind: Kind::Match(RuleMatch { > + dir: Direction::In, > + verdict: Verdict::Accept, > + ip: Some( > + IpMatch::new( > + > IpAddrMatch::Alias(AliasName::new(AliasScope::Datacenter, "test")), > + > IpAddrMatch::Set(IpsetName::new(IpsetScope::Datacenter, "test"),), > + ) > + .unwrap() > + ), > + ..Default::default() > + }), > + } > + ); > + > + rule = "IN REJECT".parse().expect("valid rule"); > + > + assert_eq!( > + rule, > + Rule { > + disabled: false, > + comment: None, > + kind: Kind::Match(RuleMatch { > + dir: Direction::In, > + verdict: Verdict::Reject, > + ..Default::default() > + }), > + } > + ); > + > + "IN DROP ---log crit" > + .parse::<Rule>() > + .expect_err("too many dashes in option"); > + > + "IN DROP --log --iface eth0" > + .parse::<Rule>() > + .expect_err("no value for option"); > + > + "IN DROP --log crit --iface" > + .parse::<Rule>() > + .expect_err("no value for option"); > + } > +} > diff --git a/proxmox-ve-config/src/firewall/types/rule_match.rs > b/proxmox-ve-config/src/firewall/types/rule_match.rs > new file mode 100644 > index 0000000..ae5345c > --- /dev/null > +++ b/proxmox-ve-config/src/firewall/types/rule_match.rs > @@ -0,0 +1,953 @@ > +use std::collections::HashMap; > +use std::fmt; > +use std::str::FromStr; > + > +use serde::Deserialize; > + > +use anyhow::{bail, format_err, Error}; > +use serde::de::IntoDeserializer; > + > +use crate::firewall::parse::{match_name, match_non_whitespace, SomeStr}; > +use crate::firewall::types::address::{Family, IpList}; > +use crate::firewall::types::alias::AliasName; > +use crate::firewall::types::ipset::IpsetName; > +use crate::firewall::types::log::LogLevel; > +use crate::firewall::types::port::PortList; > +use crate::firewall::types::rule::{Direction, Verdict}; > + > +#[derive(Clone, Debug, Default, Deserialize)] > +#[cfg_attr(test, derive(Eq, PartialEq))] > +#[serde(deny_unknown_fields, rename_all = "kebab-case")] > +pub(crate) struct RuleOptions { > + #[serde(alias = "p")] > + pub(crate) proto: Option<String>, > + > + pub(crate) dport: Option<String>, > + pub(crate) sport: Option<String>, > + > + pub(crate) dest: Option<String>, > + pub(crate) source: Option<String>, > + > + #[serde(alias = "i")] > + pub(crate) iface: Option<String>, > + > + pub(crate) log: Option<LogLevel>, > + pub(crate) icmp_type: Option<String>, > +} > + > +impl FromStr for RuleOptions { > + type Err = Error; > + > + fn from_str(mut line: &str) -> Result<Self, Self::Err> { > + let mut options = HashMap::new(); > + > + loop { > + line = line.trim_start(); > + > + if line.is_empty() { > + break; > + } > + > + line = line > + .strip_prefix('-') > + .ok_or_else(|| format_err!("expected an option starting with > '-'"))?; > + > + // second dash is optional > + line = line.strip_prefix('-').unwrap_or(line); > + > + let param; > + (param, line) = match_name(line) > + .ok_or_else(|| format_err!("expected a parameter name after > '-'"))?; > + > + let value; > + (value, line) = match_non_whitespace(line.trim_start()) > + .ok_or_else(|| format_err!("expected a value for > {param:?}"))?; > + > + if options.insert(param, SomeStr(value)).is_some() { > + bail!("duplicate option in rule: {param}") > + } > + } > + > + Ok(RuleOptions::deserialize(IntoDeserializer::< > + '_, > + crate::firewall::parse::SerdeStringError, > + >::into_deserializer( > + options > + ))?) > + } > +} > + > +#[derive(Clone, Debug, Default)] > +#[cfg_attr(test, derive(Eq, PartialEq))] > +pub struct RuleMatch { > + pub(crate) dir: Direction, > + pub(crate) verdict: Verdict, > + pub(crate) fw_macro: Option<String>, > + > + pub(crate) iface: Option<String>, > + pub(crate) log: Option<LogLevel>, > + pub(crate) ip: Option<IpMatch>, > + pub(crate) proto: Option<Protocol>, > +} > + > +impl RuleMatch { > + pub(crate) fn from_options( > + dir: Direction, > + verdict: Verdict, > + fw_macro: impl Into<Option<String>>, > + options: RuleOptions, > + ) -> Result<Self, Error> { > + if options.dport.is_some() && options.icmp_type.is_some() { > + bail!("dport and icmp-type are mutually exclusive"); > + } > + > + let ip = IpMatch::from_options(&options)?; > + let proto = Protocol::from_options(&options)?; > + > + // todo: check protocol & IP Version compatibility > + > + Ok(Self { > + dir, > + verdict, > + fw_macro: fw_macro.into(), > + iface: options.iface, > + log: options.log, > + ip, > + proto, > + }) > + } > + > + pub fn direction(&self) -> Direction { > + self.dir > + } > + > + pub fn iface(&self) -> Option<&str> { > + self.iface.as_deref() > + } > + > + pub fn verdict(&self) -> Verdict { > + self.verdict > + } > + > + pub fn fw_macro(&self) -> Option<&str> { > + self.fw_macro.as_deref() > + } > + > + pub fn log(&self) -> Option<LogLevel> { > + self.log > + } > + > + pub fn ip(&self) -> Option<&IpMatch> { > + self.ip.as_ref() > + } > + > + pub fn proto(&self) -> Option<&Protocol> { > + self.proto.as_ref() > + } > +} > + > +/// Returns `(Macro name, Verdict, RestOfTheLine)`. > +fn parse_action(line: &str) -> Result<(Option<&str>, Verdict, &str), Error> {
Hmm, since this is only used below, IMO it's fine that this returns a tuple like that on `Ok` - but should functions like that be used in multiple places, it might be beneficial to use a type alias or even a tuple struct for readability's sake. > + let (verdict, line) = > + match_name(line).ok_or_else(|| format_err!("expected a verdict or > macro name"))?; > + > + Ok(if let Some(line) = line.strip_prefix('(') { > + // <macro>(<verdict>) > + > + let macro_name = verdict; > + let (verdict, line) = match_name(line).ok_or_else(|| > format_err!("expected a verdict"))?; > + let line = line > + .strip_prefix(')') > + .ok_or_else(|| format_err!("expected closing ')' after > verdict"))?; > + > + let verdict: Verdict = verdict.parse()?; > + > + (Some(macro_name), verdict, line.trim_start()) > + } else { > + (None, verdict.parse()?, line.trim_start()) > + }) > +} > + > +impl FromStr for RuleMatch { > + type Err = Error; > + > + fn from_str(line: &str) -> Result<Self, Self::Err> { > + let (dir, rest) = match_name(line).ok_or_else(|| > format_err!("expected a direction"))?; > + > + let direction: Direction = dir.parse()?; > + > + let (fw_macro, verdict, rest) = parse_action(rest.trim_start())?; > + > + let options: RuleOptions = rest.trim_start().parse()?; > + > + Self::from_options(direction, verdict, fw_macro.map(str::to_string), > options) > + } > +} > + > +#[derive(Clone, Debug)] > +#[cfg_attr(test, derive(Eq, PartialEq))] > +pub struct IpMatch { > + pub(crate) src: Option<IpAddrMatch>, > + pub(crate) dst: Option<IpAddrMatch>, > +} > + > +impl IpMatch { > + pub fn new( > + src: impl Into<Option<IpAddrMatch>>, > + dst: impl Into<Option<IpAddrMatch>>, > + ) -> Result<Self, Error> { > + let source = src.into(); > + let dest = dst.into(); > + > + if source.is_none() && dest.is_none() { > + bail!("either src or dst must be set") > + } > + > + if let (Some(src), Some(dst)) = (&source, &dest) { > + if src.family() != dst.family() { > + bail!("src and dst family must be equal") > + } > + } > + > + let ip_match = Self { > + src: source, > + dst: dest, > + }; > + > + Ok(ip_match) > + } > + > + fn from_options(options: &RuleOptions) -> Result<Option<Self>, Error> { > + let src = options > + .source > + .as_ref() > + .map(|elem| elem.parse::<IpAddrMatch>()) > + .transpose()?; > + > + let dst = options > + .dest > + .as_ref() > + .map(|elem| elem.parse::<IpAddrMatch>()) > + .transpose()?; > + > + Ok(IpMatch::new(src, dst).ok()) > + } > + > + pub fn src(&self) -> Option<&IpAddrMatch> { > + self.src.as_ref() > + } > + > + pub fn dst(&self) -> Option<&IpAddrMatch> { > + self.dst.as_ref() > + } > +} > + > +#[derive(Clone, Debug, Deserialize)] > +#[cfg_attr(test, derive(Eq, PartialEq))] > +pub enum IpAddrMatch { > + Ip(IpList), > + Set(IpsetName), > + Alias(AliasName), > +} > + > +impl IpAddrMatch { > + pub fn family(&self) -> Option<Family> { > + if let IpAddrMatch::Ip(list) = self { > + return Some(list.family()); > + } > + > + None > + } > +} > + > +impl FromStr for IpAddrMatch { > + type Err = Error; > + > + fn from_str(value: &str) -> Result<Self, Error> { > + if value.is_empty() { > + bail!("empty IP specification"); > + } > + > + if let Ok(ip_list) = value.parse() { > + return Ok(IpAddrMatch::Ip(ip_list)); > + } > + > + if let Ok(ipset) = value.parse() { > + return Ok(IpAddrMatch::Set(ipset)); > + } > + > + if let Ok(name) = value.parse() { > + return Ok(IpAddrMatch::Alias(name)); > + } > + > + bail!("invalid IP specification: {value}") > + } > +} > + > +#[derive(Clone, Debug)] > +#[cfg_attr(test, derive(Eq, PartialEq))] > +pub enum Protocol { > + Dccp(Ports), > + Sctp(Sctp), > + Tcp(Tcp), > + Udp(Udp), > + UdpLite(Ports), > + Icmp(Icmp), > + Icmpv6(Icmpv6), > + Named(String), > + Numeric(u8), > +} > + > +impl Protocol { > + pub(crate) fn from_options(options: &RuleOptions) -> > Result<Option<Self>, Error> { > + let proto = match options.proto.as_deref() { > + Some(p) => p, > + None => return Ok(None), > + }; > + > + Ok(Some(match proto { > + "dccp" | "33" => Protocol::Dccp(Ports::from_options(options)?), > + "sctp" | "132" => Protocol::Sctp(Sctp::from_options(options)?), > + "tcp" | "6" => Protocol::Tcp(Tcp::from_options(options)?), > + "udp" | "17" => Protocol::Udp(Udp::from_options(options)?), > + "udplite" | "136" => > Protocol::UdpLite(Ports::from_options(options)?), > + "icmp" | "1" => Protocol::Icmp(Icmp::from_options(options)?), > + "ipv6-icmp" | "icmpv6" | "58" => > Protocol::Icmpv6(Icmpv6::from_options(options)?), > + other => match other.parse::<u8>() { > + Ok(num) => Protocol::Numeric(num), > + Err(_) => Protocol::Named(other.to_string()), > + }, > + })) > + } > + > + pub fn family(&self) -> Option<Family> { > + match self { > + Self::Icmp(_) => Some(Family::V4), > + Self::Icmpv6(_) => Some(Family::V6), > + _ => None, > + } > + } > +} > + > +#[derive(Clone, Debug, Default)] > +#[cfg_attr(test, derive(Eq, PartialEq))] > +pub struct Udp { > + ports: Ports, > +} > + > +impl Udp { > + fn from_options(options: &RuleOptions) -> Result<Self, Error> { > + Ok(Self { > + ports: Ports::from_options(options)?, > + }) > + } > + > + pub fn new(ports: Ports) -> Self { > + Self { ports } > + } > + > + pub fn ports(&self) -> &Ports { > + &self.ports > + } > +} > + > +impl From<Udp> for Protocol { > + fn from(value: Udp) -> Self { > + Protocol::Udp(value) > + } > +} > + > +#[derive(Clone, Debug, Default)] > +#[cfg_attr(test, derive(Eq, PartialEq))] > +pub struct Ports { > + sport: Option<PortList>, > + dport: Option<PortList>, > +} > + > +impl Ports { > + pub fn new(sport: impl Into<Option<PortList>>, dport: impl > Into<Option<PortList>>) -> Self { > + Self { > + sport: sport.into(), > + dport: dport.into(), > + } > + } > + > + fn from_options(options: &RuleOptions) -> Result<Self, Error> { > + Ok(Self { > + sport: options.sport.as_deref().map(|s| s.parse()).transpose()?, > + dport: options.dport.as_deref().map(|s| s.parse()).transpose()?, > + }) > + } > + > + pub fn from_u16(sport: impl Into<Option<u16>>, dport: impl > Into<Option<u16>>) -> Self { > + Self::new( > + sport.into().map(PortList::from), > + dport.into().map(PortList::from), > + ) > + } > + > + pub fn sport(&self) -> Option<&PortList> { > + self.sport.as_ref() > + } > + > + pub fn dport(&self) -> Option<&PortList> { > + self.dport.as_ref() > + } > +} > + > +#[derive(Clone, Debug, Default)] > +#[cfg_attr(test, derive(Eq, PartialEq))] > +pub struct Tcp { > + ports: Ports, > +} > + > +impl Tcp { > + pub fn new(ports: Ports) -> Self { > + Self { ports } > + } > + > + fn from_options(options: &RuleOptions) -> Result<Self, Error> { > + Ok(Self { > + ports: Ports::from_options(options)?, > + }) > + } > + > + pub fn ports(&self) -> &Ports { > + &self.ports > + } > +} > + > +impl From<Tcp> for Protocol { > + fn from(value: Tcp) -> Self { > + Protocol::Tcp(value) > + } > +} > + > +#[derive(Clone, Debug, Default)] > +#[cfg_attr(test, derive(Eq, PartialEq))] > +pub struct Sctp { > + ports: Ports, > +} > + > +impl Sctp { > + fn from_options(options: &RuleOptions) -> Result<Self, Error> { > + Ok(Self { > + ports: Ports::from_options(options)?, > + }) > + } > + > + pub fn ports(&self) -> &Ports { > + &self.ports > + } > +} > + > +#[derive(Clone, Debug, Default)] > +#[cfg_attr(test, derive(Eq, PartialEq))] > +pub struct Icmp { > + ty: Option<IcmpType>, > + code: Option<IcmpCode>, > +} > + > +impl Icmp { > + pub fn new_ty(ty: IcmpType) -> Self { > + Self { > + ty: Some(ty), > + ..Default::default() > + } > + } > + > + pub fn new_code(code: IcmpCode) -> Self { > + Self { > + code: Some(code), > + ..Default::default() > + } > + } > + > + fn from_options(options: &RuleOptions) -> Result<Self, Error> { > + if let Some(ty) = &options.icmp_type { > + return ty.parse(); > + } > + > + Ok(Self::default()) > + } > + > + pub fn ty(&self) -> Option<&IcmpType> { > + self.ty.as_ref() > + } > + > + pub fn code(&self) -> Option<&IcmpCode> { > + self.code.as_ref() > + } > +} > + > +impl FromStr for Icmp { > + type Err = Error; > + > + fn from_str(s: &str) -> Result<Self, Self::Err> { > + let mut this = Self::default(); > + > + if let Ok(ty) = s.parse() { > + this.ty = Some(ty); > + return Ok(this); > + } > + > + if let Ok(code) = s.parse() { > + this.code = Some(code); > + return Ok(this); > + } > + > + bail!("supplied string is neither a valid icmp type nor code"); > + } > +} > + > +#[derive(Clone, Debug)] > +#[cfg_attr(test, derive(Eq, PartialEq))] > +pub enum IcmpType { > + Numeric(u8), > + Named(&'static str), > +} > + > +// MUST BE SORTED! Should maaaybe note that it must be sorted for binary search, not just for any reason. :P > +const ICMP_TYPES: &[(&str, u8)] = &[ > + ("address-mask-reply", 18), > + ("address-mask-request", 17), > + ("destination-unreachable", 3), > + ("echo-reply", 0), > + ("echo-request", 8), > + ("info-reply", 16), > + ("info-request", 15), > + ("parameter-problem", 12), > + ("redirect", 5), > + ("router-advertisement", 9), > + ("router-solicitation", 10), > + ("source-quench", 4), > + ("time-exceeded", 11), > + ("timestamp-reply", 14), > + ("timestamp-request", 13), > +]; > + > +impl std::str::FromStr for IcmpType { > + type Err = Error; > + > + fn from_str(s: &str) -> Result<Self, Error> { > + if let Ok(ty) = s.trim().parse::<u8>() { > + return Ok(Self::Numeric(ty)); > + } > + > + if let Ok(index) = ICMP_TYPES.binary_search_by(|v| v.0.cmp(s)) { > + return Ok(Self::Named(ICMP_TYPES[index].0)); > + } > + > + bail!("{s:?} is not a valid icmp type"); > + } > +} > + > +impl fmt::Display for IcmpType { > + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { > + match self { > + IcmpType::Numeric(ty) => write!(f, "{ty}"), > + IcmpType::Named(ty) => write!(f, "{ty}"), > + } > + } > +} > + > +#[derive(Clone, Debug)] > +#[cfg_attr(test, derive(Eq, PartialEq))] > +pub enum IcmpCode { > + Numeric(u8), > + Named(&'static str), > +} > + > +// MUST BE SORTED! Same here. > +const ICMP_CODES: &[(&str, u8)] = &[ > + ("admin-prohibited", 13), > + ("host-prohibited", 10), > + ("host-unreachable", 1), > + ("net-prohibited", 9), > + ("net-unreachable", 0), > + ("port-unreachable", 3), > + ("prot-unreachable", 2), > +]; > + > +impl std::str::FromStr for IcmpCode { > + type Err = Error; > + > + fn from_str(s: &str) -> Result<Self, Error> { > + if let Ok(code) = s.trim().parse::<u8>() { > + return Ok(Self::Numeric(code)); > + } > + > + if let Ok(index) = ICMP_CODES.binary_search_by(|v| v.0.cmp(s)) { > + return Ok(Self::Named(ICMP_CODES[index].0)); > + } > + > + bail!("{s:?} is not a valid icmp code"); > + } > +} > + > +impl fmt::Display for IcmpCode { > + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { > + match self { > + IcmpCode::Numeric(code) => write!(f, "{code}"), > + IcmpCode::Named(code) => write!(f, "{code}"), > + } > + } > +} > + > +#[derive(Clone, Debug, Default)] > +#[cfg_attr(test, derive(Eq, PartialEq))] > +pub struct Icmpv6 { > + pub ty: Option<Icmpv6Type>, > + pub code: Option<Icmpv6Code>, > +} > + > +impl Icmpv6 { > + pub fn new_ty(ty: Icmpv6Type) -> Self { > + Self { > + ty: Some(ty), > + ..Default::default() > + } > + } > + > + pub fn new_code(code: Icmpv6Code) -> Self { > + Self { > + code: Some(code), > + ..Default::default() > + } > + } > + > + fn from_options(options: &RuleOptions) -> Result<Self, Error> { > + if let Some(ty) = &options.icmp_type { > + return ty.parse(); > + } > + > + Ok(Self::default()) > + } > + > + pub fn ty(&self) -> Option<&Icmpv6Type> { > + self.ty.as_ref() > + } > + > + pub fn code(&self) -> Option<&Icmpv6Code> { > + self.code.as_ref() > + } > +} > + > +impl FromStr for Icmpv6 { > + type Err = Error; > + > + fn from_str(s: &str) -> Result<Self, Self::Err> { > + let mut this = Self::default(); > + > + if let Ok(ty) = s.parse() { > + this.ty = Some(ty); > + return Ok(this); > + } > + > + if let Ok(code) = s.parse() { > + this.code = Some(code); > + return Ok(this); > + } > + > + bail!("supplied string is neither a valid icmpv6 type nor code"); > + } > +} > + > +#[derive(Clone, Debug)] > +#[cfg_attr(test, derive(Eq, PartialEq))] > +pub enum Icmpv6Type { > + Numeric(u8), > + Named(&'static str), > +} > + > +// MUST BE SORTED! And here too. > +const ICMPV6_TYPES: &[(&str, u8)] = &[ > + ("destination-unreachable", 1), > + ("echo-reply", 129), > + ("echo-request", 128), > + ("ind-neighbor-advert", 142), > + ("ind-neighbor-solicit", 141), > + ("mld-listener-done", 132), > + ("mld-listener-query", 130), > + ("mld-listener-reduction", 132), > + ("mld-listener-report", 131), > + ("mld2-listener-report", 143), > + ("nd-neighbor-advert", 136), > + ("nd-neighbor-solicit", 135), > + ("nd-redirect", 137), > + ("nd-router-advert", 134), > + ("nd-router-solicit", 133), > + ("packet-too-big", 2), > + ("parameter-problem", 4), > + ("router-renumbering", 138), > + ("time-exceeded", 3), > +]; > + > +impl std::str::FromStr for Icmpv6Type { > + type Err = Error; > + > + fn from_str(s: &str) -> Result<Self, Error> { > + if let Ok(ty) = s.trim().parse::<u8>() { > + return Ok(Self::Numeric(ty)); > + } > + > + if let Ok(index) = ICMPV6_TYPES.binary_search_by(|v| v.0.cmp(s)) { > + return Ok(Self::Named(ICMPV6_TYPES[index].0)); > + } > + > + bail!("{s:?} is not a valid icmpv6 type"); > + } > +} > + > +impl fmt::Display for Icmpv6Type { > + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { > + match self { > + Icmpv6Type::Numeric(ty) => write!(f, "{ty}"), > + Icmpv6Type::Named(ty) => write!(f, "{ty}"), > + } > + } > +} > + > +#[derive(Clone, Debug)] > +#[cfg_attr(test, derive(Eq, PartialEq))] > +pub enum Icmpv6Code { > + Numeric(u8), > + Named(&'static str), > +} > + > +// MUST BE SORTED! As well as here. > +const ICMPV6_CODES: &[(&str, u8)] = &[ > + ("addr-unreachable", 3), > + ("admin-prohibited", 1), > + ("no-route", 0), > + ("policy-fail", 5), > + ("port-unreachable", 4), > + ("reject-route", 6), > +]; > + > +impl std::str::FromStr for Icmpv6Code { > + type Err = Error; > + > + fn from_str(s: &str) -> Result<Self, Error> { > + if let Ok(code) = s.trim().parse::<u8>() { > + return Ok(Self::Numeric(code)); > + } > + > + if let Ok(index) = ICMPV6_CODES.binary_search_by(|v| v.0.cmp(s)) { > + return Ok(Self::Named(ICMPV6_CODES[index].0)); > + } > + > + bail!("{s:?} is not a valid icmpv6 code"); > + } > +} > + > +impl fmt::Display for Icmpv6Code { > + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { > + match self { > + Icmpv6Code::Numeric(code) => write!(f, "{code}"), > + Icmpv6Code::Named(code) => write!(f, "{code}"), > + } > + } > +} > + > +#[cfg(test)] > +mod tests { > + use crate::firewall::types::Cidr; > + > + use super::*; > + > + #[test] > + fn test_parse_action() { > + assert_eq!(parse_action("REJECT").unwrap(), (None, Verdict::Reject, > "")); > + > + assert_eq!( > + parse_action("SSH(ACCEPT) qweasd").unwrap(), > + (Some("SSH"), Verdict::Accept, "qweasd") > + ); > + } > + > + #[test] > + fn test_parse_ip_addr_match() { > + for input in [ > + "10.0.0.0/8", > + "10.0.0.0/8,192.168.0.0-192.168.255.255,172.16.0.1", > + "dc/test", > + "+guest/proxmox", > + ] { > + input.parse::<IpAddrMatch>().expect("valid ip match"); > + } > + > + for input in [ > + "10.0.0.0/", > + "10.0.0.0/8,192.168.256.0-192.168.255.255,172.16.0.1", > + "dcc/test", > + "+guest/", > + "", > + ] { > + input.parse::<IpAddrMatch>().expect_err("invalid ip match"); > + } > + } > + > + #[test] > + fn test_parse_options() { > + let mut options: RuleOptions = > + "-p udp --sport 123 --dport 234 -source 127.0.0.1 --dest > 127.0.0.1 -i ens1 --log crit" > + .parse() > + .expect("valid option string"); > + > + assert_eq!( > + options, > + RuleOptions { > + proto: Some("udp".to_string()), > + sport: Some("123".to_string()), > + dport: Some("234".to_string()), > + source: Some("127.0.0.1".to_string()), > + dest: Some("127.0.0.1".to_string()), > + iface: Some("ens1".to_string()), > + log: Some(LogLevel::Critical), > + icmp_type: None, > + } > + ); > + > + options = "".parse().expect("valid option string"); > + > + assert_eq!(options, RuleOptions::default(),); > + } > + > + #[test] > + fn test_construct_ip_match() { > + IpMatch::new( > + IpAddrMatch::Ip(IpList::from(Cidr::new_v4([10, 0, 0, 0], > 8).unwrap())), > + IpAddrMatch::Ip(IpList::from(Cidr::new_v4([10, 0, 0, 0], > 8).unwrap())), > + ) > + .expect("valid ip match"); > + > + IpMatch::new( > + IpAddrMatch::Ip(IpList::from(Cidr::new_v4([10, 0, 0, 0], > 8).unwrap())), > + IpAddrMatch::Ip(IpList::from(Cidr::new_v6([0x0000; 8], > 8).unwrap())), > + ) > + .expect_err("cannot mix ip families"); > + > + IpMatch::new(None, None).expect_err("at least one ip must be set"); > + } > + > + #[test] > + fn test_from_options() { > + let mut options = RuleOptions { > + proto: Some("tcp".to_string()), > + sport: Some("123".to_string()), > + dport: Some("234".to_string()), > + source: Some("192.168.0.1".to_string()), > + dest: Some("10.0.0.1".to_string()), > + iface: Some("eth123".to_string()), > + log: Some(LogLevel::Error), > + ..Default::default() > + }; > + > + assert_eq!( > + Protocol::from_options(&options).unwrap().unwrap(), > + Protocol::Tcp(Tcp::new(Ports::from_u16(123, 234))), > + ); > + > + assert_eq!( > + IpMatch::from_options(&options).unwrap().unwrap(), > + IpMatch::new( > + IpAddrMatch::Ip(IpList::from(Cidr::new_v4([192, 168, 0, 1], > 32).unwrap()),), > + IpAddrMatch::Ip(IpList::from(Cidr::new_v4([10, 0, 0, 1], > 32).unwrap()),) > + ) > + .unwrap(), > + ); > + > + options = RuleOptions::default(); > + > + assert_eq!(Protocol::from_options(&options).unwrap(), None,); > + > + assert_eq!(IpMatch::from_options(&options).unwrap(), None,); > + > + options = RuleOptions { > + proto: Some("tcp".to_string()), > + sport: Some("qwe".to_string()), > + source: Some("qwe".to_string()), > + ..Default::default() > + }; > + > + Protocol::from_options(&options).expect_err("invalid source port"); > + > + IpMatch::from_options(&options).expect_err("invalid source address"); > + > + options = RuleOptions { > + icmp_type: Some("port-unreachable".to_string()), > + dport: Some("123".to_string()), > + ..Default::default() > + }; > + > + RuleMatch::from_options(Direction::In, Verdict::Drop, None, options) > + .expect_err("cannot mix dport and icmp-type"); > + } > + > + #[test] > + fn test_parse_icmp() { > + let mut icmp: Icmp = "info-request".parse().expect("valid icmp > type"); > + > + assert_eq!( > + icmp, > + Icmp { > + ty: Some(IcmpType::Named("info-request")), > + code: None > + } > + ); > + > + icmp = "12".parse().expect("valid icmp type"); > + > + assert_eq!( > + icmp, > + Icmp { > + ty: Some(IcmpType::Numeric(12)), > + code: None > + } > + ); > + > + icmp = "port-unreachable".parse().expect("valid icmp code"); > + > + assert_eq!( > + icmp, > + Icmp { > + ty: None, > + code: Some(IcmpCode::Named("port-unreachable")) > + } > + ); > + } > + > + #[test] > + fn test_parse_icmp6() { > + let mut icmp: Icmpv6 = "echo-reply".parse().expect("valid icmpv6 > type"); > + > + assert_eq!( > + icmp, > + Icmpv6 { > + ty: Some(Icmpv6Type::Named("echo-reply")), > + code: None > + } > + ); > + > + icmp = "12".parse().expect("valid icmpv6 type"); > + > + assert_eq!( > + icmp, > + Icmpv6 { > + ty: Some(Icmpv6Type::Numeric(12)), > + code: None > + } > + ); > + > + icmp = "admin-prohibited".parse().expect("valid icmpv6 code"); > + > + assert_eq!( > + icmp, > + Icmpv6 { > + ty: None, > + code: Some(Icmpv6Code::Named("admin-prohibited")) > + } > + ); > + } > +} _______________________________________________ pve-devel mailing list pve-devel@lists.proxmox.com https://lists.proxmox.com/cgi-bin/mailman/listinfo/pve-devel