Skip to content

feat: add cns iptables reconciliation #3885

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 101 additions & 8 deletions cns/fakes/iptablesfake.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package fakes

import (
"errors"
"fmt"
"strings"

"github.com/Azure/azure-container-networking/iptables"
Expand All @@ -11,10 +12,13 @@ var (
errChainExists = errors.New("chain already exists")
errChainNotFound = errors.New("chain not found")
errRuleExists = errors.New("rule already exists")
errRuleNotFound = errors.New("rule not found")
errIndexBounds = errors.New("index out of bounds")
)

type IPTablesMock struct {
state map[string]map[string][]string
state map[string]map[string][]string
clearChainCallCount int
}

func NewIPTablesMock() *IPTablesMock {
Expand Down Expand Up @@ -83,21 +87,110 @@ func (c *IPTablesMock) Exists(table, chain string, rulespec ...string) (bool, er
func (c *IPTablesMock) Append(table, chain string, rulespec ...string) error {
c.ensureTableExists(table)

chainRules := c.state[table][chain]
return c.Insert(table, chain, len(chainRules)+1, rulespec...)
}

func (c *IPTablesMock) Insert(table, chain string, pos int, rulespec ...string) error {
c.ensureTableExists(table)

chainExists, _ := c.ChainExists(table, chain)
if !chainExists {
return errChainNotFound
}

ruleExists, _ := c.Exists(table, chain, rulespec...)
if ruleExists {
return errRuleExists
targetRule := strings.Join(rulespec, " ")
chainRules := c.state[table][chain]

// convert 1-based position to 0-based index
index := pos - 1
if index < 0 {
index = 0
}

switch {
case index == len(chainRules):
c.state[table][chain] = append(chainRules, targetRule)
case index > len(chainRules):
return errIndexBounds
default:
c.state[table][chain] = append(chainRules[:index], append([]string{targetRule}, chainRules[index:]...)...)
}

targetRule := strings.Join(rulespec, " ")
c.state[table][chain] = append(c.state[table][chain], targetRule)
return nil
}

func (c *IPTablesMock) Insert(table, chain string, _ int, rulespec ...string) error {
return c.Append(table, chain, rulespec...)
func (c *IPTablesMock) List(table, chain string) ([]string, error) {
c.ensureTableExists(table)

chainExists, _ := c.ChainExists(table, chain)
if !chainExists {
return nil, errChainNotFound
}

chainRules := c.state[table][chain]
// preallocate: 1 for chain header + number of rules
result := make([]string, 0, 1+len(chainRules))

// for built-in chains, start with policy -P, otherwise start with definition -N
builtins := []string{iptables.Input, iptables.Output, iptables.Prerouting, iptables.Postrouting, iptables.Forward}
isBuiltIn := false
for _, builtin := range builtins {
if chain == builtin {
isBuiltIn = true
break
}
}

if isBuiltIn {
result = append(result, fmt.Sprintf("-P %s ACCEPT", chain))
} else {
result = append(result, "-N "+chain)
}

// iptables with -S always outputs the rules in -A format
for _, rule := range chainRules {
result = append(result, fmt.Sprintf("-A %s %s", chain, rule))
}

return result, nil
}

func (c *IPTablesMock) ClearChain(table, chain string) error {
c.clearChainCallCount++
c.ensureTableExists(table)

chainExists, _ := c.ChainExists(table, chain)
if !chainExists {
return errChainNotFound
}

c.state[table][chain] = []string{}
return nil
}

func (c *IPTablesMock) Delete(table, chain string, rulespec ...string) error {
c.ensureTableExists(table)

chainExists, _ := c.ChainExists(table, chain)
if !chainExists {
return errChainNotFound
}

targetRule := strings.Join(rulespec, " ")
chainRules := c.state[table][chain]

// delete first match
for i, rule := range chainRules {
if rule == targetRule {
c.state[table][chain] = append(chainRules[:i], chainRules[i+1:]...)
return nil
}
}

return errRuleNotFound
}

func (c *IPTablesMock) ClearChainCallCount() int {
return c.clearChainCallCount
}
118 changes: 78 additions & 40 deletions cns/restserver/internalapi_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
"github.com/pkg/errors"
)

const SWIFT = "SWIFT-POSTROUTING"
const SWIFTPOSTROUTING = "SWIFT-POSTROUTING"

type IPtablesProvider struct{}

Expand All @@ -37,32 +37,62 @@ func (service *HTTPRestService) programSNATRules(req *cns.CreateNetworkContainer
return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to create iptables interface : %v", err)
}

chainExist, err := ipt.ChainExists(iptables.Nat, SWIFT)
chainExist, err := ipt.ChainExists(iptables.Nat, SWIFTPOSTROUTING)
if err != nil {
return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to check for existence of SWIFT chain: %v", err)
return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to check for existence of SWIFT-POSTROUTING chain: %v", err)
}
if !chainExist { // create and append chain if it doesn't exist
logger.Printf("[Azure CNS] Creating SWIFT Chain ...")
err = ipt.NewChain(iptables.Nat, SWIFT)
logger.Printf("[Azure CNS] Creating SWIFT-POSTROUTING Chain ...")
err = ipt.NewChain(iptables.Nat, SWIFTPOSTROUTING)
if err != nil {
return types.FailedToRunIPTableCmd, "[Azure CNS] failed to create SWIFT chain : " + err.Error()
}
logger.Printf("[Azure CNS] Append SWIFT Chain to POSTROUTING ...")
err = ipt.Append(iptables.Nat, iptables.Postrouting, "-j", SWIFT)
if err != nil {
return types.FailedToRunIPTableCmd, "[Azure CNS] failed to append SWIFT chain : " + err.Error()
return types.FailedToRunIPTableCmd, "[Azure CNS] failed to create SWIFT-POSTROUTING chain : " + err.Error()
}
}

postroutingToSwiftJumpexist, err := ipt.Exists(iptables.Nat, iptables.Postrouting, "-j", SWIFT)
// reconcile jump to SWIFT-POSTROUTING chain
rules, err := ipt.List(iptables.Nat, iptables.Postrouting)
if err != nil {
return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to check for existence of POSTROUTING to SWIFT chain jump: %v", err)
return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to check rules in postrouting chain of nat table: %v", err)
}
swiftRuleIndex := len(rules) // append if neither jump rule from POSTROUTING is found
// one time migration from old SWIFT chain
// previously, CNI may have a jump to the SWIFT chain-- our jump to SWIFT-POSTROUTING needs to happen first
for index, rule := range rules {
if rule == "-A POSTROUTING -j SWIFT" {
// jump to SWIFT comes before jump to SWIFT-POSTROUTING, so potential reordering required
swiftRuleIndex = index
break
}
if rule == "-A POSTROUTING -j SWIFT-POSTROUTING" {
// jump to SWIFT-POSTROUTING comes before jump to SWIFT, which requires no further action
swiftRuleIndex = -1
break
}
}
if !postroutingToSwiftJumpexist {
logger.Printf("[Azure CNS] Append SWIFT Chain to POSTROUTING ...")
err = ipt.Append(iptables.Nat, iptables.Postrouting, "-j", SWIFT)
if swiftRuleIndex != -1 {
// jump SWIFT rule exists, insert SWIFT-POSTROUTING rule at the same position so it ends up running first
// first, remove any existing SWIFT-POSTROUTING rules to avoid duplicates
// note: inserting at len(rules) and deleting a jump to SWIFT-POSTROUTING is mutually exclusive
swiftPostroutingExists, err := ipt.Exists(iptables.Nat, iptables.Postrouting, "-j", SWIFTPOSTROUTING)
if err != nil {
return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to check for existence of SWIFT-POSTROUTING rule: %v", err)
}
if swiftPostroutingExists {
err = ipt.Delete(iptables.Nat, iptables.Postrouting, "-j", SWIFTPOSTROUTING)
if err != nil {
return types.FailedToRunIPTableCmd, "[Azure CNS] failed to delete existing SWIFT-POSTROUTING rule : " + err.Error()
}
}

// slice index is 0-based, iptables insert is 1-based, but list also gives us the -P POSTROUTING ACCEPT
// as the first rule so swiftRuleIndex gives us the correct 1-indexed iptables position.
// Example:
// -P POSTROUTING ACCEPT is at swiftRuleIndex 0
// -A POSTROUTING -j SWIFT is at swiftRuleIndex 1, and iptables index 1
logger.Printf("[Azure CNS] Inserting SWIFT-POSTROUTING Chain at iptables position %d", swiftRuleIndex)
err = ipt.Insert(iptables.Nat, iptables.Postrouting, swiftRuleIndex, "-j", SWIFTPOSTROUTING)
Comment on lines +92 to +93
Copy link
Preview

Copilot AI Aug 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The insert position calculation may be incorrect. The comment indicates that swiftRuleIndex gives the correct 1-indexed iptables position, but if swiftRuleIndex is len(rules), this could exceed valid iptables positions. Consider adding bounds checking or using swiftRuleIndex+1 when swiftRuleIndex equals len(rules).

Suggested change
logger.Printf("[Azure CNS] Inserting SWIFT-POSTROUTING Chain at iptables position %d", swiftRuleIndex)
err = ipt.Insert(iptables.Nat, iptables.Postrouting, swiftRuleIndex, "-j", SWIFTPOSTROUTING)
// Ensure insert position does not exceed valid iptables positions
insertPos := swiftRuleIndex
if swiftRuleIndex >= len(rules) {
insertPos = len(rules)
}
logger.Printf("[Azure CNS] Inserting SWIFT-POSTROUTING Chain at iptables position %d", insertPos)
err = ipt.Insert(iptables.Nat, iptables.Postrouting, insertPos, "-j", SWIFTPOSTROUTING)

Copilot uses AI. Check for mistakes.

Copy link
Contributor Author

@QxBytes QxBytes Aug 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rules correspond to what is received from List(), which includes the built-in -P or -N rule. So, if you see one rule with iptables -nvL -t nat, List gives us the built in rule (-P or -N) and also the one rule we see, leading to a length of 2. If we insert using iptables index 2, it ends up at the end of the chain (after the one rule we see from iptables -nvL -t nat.

if err != nil {
return types.FailedToRunIPTableCmd, "[Azure CNS] failed to append SWIFT chain : " + err.Error()
return types.FailedToRunIPTableCmd, "[Azure CNS] failed to insert SWIFT-POSTROUTING chain : " + err.Error()
}
}

Expand All @@ -71,39 +101,47 @@ func (service *HTTPRestService) programSNATRules(req *cns.CreateNetworkContainer
// put the ip address in standard cidr form (where we zero out the parts that are not relevant)
_, podSubnet, _ := net.ParseCIDR(v.IPAddress + "/" + fmt.Sprintf("%d", req.IPConfiguration.IPSubnet.PrefixLength))

snatUDPRuleExists, err := ipt.Exists(iptables.Nat, SWIFT, "-m", "addrtype", "!", "--dst-type", "local", "-s", podSubnet.String(), "-d", networkutils.AzureDNS, "-p", iptables.UDP, "--dport", strconv.Itoa(iptables.DNSPort), "-j", iptables.Snat, "--to", ncPrimaryIP.String())
if err != nil {
return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to check for existence of pod SNAT UDP rule : %v", err)
// define all rules we want in the chain
rules := [][]string{
{"-m", "addrtype", "!", "--dst-type", "local", "-s", podSubnet.String(), "-d", networkutils.AzureDNS, "-p", iptables.UDP, "--dport", strconv.Itoa(iptables.DNSPort), "-j", iptables.Snat, "--to", ncPrimaryIP.String()},
{"-m", "addrtype", "!", "--dst-type", "local", "-s", podSubnet.String(), "-d", networkutils.AzureDNS, "-p", iptables.TCP, "--dport", strconv.Itoa(iptables.DNSPort), "-j", iptables.Snat, "--to", ncPrimaryIP.String()},
{"-m", "addrtype", "!", "--dst-type", "local", "-s", podSubnet.String(), "-d", networkutils.AzureIMDS, "-p", iptables.TCP, "--dport", strconv.Itoa(iptables.HTTPPort), "-j", iptables.Snat, "--to", req.HostPrimaryIP},
}
if !snatUDPRuleExists {
logger.Printf("[Azure CNS] Inserting pod SNAT UDP rule ...")
err = ipt.Insert(iptables.Nat, SWIFT, 1, "-m", "addrtype", "!", "--dst-type", "local", "-s", podSubnet.String(), "-d", networkutils.AzureDNS, "-p", iptables.UDP, "--dport", strconv.Itoa(iptables.DNSPort), "-j", iptables.Snat, "--to", ncPrimaryIP.String())

// check if all rules exist
allRulesExist := true
for _, rule := range rules {
exists, err := ipt.Exists(iptables.Nat, SWIFTPOSTROUTING, rule...)
if err != nil {
return types.FailedToRunIPTableCmd, "[Azure CNS] failed to insert pod SNAT UDP rule : " + err.Error()
return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to check for existence of rule: %v", err)
}
if !exists {
allRulesExist = false
break
}
}

snatPodTCPRuleExists, err := ipt.Exists(iptables.Nat, SWIFT, "-m", "addrtype", "!", "--dst-type", "local", "-s", podSubnet.String(), "-d", networkutils.AzureDNS, "-p", iptables.TCP, "--dport", strconv.Itoa(iptables.DNSPort), "-j", iptables.Snat, "--to", ncPrimaryIP.String())
// get current rule count in SWIFT-POSTROUTING chain
currentRules, err := ipt.List(iptables.Nat, SWIFTPOSTROUTING)
if err != nil {
return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to check for existence of pod SNAT TCP rule : %v", err)
return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to list rules in SWIFT-POSTROUTING chain: %v", err)
}
if !snatPodTCPRuleExists {
logger.Printf("[Azure CNS] Inserting pod SNAT TCP rule ...")
err = ipt.Insert(iptables.Nat, SWIFT, 1, "-m", "addrtype", "!", "--dst-type", "local", "-s", podSubnet.String(), "-d", networkutils.AzureDNS, "-p", iptables.TCP, "--dport", strconv.Itoa(iptables.DNSPort), "-j", iptables.Snat, "--to", ncPrimaryIP.String())

// if rule count doesn't match or not all rules exist, reconcile
// add one because there is always a singular starting rule in the chain, in addition to the ones we add
if len(currentRules) != len(rules)+1 || !allRulesExist {
logger.Printf("[Azure CNS] Reconciling SWIFT-POSTROUTING chain rules")

err = ipt.ClearChain(iptables.Nat, SWIFTPOSTROUTING)
if err != nil {
return types.FailedToRunIPTableCmd, "[Azure CNS] failed to insert pod SNAT TCP rule : " + err.Error()
return types.FailedToRunIPTableCmd, "[Azure CNS] failed to flush SWIFT-POSTROUTING chain : " + err.Error()
}
}

snatIMDSRuleexist, err := ipt.Exists(iptables.Nat, SWIFT, "-m", "addrtype", "!", "--dst-type", "local", "-s", podSubnet.String(), "-d", networkutils.AzureIMDS, "-p", iptables.TCP, "--dport", strconv.Itoa(iptables.HTTPPort), "-j", iptables.Snat, "--to", req.HostPrimaryIP)
if err != nil {
return types.UnexpectedError, fmt.Sprintf("[Azure CNS] Error. Failed to check for existence of pod SNAT IMDS rule : %v", err)
}
if !snatIMDSRuleexist {
logger.Printf("[Azure CNS] Inserting pod SNAT IMDS rule ...")
err = ipt.Insert(iptables.Nat, SWIFT, 1, "-m", "addrtype", "!", "--dst-type", "local", "-s", podSubnet.String(), "-d", networkutils.AzureIMDS, "-p", iptables.TCP, "--dport", strconv.Itoa(iptables.HTTPPort), "-j", iptables.Snat, "--to", req.HostPrimaryIP)
if err != nil {
return types.FailedToRunIPTableCmd, "[Azure CNS] failed to insert pod SNAT IMDS rule : " + err.Error()
for _, rule := range rules {
err = ipt.Append(iptables.Nat, SWIFTPOSTROUTING, rule...)
if err != nil {
return types.FailedToRunIPTableCmd, "[Azure CNS] failed to append rule to SWIFT-POSTROUTING chain : " + err.Error()
}
}
}

Expand Down
Loading
Loading