diff --git a/iptc/ip4tc.py b/iptc/ip4tc.py index 1efeabe..a953f41 100644 --- a/iptc/ip4tc.py +++ b/iptc/ip4tc.py @@ -460,8 +460,6 @@ class Match(IPTCModule): if self._module.next is not None: self._store_buffer(module) - self._check_alias(module[0], match) - self._match_buf = (ct.c_ubyte * self.size)() if match: ct.memmove(ct.byref(self._match_buf), ct.byref(match), self.size) @@ -503,7 +501,11 @@ class Match(IPTCModule): self._buffer.buffer = ct.cast(module, ct.POINTER(ct.c_ubyte)) def _final_check(self): - self._xt.final_check_match(self._module) + if self._alias is not None: + module = self._alias + else: + module = self._module + self._xt.final_check_match(module) def _parse(self, argv, inv, entry): if self._alias is not None: @@ -530,6 +532,7 @@ class Match(IPTCModule): self._ptrptr = ct.cast(ct.pointer(self._ptr), ct.POINTER(ct.POINTER(xt_entry_match))) self._module.m = self._ptr + self._check_alias(self._module, self._module.m) if self._alias is not None: self._alias.m = self._ptr self._update_name() @@ -613,8 +616,6 @@ class Target(IPTCModule): else: self._revision = self._module.revision - self._check_alias(module[0], target) - self._create_buffer(target) if self._is_standard_target(): @@ -673,7 +674,11 @@ class Target(IPTCModule): return False def _final_check(self): - self._xt.final_check_target(self._module) + if self._alias is not None: + module = self._alias + else: + module = self._module + self._xt.final_check_target(module) def _parse(self, argv, inv, entry): if self._alias is not None: @@ -715,6 +720,7 @@ class Target(IPTCModule): self._ptrptr = ct.cast(ct.pointer(self._ptr), ct.POINTER(ct.POINTER(xt_entry_target))) self._module.t = self._ptr + self._check_alias(self._module, self._module.t) if self._alias is not None: self._alias.t = self._ptr self._update_name() diff --git a/iptc/test/test_matches.py b/iptc/test/test_matches.py index 69b0b01..67c37ff 100755 --- a/iptc/test/test_matches.py +++ b/iptc/test/test_matches.py @@ -298,6 +298,40 @@ class TestXTStateMatch(unittest.TestCase): self.assertEquals(m.state, "RELATED,ESTABLISHED") +class TestXTConntrackMatch(unittest.TestCase): + def setUp(self): + self.rule = iptc.Rule() + self.rule.src = "127.0.0.1" + self.rule.protocol = "tcp" + self.rule.target = iptc.Target(self.rule, "ACCEPT") + + self.match = iptc.Match(self.rule, "conntrack") + + self.chain = iptc.Chain(iptc.Table(iptc.Table.FILTER), + "iptc_test_conntrack") + self.table = iptc.Table(iptc.Table.FILTER) + try: + self.chain.flush() + self.chain.delete() + except: + pass + self.table.create_chain(self.chain) + + def tearDown(self): + self.chain.flush() + self.chain.delete() + pass + + def test_state(self): + self.match.ctstate = "NEW,RELATED" + self.rule.add_match(self.match) + self.chain.insert_rule(self.rule) + rule = self.chain.rules[0] + m = rule.matches[0] + self.assertTrue(m.name, ["conntrack"]) + self.assertEquals(m.ctstate, "NEW,RELATED") + + def suite(): suite_match = unittest.TestLoader().loadTestsFromTestCase(TestMatch) suite_udp = unittest.TestLoader().loadTestsFromTestCase(TestXTUdpMatch) @@ -308,9 +342,11 @@ def suite(): suite_iprange = unittest.TestLoader().loadTestsFromTestCase( TestIprangeMatch) suite_state = unittest.TestLoader().loadTestsFromTestCase(TestXTStateMatch) + suite_conntrack = unittest.TestLoader().loadTestsFromTestCase( + TestXTConntrackMatch) return unittest.TestSuite([suite_match, suite_udp, suite_mark, suite_limit, suite_comment, suite_iprange, - suite_state]) + suite_state, suite_conntrack]) def run_tests(): diff --git a/iptc/test/test_targets.py b/iptc/test/test_targets.py index 6d83f5f..32516bd 100755 --- a/iptc/test/test_targets.py +++ b/iptc/test/test_targets.py @@ -371,6 +371,37 @@ class TestXTNotrackTarget(unittest.TestCase): self.assertTrue(t.name in ["NOTRACK", "CT"]) +class TestXTCtTarget(unittest.TestCase): + def setUp(self): + self.rule = iptc.Rule() + self.rule.dst = "127.0.0.2" + self.rule.protocol = "tcp" + self.rule.out_interface = "eth0" + + self.target = iptc.Target(self.rule, "CT") + self.target.notrack = "true" + self.rule.target = self.target + + self.chain = iptc.Chain(iptc.Table(iptc.Table.RAW), + "iptc_test_ct") + try: + self.chain.flush() + self.chain.delete() + except: + pass + iptc.Table(iptc.Table.RAW).create_chain(self.chain) + + def tearDown(self): + self.chain.flush() + self.chain.delete() + + def test_ct(self): + self.chain.insert_rule(self.rule) + t = self.chain.rules[0].target + self.assertEquals(t.name, "CT") + self.assertTrue(t.notrack is not None) + + def suite(): suites = [] suite_target = unittest.TestLoader().loadTestsFromTestCase(TestTarget) @@ -383,14 +414,15 @@ def suite(): TestIPTMasqueradeTarget) suite_dnat = unittest.TestLoader().loadTestsFromTestCase( TestDnatTarget) - suite_conntrack = unittest.TestLoader().loadTestsFromTestCase( + suite_notrack = unittest.TestLoader().loadTestsFromTestCase( TestXTNotrackTarget) + suite_ct = unittest.TestLoader().loadTestsFromTestCase(TestXTCtTarget) suites.extend([suite_target, suite_cluster, suite_tos]) if is_table_available(iptc.Table.NAT): suites.extend([suite_target, suite_cluster, suite_redir, suite_tos, suite_masq, suite_dnat]) if is_table_available(iptc.Table.RAW): - suites.extend([suite_conntrack]) + suites.extend([suite_notrack, suite_ct]) return unittest.TestSuite(suites)