diff --git a/cns/fakes/iptablesfake.go b/cns/fakes/iptablesfake.go index f80fd075c4..1845a6e61d 100644 --- a/cns/fakes/iptablesfake.go +++ b/cns/fakes/iptablesfake.go @@ -2,6 +2,7 @@ package fakes import ( "errors" + "fmt" "strings" "github.com/Azure/azure-container-networking/iptables" @@ -11,10 +12,13 @@ var ( errChainExists = errors.New("chain already exists") errChainNotFound = errors.New("chain not found") errRuleExists = errors.New("rule already exists") + errRuleNotFound = errors.New("rule not found") + errIndexBounds = errors.New("index out of bounds") ) type IPTablesMock struct { - state map[string]map[string][]string + state map[string]map[string][]string + clearChainCallCount int } func NewIPTablesMock() *IPTablesMock { @@ -83,21 +87,110 @@ func (c *IPTablesMock) Exists(table, chain string, rulespec ...string) (bool, er func (c *IPTablesMock) Append(table, chain string, rulespec ...string) error { c.ensureTableExists(table) + chainRules := c.state[table][chain] + return c.Insert(table, chain, len(chainRules)+1, rulespec...) +} + +func (c *IPTablesMock) Insert(table, chain string, pos int, rulespec ...string) error { + c.ensureTableExists(table) + chainExists, _ := c.ChainExists(table, chain) if !chainExists { return errChainNotFound } - ruleExists, _ := c.Exists(table, chain, rulespec...) - if ruleExists { - return errRuleExists + targetRule := strings.Join(rulespec, " ") + chainRules := c.state[table][chain] + + // convert 1-based position to 0-based index + index := pos - 1 + if index < 0 { + index = 0 + } + + switch { + case index == len(chainRules): + c.state[table][chain] = append(chainRules, targetRule) + case index > len(chainRules): + return errIndexBounds + default: + c.state[table][chain] = append(chainRules[:index], append([]string{targetRule}, chainRules[index:]...)...) } - targetRule := strings.Join(rulespec, " ") - c.state[table][chain] = append(c.state[table][chain], targetRule) return nil } -func (c *IPTablesMock) Insert(table, chain string, _ int, rulespec ...string) error { - return c.Append(table, chain, rulespec...) +func (c *IPTablesMock) List(table, chain string) ([]string, error) { + c.ensureTableExists(table) + + chainExists, _ := c.ChainExists(table, chain) + if !chainExists { + return nil, errChainNotFound + } + + chainRules := c.state[table][chain] + // preallocate: 1 for chain header + number of rules + result := make([]string, 0, 1+len(chainRules)) + + // for built-in chains, start with policy -P, otherwise start with definition -N + builtins := []string{iptables.Input, iptables.Output, iptables.Prerouting, iptables.Postrouting, iptables.Forward} + isBuiltIn := false + for _, builtin := range builtins { + if chain == builtin { + isBuiltIn = true + break + } + } + + if isBuiltIn { + result = append(result, fmt.Sprintf("-P %s ACCEPT", chain)) + } else { + result = append(result, "-N "+chain) + } + + // iptables with -S always outputs the rules in -A format + for _, rule := range chainRules { + result = append(result, fmt.Sprintf("-A %s %s", chain, rule)) + } + + return result, nil +} + +func (c *IPTablesMock) ClearChain(table, chain string) error { + c.clearChainCallCount++ + c.ensureTableExists(table) + + chainExists, _ := c.ChainExists(table, chain) + if !chainExists { + return errChainNotFound + } + + c.state[table][chain] = []string{} + return nil +} + +func (c *IPTablesMock) Delete(table, chain string, rulespec ...string) error { + c.ensureTableExists(table) + + chainExists, _ := c.ChainExists(table, chain) + if !chainExists { + return errChainNotFound + } + + targetRule := strings.Join(rulespec, " ") + chainRules := c.state[table][chain] + + // delete first match + for i, rule := range chainRules { + if rule == targetRule { + c.state[table][chain] = append(chainRules[:i], chainRules[i+1:]...) + return nil + } + } + + return errRuleNotFound +} + +func (c *IPTablesMock) ClearChainCallCount() int { + return c.clearChainCallCount } diff --git a/cns/restserver/internalapi_linux.go b/cns/restserver/internalapi_linux.go index ef30dabf03..7b04c4f8e0 100644 --- a/cns/restserver/internalapi_linux.go +++ b/cns/restserver/internalapi_linux.go @@ -14,7 +14,7 @@ import ( "github.com/pkg/errors" ) -const SWIFT = "SWIFT-POSTROUTING" +const SWIFTPOSTROUTING = "SWIFT-POSTROUTING" type IPtablesProvider struct{} @@ -37,32 +37,62 @@ func (service *HTTPRestService) programSNATRules(req *cns.CreateNetworkContainer return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to create iptables interface : %v", err) } - chainExist, err := ipt.ChainExists(iptables.Nat, SWIFT) + chainExist, err := ipt.ChainExists(iptables.Nat, SWIFTPOSTROUTING) if err != nil { - return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to check for existence of SWIFT chain: %v", err) + return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to check for existence of SWIFT-POSTROUTING chain: %v", err) } if !chainExist { // create and append chain if it doesn't exist - logger.Printf("[Azure CNS] Creating SWIFT Chain ...") - err = ipt.NewChain(iptables.Nat, SWIFT) + logger.Printf("[Azure CNS] Creating SWIFT-POSTROUTING Chain ...") + err = ipt.NewChain(iptables.Nat, SWIFTPOSTROUTING) if err != nil { - return types.FailedToRunIPTableCmd, "[Azure CNS] failed to create SWIFT chain : " + err.Error() - } - logger.Printf("[Azure CNS] Append SWIFT Chain to POSTROUTING ...") - err = ipt.Append(iptables.Nat, iptables.Postrouting, "-j", SWIFT) - if err != nil { - return types.FailedToRunIPTableCmd, "[Azure CNS] failed to append SWIFT chain : " + err.Error() + return types.FailedToRunIPTableCmd, "[Azure CNS] failed to create SWIFT-POSTROUTING chain : " + err.Error() } } - postroutingToSwiftJumpexist, err := ipt.Exists(iptables.Nat, iptables.Postrouting, "-j", SWIFT) + // reconcile jump to SWIFT-POSTROUTING chain + rules, err := ipt.List(iptables.Nat, iptables.Postrouting) if err != nil { - return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to check for existence of POSTROUTING to SWIFT chain jump: %v", err) + return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to check rules in postrouting chain of nat table: %v", err) + } + swiftRuleIndex := len(rules) // append if neither jump rule from POSTROUTING is found + // one time migration from old SWIFT chain + // previously, CNI may have a jump to the SWIFT chain-- our jump to SWIFT-POSTROUTING needs to happen first + for index, rule := range rules { + if rule == "-A POSTROUTING -j SWIFT" { + // jump to SWIFT comes before jump to SWIFT-POSTROUTING, so potential reordering required + swiftRuleIndex = index + break + } + if rule == "-A POSTROUTING -j SWIFT-POSTROUTING" { + // jump to SWIFT-POSTROUTING comes before jump to SWIFT, which requires no further action + swiftRuleIndex = -1 + break + } } - if !postroutingToSwiftJumpexist { - logger.Printf("[Azure CNS] Append SWIFT Chain to POSTROUTING ...") - err = ipt.Append(iptables.Nat, iptables.Postrouting, "-j", SWIFT) + if swiftRuleIndex != -1 { + // jump SWIFT rule exists, insert SWIFT-POSTROUTING rule at the same position so it ends up running first + // first, remove any existing SWIFT-POSTROUTING rules to avoid duplicates + // note: inserting at len(rules) and deleting a jump to SWIFT-POSTROUTING is mutually exclusive + swiftPostroutingExists, err := ipt.Exists(iptables.Nat, iptables.Postrouting, "-j", SWIFTPOSTROUTING) + if err != nil { + return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to check for existence of SWIFT-POSTROUTING rule: %v", err) + } + if swiftPostroutingExists { + err = ipt.Delete(iptables.Nat, iptables.Postrouting, "-j", SWIFTPOSTROUTING) + if err != nil { + return types.FailedToRunIPTableCmd, "[Azure CNS] failed to delete existing SWIFT-POSTROUTING rule : " + err.Error() + } + } + + // slice index is 0-based, iptables insert is 1-based, but list also gives us the -P POSTROUTING ACCEPT + // as the first rule so swiftRuleIndex gives us the correct 1-indexed iptables position. + // Example: + // -P POSTROUTING ACCEPT is at swiftRuleIndex 0 + // -A POSTROUTING -j SWIFT is at swiftRuleIndex 1, and iptables index 1 + logger.Printf("[Azure CNS] Inserting SWIFT-POSTROUTING Chain at iptables position %d", swiftRuleIndex) + err = ipt.Insert(iptables.Nat, iptables.Postrouting, swiftRuleIndex, "-j", SWIFTPOSTROUTING) if err != nil { - return types.FailedToRunIPTableCmd, "[Azure CNS] failed to append SWIFT chain : " + err.Error() + return types.FailedToRunIPTableCmd, "[Azure CNS] failed to insert SWIFT-POSTROUTING chain : " + err.Error() } } @@ -71,39 +101,47 @@ func (service *HTTPRestService) programSNATRules(req *cns.CreateNetworkContainer // put the ip address in standard cidr form (where we zero out the parts that are not relevant) _, podSubnet, _ := net.ParseCIDR(v.IPAddress + "/" + fmt.Sprintf("%d", req.IPConfiguration.IPSubnet.PrefixLength)) - snatUDPRuleExists, err := ipt.Exists(iptables.Nat, SWIFT, "-m", "addrtype", "!", "--dst-type", "local", "-s", podSubnet.String(), "-d", networkutils.AzureDNS, "-p", iptables.UDP, "--dport", strconv.Itoa(iptables.DNSPort), "-j", iptables.Snat, "--to", ncPrimaryIP.String()) - if err != nil { - return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to check for existence of pod SNAT UDP rule : %v", err) + // define all rules we want in the chain + rules := [][]string{ + {"-m", "addrtype", "!", "--dst-type", "local", "-s", podSubnet.String(), "-d", networkutils.AzureDNS, "-p", iptables.UDP, "--dport", strconv.Itoa(iptables.DNSPort), "-j", iptables.Snat, "--to", ncPrimaryIP.String()}, + {"-m", "addrtype", "!", "--dst-type", "local", "-s", podSubnet.String(), "-d", networkutils.AzureDNS, "-p", iptables.TCP, "--dport", strconv.Itoa(iptables.DNSPort), "-j", iptables.Snat, "--to", ncPrimaryIP.String()}, + {"-m", "addrtype", "!", "--dst-type", "local", "-s", podSubnet.String(), "-d", networkutils.AzureIMDS, "-p", iptables.TCP, "--dport", strconv.Itoa(iptables.HTTPPort), "-j", iptables.Snat, "--to", req.HostPrimaryIP}, } - if !snatUDPRuleExists { - logger.Printf("[Azure CNS] Inserting pod SNAT UDP rule ...") - err = ipt.Insert(iptables.Nat, SWIFT, 1, "-m", "addrtype", "!", "--dst-type", "local", "-s", podSubnet.String(), "-d", networkutils.AzureDNS, "-p", iptables.UDP, "--dport", strconv.Itoa(iptables.DNSPort), "-j", iptables.Snat, "--to", ncPrimaryIP.String()) + + // check if all rules exist + allRulesExist := true + for _, rule := range rules { + exists, err := ipt.Exists(iptables.Nat, SWIFTPOSTROUTING, rule...) if err != nil { - return types.FailedToRunIPTableCmd, "[Azure CNS] failed to insert pod SNAT UDP rule : " + err.Error() + return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to check for existence of rule: %v", err) + } + if !exists { + allRulesExist = false + break } } - snatPodTCPRuleExists, err := ipt.Exists(iptables.Nat, SWIFT, "-m", "addrtype", "!", "--dst-type", "local", "-s", podSubnet.String(), "-d", networkutils.AzureDNS, "-p", iptables.TCP, "--dport", strconv.Itoa(iptables.DNSPort), "-j", iptables.Snat, "--to", ncPrimaryIP.String()) + // get current rule count in SWIFT-POSTROUTING chain + currentRules, err := ipt.List(iptables.Nat, SWIFTPOSTROUTING) if err != nil { - return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to check for existence of pod SNAT TCP rule : %v", err) + return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to list rules in SWIFT-POSTROUTING chain: %v", err) } - if !snatPodTCPRuleExists { - logger.Printf("[Azure CNS] Inserting pod SNAT TCP rule ...") - err = ipt.Insert(iptables.Nat, SWIFT, 1, "-m", "addrtype", "!", "--dst-type", "local", "-s", podSubnet.String(), "-d", networkutils.AzureDNS, "-p", iptables.TCP, "--dport", strconv.Itoa(iptables.DNSPort), "-j", iptables.Snat, "--to", ncPrimaryIP.String()) + + // if rule count doesn't match or not all rules exist, reconcile + // add one because there is always a singular starting rule in the chain, in addition to the ones we add + if len(currentRules) != len(rules)+1 || !allRulesExist { + logger.Printf("[Azure CNS] Reconciling SWIFT-POSTROUTING chain rules") + + err = ipt.ClearChain(iptables.Nat, SWIFTPOSTROUTING) if err != nil { - return types.FailedToRunIPTableCmd, "[Azure CNS] failed to insert pod SNAT TCP rule : " + err.Error() + return types.FailedToRunIPTableCmd, "[Azure CNS] failed to flush SWIFT-POSTROUTING chain : " + err.Error() } - } - snatIMDSRuleexist, err := ipt.Exists(iptables.Nat, SWIFT, "-m", "addrtype", "!", "--dst-type", "local", "-s", podSubnet.String(), "-d", networkutils.AzureIMDS, "-p", iptables.TCP, "--dport", strconv.Itoa(iptables.HTTPPort), "-j", iptables.Snat, "--to", req.HostPrimaryIP) - if err != nil { - return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to check for existence of pod SNAT IMDS rule : %v", err) - } - if !snatIMDSRuleexist { - logger.Printf("[Azure CNS] Inserting pod SNAT IMDS rule ...") - err = ipt.Insert(iptables.Nat, SWIFT, 1, "-m", "addrtype", "!", "--dst-type", "local", "-s", podSubnet.String(), "-d", networkutils.AzureIMDS, "-p", iptables.TCP, "--dport", strconv.Itoa(iptables.HTTPPort), "-j", iptables.Snat, "--to", req.HostPrimaryIP) - if err != nil { - return types.FailedToRunIPTableCmd, "[Azure CNS] failed to insert pod SNAT IMDS rule : " + err.Error() + for _, rule := range rules { + err = ipt.Append(iptables.Nat, SWIFTPOSTROUTING, rule...) + if err != nil { + return types.FailedToRunIPTableCmd, "[Azure CNS] failed to append rule to SWIFT-POSTROUTING chain : " + err.Error() + } } } diff --git a/cns/restserver/internalapi_linux_test.go b/cns/restserver/internalapi_linux_test.go index 731ca4d989..a292b30f98 100644 --- a/cns/restserver/internalapi_linux_test.go +++ b/cns/restserver/internalapi_linux_test.go @@ -27,16 +27,24 @@ func (c *FakeIPTablesProvider) GetIPTables() (iptablesClient, error) { } func TestAddSNATRules(t *testing.T) { - type expectedScenario struct { + type chainExpectation struct { + table string + chain string + expected []string + } + + type preExistingRule struct { table string chain string rule []string } tests := []struct { - name string - input *cns.CreateNetworkContainerRequest - expected []expectedScenario + name string + input *cns.CreateNetworkContainerRequest + preExistingRules []preExistingRule + expectedChains []chainExpectation + expectedClearChainCalls int }{ { // in pod subnet, the primary nic ip is in the same address space as the pod subnet @@ -56,93 +64,302 @@ func TestAddSNATRules(t *testing.T) { }, HostPrimaryIP: "10.0.0.4", }, - expected: []expectedScenario{ + expectedChains: []chainExpectation{ { table: iptables.Nat, - chain: SWIFT, - rule: []string{ - "-m", "addrtype", "!", "--dst-type", "local", "-s", "240.1.2.0/24", "-d", - networkutils.AzureDNS, "-p", iptables.UDP, "--dport", strconv.Itoa(iptables.DNSPort), "-j", iptables.Snat, "--to", "240.1.2.1", + chain: SWIFTPOSTROUTING, + expected: []string{ + "-N SWIFT-POSTROUTING", + "-A SWIFT-POSTROUTING -m addrtype ! --dst-type local -s 240.1.2.0/24 -d " + networkutils.AzureDNS + " -p udp --dport " + strconv.Itoa(iptables.DNSPort) + " -j SNAT --to 240.1.2.1", + "-A SWIFT-POSTROUTING -m addrtype ! --dst-type local -s 240.1.2.0/24 -d " + networkutils.AzureDNS + " -p tcp --dport " + strconv.Itoa(iptables.DNSPort) + " -j SNAT --to 240.1.2.1", + "-A SWIFT-POSTROUTING -m addrtype ! --dst-type local -s 240.1.2.0/24 -d " + networkutils.AzureIMDS + " -p tcp --dport " + strconv.Itoa(iptables.HTTPPort) + " -j SNAT --to 10.0.0.4", + }, + }, + { + table: iptables.Nat, + chain: iptables.Postrouting, + expected: []string{ + "-P POSTROUTING ACCEPT", + "-A POSTROUTING -j SWIFT-POSTROUTING", }, }, + }, + expectedClearChainCalls: 1, + }, + { + // test with pre-existing SWIFT rule that should be migrated + name: "migration from old SWIFT", + input: &cns.CreateNetworkContainerRequest{ + NetworkContainerid: ncID, + IPConfiguration: cns.IPConfiguration{ + IPSubnet: cns.IPSubnet{ + IPAddress: "240.1.2.1", + PrefixLength: 24, + }, + }, + SecondaryIPConfigs: map[string]cns.SecondaryIPConfig{ + "abc": { + IPAddress: "240.1.2.7", + }, + }, + HostPrimaryIP: "10.0.0.4", + }, + preExistingRules: []preExistingRule{ + { + table: iptables.Nat, + chain: iptables.Postrouting, + rule: []string{"-j", "SWIFT"}, + }, + { + // stale rule at lower priority should be cleaned up + table: iptables.Nat, + chain: iptables.Postrouting, + rule: []string{"-j", "SWIFT-POSTROUTING"}, + }, { + // should be cleaned up table: iptables.Nat, - chain: SWIFT, + chain: SWIFTPOSTROUTING, rule: []string{ - "-m", "addrtype", "!", "--dst-type", "local", "-s", "240.1.2.0/24", "-d", - networkutils.AzureDNS, "-p", iptables.TCP, "--dport", strconv.Itoa(iptables.DNSPort), "-j", iptables.Snat, "--to", "240.1.2.1", + "-m", "addrtype", "!", "--dst-type", "local", "-s", "240.1.2.0/24", "-d", networkutils.AzureDNS, + "-p", "udp", "--dport", strconv.Itoa(iptables.DNSPort), "-j", "SNAT", "--to", "99.1.2.1", }, }, { table: iptables.Nat, - chain: SWIFT, + chain: "SWIFT", rule: []string{ - "-m", "addrtype", "!", "--dst-type", "local", "-s", "240.1.2.0/24", "-d", - networkutils.AzureIMDS, "-p", iptables.TCP, "--dport", strconv.Itoa(iptables.HTTPPort), "-j", iptables.Snat, "--to", "10.0.0.4", + "-m", "addrtype", "!", "--dst-type", "local", "-s", "240.1.2.0/24", "-d", networkutils.AzureDNS, + "-p", "udp", "--dport", strconv.Itoa(iptables.DNSPort), "-j", "SNAT", "--to", "192.1.2.1", + }, + }, + }, + expectedChains: []chainExpectation{ + { + table: iptables.Nat, + chain: SWIFTPOSTROUTING, + expected: []string{ + "-N SWIFT-POSTROUTING", + "-A SWIFT-POSTROUTING -m addrtype ! --dst-type local -s 240.1.2.0/24 -d " + networkutils.AzureDNS + " -p udp --dport " + strconv.Itoa(iptables.DNSPort) + " -j SNAT --to 240.1.2.1", + "-A SWIFT-POSTROUTING -m addrtype ! --dst-type local -s 240.1.2.0/24 -d " + networkutils.AzureDNS + " -p tcp --dport " + strconv.Itoa(iptables.DNSPort) + " -j SNAT --to 240.1.2.1", + "-A SWIFT-POSTROUTING -m addrtype ! --dst-type local -s 240.1.2.0/24 -d " + networkutils.AzureIMDS + " -p tcp --dport " + strconv.Itoa(iptables.HTTPPort) + " -j SNAT --to 10.0.0.4", + }, + }, + { + table: iptables.Nat, + chain: iptables.Postrouting, + expected: []string{ + "-P POSTROUTING ACCEPT", + "-A POSTROUTING -j SWIFT-POSTROUTING", + "-A POSTROUTING -j SWIFT", + }, + }, + { + // stale old rule can remain + table: iptables.Nat, + chain: "SWIFT", + expected: []string{ + "-N SWIFT", + "-A SWIFT -m addrtype ! --dst-type local -s 240.1.2.0/24 -d " + networkutils.AzureDNS + " -p udp --dport " + strconv.Itoa(iptables.DNSPort) + " -j SNAT --to 192.1.2.1", }, }, }, + expectedClearChainCalls: 1, }, { - // in vnet scale, the primary nic ip becomes the node ip (diff address space from pod subnet) - name: "vnet scale", + // test after migration has already completed + name: "after migration from old SWIFT", input: &cns.CreateNetworkContainerRequest{ NetworkContainerid: ncID, IPConfiguration: cns.IPConfiguration{ IPSubnet: cns.IPSubnet{ - IPAddress: "10.0.0.4", - PrefixLength: 28, + IPAddress: "240.1.2.1", + PrefixLength: 24, }, }, SecondaryIPConfigs: map[string]cns.SecondaryIPConfig{ "abc": { - IPAddress: "240.1.2.15", + IPAddress: "240.1.2.7", }, }, HostPrimaryIP: "10.0.0.4", }, - expected: []expectedScenario{ + preExistingRules: []preExistingRule{ + { + // rule at higher priority means nothing happens + table: iptables.Nat, + chain: iptables.Postrouting, + rule: []string{"-j", "SWIFT-POSTROUTING"}, + }, + { + table: iptables.Nat, + chain: iptables.Postrouting, + rule: []string{"-j", "SWIFT"}, + }, { table: iptables.Nat, - chain: SWIFT, + chain: SWIFTPOSTROUTING, rule: []string{ - "-m", "addrtype", "!", "--dst-type", "local", "-s", "240.1.2.0/28", "-d", - networkutils.AzureDNS, "-p", iptables.UDP, "--dport", strconv.Itoa(iptables.DNSPort), "-j", iptables.Snat, "--to", "10.0.0.4", + "-m", "addrtype", "!", "--dst-type", "local", "-s", "240.1.2.0/24", "-d", networkutils.AzureDNS, + "-p", "udp", "--dport", strconv.Itoa(iptables.DNSPort), "-j", "SNAT", "--to", "240.1.2.1", }, }, { table: iptables.Nat, - chain: SWIFT, + chain: SWIFTPOSTROUTING, rule: []string{ - "-m", "addrtype", "!", "--dst-type", "local", "-s", "240.1.2.0/28", "-d", - networkutils.AzureDNS, "-p", iptables.TCP, "--dport", strconv.Itoa(iptables.DNSPort), "-j", iptables.Snat, "--to", "10.0.0.4", + "-m", "addrtype", "!", "--dst-type", "local", "-s", "240.1.2.0/24", "-d", networkutils.AzureDNS, + "-p", "tcp", "--dport", strconv.Itoa(iptables.DNSPort), "-j", "SNAT", "--to", "240.1.2.1", }, }, { table: iptables.Nat, - chain: SWIFT, + chain: SWIFTPOSTROUTING, rule: []string{ - "-m", "addrtype", "!", "--dst-type", "local", "-s", "240.1.2.0/28", "-d", - networkutils.AzureIMDS, "-p", iptables.TCP, "--dport", strconv.Itoa(iptables.HTTPPort), "-j", iptables.Snat, "--to", "10.0.0.4", + "-m", "addrtype", "!", "--dst-type", "local", "-s", "240.1.2.0/24", "-d", networkutils.AzureIMDS, + "-p", "tcp", "--dport", strconv.Itoa(iptables.HTTPPort), "-j", "SNAT", "--to", "10.0.0.4", }, }, + { + table: iptables.Nat, + chain: "SWIFT", + rule: []string{ + "-m", "addrtype", "!", "--dst-type", "local", "-s", "240.1.2.0/24", "-d", networkutils.AzureDNS, + "-p", "udp", "--dport", strconv.Itoa(iptables.DNSPort), "-j", "SNAT", "--to", "192.1.2.1", + }, + }, + }, + expectedChains: []chainExpectation{ + { + table: iptables.Nat, + chain: SWIFTPOSTROUTING, + expected: []string{ + "-N SWIFT-POSTROUTING", + "-A SWIFT-POSTROUTING -m addrtype ! --dst-type local -s 240.1.2.0/24 -d " + networkutils.AzureDNS + " -p udp --dport " + strconv.Itoa(iptables.DNSPort) + " -j SNAT --to 240.1.2.1", + "-A SWIFT-POSTROUTING -m addrtype ! --dst-type local -s 240.1.2.0/24 -d " + networkutils.AzureDNS + " -p tcp --dport " + strconv.Itoa(iptables.DNSPort) + " -j SNAT --to 240.1.2.1", + "-A SWIFT-POSTROUTING -m addrtype ! --dst-type local -s 240.1.2.0/24 -d " + networkutils.AzureIMDS + " -p tcp --dport " + strconv.Itoa(iptables.HTTPPort) + " -j SNAT --to 10.0.0.4", + }, + }, + { + table: iptables.Nat, + chain: iptables.Postrouting, + expected: []string{ + "-P POSTROUTING ACCEPT", + "-A POSTROUTING -j SWIFT-POSTROUTING", + "-A POSTROUTING -j SWIFT", + }, + }, + { + // stale old rule can remain + table: iptables.Nat, + chain: "SWIFT", + expected: []string{ + "-N SWIFT", + "-A SWIFT -m addrtype ! --dst-type local -s 240.1.2.0/24 -d " + networkutils.AzureDNS + " -p udp --dport " + strconv.Itoa(iptables.DNSPort) + " -j SNAT --to 192.1.2.1", + }, + }, + }, + expectedClearChainCalls: 0, + }, + { + // in vnet scale, the primary nic ip becomes the node ip (diff address space from pod subnet) + name: "vnet scale", + input: &cns.CreateNetworkContainerRequest{ + NetworkContainerid: ncID, + IPConfiguration: cns.IPConfiguration{ + IPSubnet: cns.IPSubnet{ + IPAddress: "10.0.0.4", + PrefixLength: 28, + }, + }, + SecondaryIPConfigs: map[string]cns.SecondaryIPConfig{ + "abc": { + IPAddress: "240.1.2.15", + }, + }, + HostPrimaryIP: "10.0.0.4", }, + expectedChains: []chainExpectation{ + { + table: iptables.Nat, + chain: SWIFTPOSTROUTING, + expected: []string{ + "-N SWIFT-POSTROUTING", + "-A SWIFT-POSTROUTING -m addrtype ! --dst-type local -s 240.1.2.0/28 -d " + networkutils.AzureDNS + " -p udp --dport " + strconv.Itoa(iptables.DNSPort) + " -j SNAT --to 10.0.0.4", + "-A SWIFT-POSTROUTING -m addrtype ! --dst-type local -s 240.1.2.0/28 -d " + networkutils.AzureDNS + " -p tcp --dport " + strconv.Itoa(iptables.DNSPort) + " -j SNAT --to 10.0.0.4", + "-A SWIFT-POSTROUTING -m addrtype ! --dst-type local -s 240.1.2.0/28 -d " + networkutils.AzureIMDS + " -p tcp --dport " + strconv.Itoa(iptables.HTTPPort) + " -j SNAT --to 10.0.0.4", + }, + }, + { + table: iptables.Nat, + chain: iptables.Postrouting, + expected: []string{ + "-P POSTROUTING ACCEPT", + "-A POSTROUTING -j SWIFT-POSTROUTING", + }, + }, + }, + expectedClearChainCalls: 1, }, } for _, tt := range tests { - service := getTestService(cns.KubernetesCRD) - service.iptables = &FakeIPTablesProvider{} - resp, msg := service.programSNATRules(tt.input) - if resp != types.Success { - t.Fatal("failed to program snat rules", msg, " case: ", tt.name) - } - finalState, _ := service.iptables.GetIPTables() - for _, ex := range tt.expected { - exists, err := finalState.Exists(ex.table, ex.chain, ex.rule...) - if err != nil || !exists { - t.Fatal("rule not found", ex.rule, " case: ", tt.name) + t.Run(tt.name, func(t *testing.T) { + service := getTestService(cns.KubernetesCRD) + ipt := fakes.NewIPTablesMock() + service.iptables = &FakeIPTablesProvider{ + iptables: ipt, + } + + // setup pre-existing rules + if len(tt.preExistingRules) > 0 { + for _, preRule := range tt.preExistingRules { + chainExists, _ := ipt.ChainExists(preRule.table, preRule.chain) + + if !chainExists { + err := ipt.NewChain(preRule.table, preRule.chain) + if err != nil { + t.Fatal("failed to setup pre-existing rule chain:", err) + } + } + + err := ipt.Append(preRule.table, preRule.chain, preRule.rule...) + if err != nil { + t.Fatal("failed to setup pre-existing rule:", err) + } + } + } + + resp, msg := service.programSNATRules(tt.input) + if resp != types.Success { + t.Fatal("failed to program snat rules", msg) + } + + // verify chain contents using List + for _, chainExp := range tt.expectedChains { + actualRules, err := ipt.List(chainExp.table, chainExp.chain) + if err != nil { + t.Fatal("failed to list rules for chain", chainExp.chain, ":", err) + } + + if len(actualRules) != len(chainExp.expected) { + t.Fatalf("chain %s rule count mismatch: got %d, expected %d\nActual: %v\nExpected: %v", + chainExp.chain, len(actualRules), len(chainExp.expected), actualRules, chainExp.expected) + } + + for i, expectedRule := range chainExp.expected { + if actualRules[i] != expectedRule { + t.Fatalf("chain %s rule %d mismatch:\nActual: %s\nExpected: %s", + chainExp.chain, i, actualRules[i], expectedRule) + } + } + } + + // verify ClearChain was called the expected number of times + actualClearChainCalls := ipt.ClearChainCallCount() + if actualClearChainCalls != tt.expectedClearChainCalls { + t.Fatalf("ClearChain call count mismatch: got %d, expected %d", actualClearChainCalls, tt.expectedClearChainCalls) } - } + }) } } diff --git a/cns/restserver/restserver.go b/cns/restserver/restserver.go index c467ab04e2..56b6e15a48 100644 --- a/cns/restserver/restserver.go +++ b/cns/restserver/restserver.go @@ -60,6 +60,9 @@ type iptablesClient interface { Append(table string, chain string, rulespec ...string) error Exists(table string, chain string, rulespec ...string) (bool, error) Insert(table string, chain string, pos int, rulespec ...string) error + List(table string, chain string) ([]string, error) + ClearChain(table string, chain string) error + Delete(table, chain string, rulespec ...string) error } type iptablesGetter interface {