You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

1412 lines
38 KiB

  1. /*
  2. * This file is part of PowerDNS or dnsdist.
  3. * Copyright -- PowerDNS.COM B.V. and its contributors
  4. *
  5. * This program is free software; you can redistribute it and/or modify
  6. * it under the terms of version 2 of the GNU General Public License as
  7. * published by the Free Software Foundation.
  8. *
  9. * In addition, for the avoidance of any doubt, permission is granted to
  10. * link this program with OpenSSL and to (re)distribute the binaries
  11. * produced as the result of such linking.
  12. *
  13. * This program is distributed in the hope that it will be useful,
  14. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  15. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  16. * GNU General Public License for more details.
  17. *
  18. * You should have received a copy of the GNU General Public License
  19. * along with this program; if not, write to the Free Software
  20. * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
  21. */
  22. #pragma once
  23. #include <string>
  24. #include <sys/socket.h>
  25. #include <netinet/in.h>
  26. #include <arpa/inet.h>
  27. #include <iostream>
  28. #include <stdio.h>
  29. #include <functional>
  30. #include <bitset>
  31. #include "pdnsexception.hh"
  32. #include "misc.hh"
  33. #include <sys/socket.h>
  34. #include <netdb.h>
  35. #include <sstream>
  36. #include <boost/tuple/tuple.hpp>
  37. #include <boost/tuple/tuple_comparison.hpp>
  38. #include "namespaces.hh"
  39. #ifdef __APPLE__
  40. #include <libkern/OSByteOrder.h>
  41. #define htobe16(x) OSSwapHostToBigInt16(x)
  42. #define htole16(x) OSSwapHostToLittleInt16(x)
  43. #define be16toh(x) OSSwapBigToHostInt16(x)
  44. #define le16toh(x) OSSwapLittleToHostInt16(x)
  45. #define htobe32(x) OSSwapHostToBigInt32(x)
  46. #define htole32(x) OSSwapHostToLittleInt32(x)
  47. #define be32toh(x) OSSwapBigToHostInt32(x)
  48. #define le32toh(x) OSSwapLittleToHostInt32(x)
  49. #define htobe64(x) OSSwapHostToBigInt64(x)
  50. #define htole64(x) OSSwapHostToLittleInt64(x)
  51. #define be64toh(x) OSSwapBigToHostInt64(x)
  52. #define le64toh(x) OSSwapLittleToHostInt64(x)
  53. #endif
  54. #ifdef __sun
  55. #define htobe16(x) BE_16(x)
  56. #define htole16(x) LE_16(x)
  57. #define be16toh(x) BE_IN16(&(x))
  58. #define le16toh(x) LE_IN16(&(x))
  59. #define htobe32(x) BE_32(x)
  60. #define htole32(x) LE_32(x)
  61. #define be32toh(x) BE_IN32(&(x))
  62. #define le32toh(x) LE_IN32(&(x))
  63. #define htobe64(x) BE_64(x)
  64. #define htole64(x) LE_64(x)
  65. #define be64toh(x) BE_IN64(&(x))
  66. #define le64toh(x) LE_IN64(&(x))
  67. #endif
  68. #ifdef __FreeBSD__
  69. #include <sys/endian.h>
  70. #endif
  71. #if defined(__NetBSD__) && defined(IP_PKTINFO) && !defined(IP_SENDSRCADDR)
  72. // The IP_PKTINFO option in NetBSD was incompatible with Linux until a
  73. // change that also introduced IP_SENDSRCADDR for FreeBSD compatibility.
  74. #undef IP_PKTINFO
  75. #endif
  76. union ComboAddress {
  77. struct sockaddr_in sin4;
  78. struct sockaddr_in6 sin6;
  79. bool operator==(const ComboAddress& rhs) const
  80. {
  81. if(boost::tie(sin4.sin_family, sin4.sin_port) != boost::tie(rhs.sin4.sin_family, rhs.sin4.sin_port))
  82. return false;
  83. if(sin4.sin_family == AF_INET)
  84. return sin4.sin_addr.s_addr == rhs.sin4.sin_addr.s_addr;
  85. else
  86. return memcmp(&sin6.sin6_addr.s6_addr, &rhs.sin6.sin6_addr.s6_addr, sizeof(sin6.sin6_addr.s6_addr))==0;
  87. }
  88. bool operator!=(const ComboAddress& rhs) const
  89. {
  90. return(!operator==(rhs));
  91. }
  92. bool operator<(const ComboAddress& rhs) const
  93. {
  94. if(sin4.sin_family == 0) {
  95. return false;
  96. }
  97. if(boost::tie(sin4.sin_family, sin4.sin_port) < boost::tie(rhs.sin4.sin_family, rhs.sin4.sin_port))
  98. return true;
  99. if(boost::tie(sin4.sin_family, sin4.sin_port) > boost::tie(rhs.sin4.sin_family, rhs.sin4.sin_port))
  100. return false;
  101. if(sin4.sin_family == AF_INET)
  102. return sin4.sin_addr.s_addr < rhs.sin4.sin_addr.s_addr;
  103. else
  104. return memcmp(&sin6.sin6_addr.s6_addr, &rhs.sin6.sin6_addr.s6_addr, sizeof(sin6.sin6_addr.s6_addr)) < 0;
  105. }
  106. bool operator>(const ComboAddress& rhs) const
  107. {
  108. return rhs.operator<(*this);
  109. }
  110. struct addressOnlyHash
  111. {
  112. uint32_t operator()(const ComboAddress& ca) const
  113. {
  114. const unsigned char* start;
  115. int len;
  116. if(ca.sin4.sin_family == AF_INET) {
  117. start =(const unsigned char*)&ca.sin4.sin_addr.s_addr;
  118. len=4;
  119. }
  120. else {
  121. start =(const unsigned char*)&ca.sin6.sin6_addr.s6_addr;
  122. len=16;
  123. }
  124. return burtle(start, len, 0);
  125. }
  126. };
  127. struct addressOnlyLessThan: public std::binary_function<ComboAddress, ComboAddress, bool>
  128. {
  129. bool operator()(const ComboAddress& a, const ComboAddress& b) const
  130. {
  131. if(a.sin4.sin_family < b.sin4.sin_family)
  132. return true;
  133. if(a.sin4.sin_family > b.sin4.sin_family)
  134. return false;
  135. if(a.sin4.sin_family == AF_INET)
  136. return a.sin4.sin_addr.s_addr < b.sin4.sin_addr.s_addr;
  137. else
  138. return memcmp(&a.sin6.sin6_addr.s6_addr, &b.sin6.sin6_addr.s6_addr, sizeof(a.sin6.sin6_addr.s6_addr)) < 0;
  139. }
  140. };
  141. struct addressOnlyEqual: public std::binary_function<ComboAddress, ComboAddress, bool>
  142. {
  143. bool operator()(const ComboAddress& a, const ComboAddress& b) const
  144. {
  145. if(a.sin4.sin_family != b.sin4.sin_family)
  146. return false;
  147. if(a.sin4.sin_family == AF_INET)
  148. return a.sin4.sin_addr.s_addr == b.sin4.sin_addr.s_addr;
  149. else
  150. return !memcmp(&a.sin6.sin6_addr.s6_addr, &b.sin6.sin6_addr.s6_addr, sizeof(a.sin6.sin6_addr.s6_addr));
  151. }
  152. };
  153. socklen_t getSocklen() const
  154. {
  155. if(sin4.sin_family == AF_INET)
  156. return sizeof(sin4);
  157. else
  158. return sizeof(sin6);
  159. }
  160. ComboAddress()
  161. {
  162. sin4.sin_family=AF_INET;
  163. sin4.sin_addr.s_addr=0;
  164. sin4.sin_port=0;
  165. sin6.sin6_scope_id = 0;
  166. sin6.sin6_flowinfo = 0;
  167. }
  168. ComboAddress(const struct sockaddr *sa, socklen_t salen) {
  169. setSockaddr(sa, salen);
  170. };
  171. ComboAddress(const struct sockaddr_in6 *sa) {
  172. setSockaddr((const struct sockaddr*)sa, sizeof(struct sockaddr_in6));
  173. };
  174. ComboAddress(const struct sockaddr_in *sa) {
  175. setSockaddr((const struct sockaddr*)sa, sizeof(struct sockaddr_in));
  176. };
  177. void setSockaddr(const struct sockaddr *sa, socklen_t salen) {
  178. if (salen > sizeof(struct sockaddr_in6)) throw PDNSException("ComboAddress can't handle other than sockaddr_in or sockaddr_in6");
  179. memcpy(this, sa, salen);
  180. }
  181. // 'port' sets a default value in case 'str' does not set a port
  182. explicit ComboAddress(const string& str, uint16_t port=0)
  183. {
  184. memset(&sin6, 0, sizeof(sin6));
  185. sin4.sin_family = AF_INET;
  186. sin4.sin_port = 0;
  187. if(makeIPv4sockaddr(str, &sin4)) {
  188. sin6.sin6_family = AF_INET6;
  189. if(makeIPv6sockaddr(str, &sin6) < 0)
  190. throw PDNSException("Unable to convert presentation address '"+ str +"'");
  191. }
  192. if(!sin4.sin_port) // 'str' overrides port!
  193. sin4.sin_port=htons(port);
  194. }
  195. bool isIPv6() const
  196. {
  197. return sin4.sin_family == AF_INET6;
  198. }
  199. bool isIPv4() const
  200. {
  201. return sin4.sin_family == AF_INET;
  202. }
  203. bool isMappedIPv4() const
  204. {
  205. if(sin4.sin_family!=AF_INET6)
  206. return false;
  207. int n=0;
  208. const unsigned char*ptr = (unsigned char*) &sin6.sin6_addr.s6_addr;
  209. for(n=0; n < 10; ++n)
  210. if(ptr[n])
  211. return false;
  212. for(; n < 12; ++n)
  213. if(ptr[n]!=0xff)
  214. return false;
  215. return true;
  216. }
  217. ComboAddress mapToIPv4() const
  218. {
  219. if(!isMappedIPv4())
  220. throw PDNSException("ComboAddress can't map non-mapped IPv6 address back to IPv4");
  221. ComboAddress ret;
  222. ret.sin4.sin_family=AF_INET;
  223. ret.sin4.sin_port=sin4.sin_port;
  224. const unsigned char*ptr = (unsigned char*) &sin6.sin6_addr.s6_addr;
  225. ptr+=(sizeof(sin6.sin6_addr.s6_addr) - sizeof(ret.sin4.sin_addr.s_addr));
  226. memcpy(&ret.sin4.sin_addr.s_addr, ptr, sizeof(ret.sin4.sin_addr.s_addr));
  227. return ret;
  228. }
  229. string toString() const
  230. {
  231. char host[1024];
  232. int retval = 0;
  233. if(sin4.sin_family && !(retval = getnameinfo((struct sockaddr*) this, getSocklen(), host, sizeof(host),0, 0, NI_NUMERICHOST)))
  234. return string(host);
  235. else
  236. return "invalid "+string(gai_strerror(retval));
  237. }
  238. string toStringWithPort() const
  239. {
  240. if(sin4.sin_family==AF_INET)
  241. return toString() + ":" + std::to_string(ntohs(sin4.sin_port));
  242. else
  243. return "["+toString() + "]:" + std::to_string(ntohs(sin4.sin_port));
  244. }
  245. string toStringWithPortExcept(int port) const
  246. {
  247. if(ntohs(sin4.sin_port) == port)
  248. return toString();
  249. if(sin4.sin_family==AF_INET)
  250. return toString() + ":" + std::to_string(ntohs(sin4.sin_port));
  251. else
  252. return "["+toString() + "]:" + std::to_string(ntohs(sin4.sin_port));
  253. }
  254. string toLogString() const
  255. {
  256. return toStringWithPortExcept(53);
  257. }
  258. void truncate(unsigned int bits) noexcept;
  259. uint16_t getPort() const
  260. {
  261. return ntohs(sin4.sin_port);
  262. }
  263. void setPort(uint16_t port)
  264. {
  265. sin4.sin_port = htons(port);
  266. }
  267. void reset()
  268. {
  269. memset(&sin4, 0, sizeof(sin4));
  270. memset(&sin6, 0, sizeof(sin6));
  271. }
  272. //! Get the total number of address bits (either 32 or 128 depending on IP version)
  273. uint8_t getBits() const
  274. {
  275. if (isIPv4())
  276. return 32;
  277. if (isIPv6())
  278. return 128;
  279. return 0;
  280. }
  281. /** Get the value of the bit at the provided bit index. When the index >= 0,
  282. the index is relative to the LSB starting at index zero. When the index < 0,
  283. the index is relative to the MSB starting at index -1 and counting down.
  284. */
  285. bool getBit(int index) const
  286. {
  287. if(isIPv4()) {
  288. if (index >= 32)
  289. return false;
  290. if (index < 0) {
  291. if (index < -32)
  292. return false;
  293. index = 32 + index;
  294. }
  295. uint32_t ls_addr = ntohl(sin4.sin_addr.s_addr);
  296. return ((ls_addr & (1<<index)) != 0x00000000);
  297. }
  298. if(isIPv6()) {
  299. if (index >= 128)
  300. return false;
  301. if (index < 0) {
  302. if (index < -128)
  303. return false;
  304. index = 128 + index;
  305. }
  306. uint8_t *ls_addr = (uint8_t*)sin6.sin6_addr.s6_addr;
  307. uint8_t byte_idx = index / 8;
  308. uint8_t bit_idx = index % 8;
  309. return ((ls_addr[15-byte_idx] & (1 << bit_idx)) != 0x00);
  310. }
  311. return false;
  312. }
  313. };
  314. /** This exception is thrown by the Netmask class and by extension by the NetmaskGroup class */
  315. class NetmaskException: public PDNSException
  316. {
  317. public:
  318. NetmaskException(const string &a) : PDNSException(a) {}
  319. };
  320. inline ComboAddress makeComboAddress(const string& str)
  321. {
  322. ComboAddress address;
  323. address.sin4.sin_family=AF_INET;
  324. if(inet_pton(AF_INET, str.c_str(), &address.sin4.sin_addr) <= 0) {
  325. address.sin4.sin_family=AF_INET6;
  326. if(makeIPv6sockaddr(str, &address.sin6) < 0)
  327. throw NetmaskException("Unable to convert '"+str+"' to a netmask");
  328. }
  329. return address;
  330. }
  331. inline ComboAddress makeComboAddressFromRaw(uint8_t version, const char* raw, size_t len)
  332. {
  333. ComboAddress address;
  334. if (version == 4) {
  335. address.sin4.sin_family = AF_INET;
  336. if (len != sizeof(address.sin4.sin_addr)) throw NetmaskException("invalid raw address length");
  337. memcpy(&address.sin4.sin_addr, raw, sizeof(address.sin4.sin_addr));
  338. }
  339. else if (version == 6) {
  340. address.sin6.sin6_family = AF_INET6;
  341. if (len != sizeof(address.sin6.sin6_addr)) throw NetmaskException("invalid raw address length");
  342. memcpy(&address.sin6.sin6_addr, raw, sizeof(address.sin6.sin6_addr));
  343. }
  344. else throw NetmaskException("invalid address family");
  345. return address;
  346. }
  347. inline ComboAddress makeComboAddressFromRaw(uint8_t version, const string &str)
  348. {
  349. return makeComboAddressFromRaw(version, str.c_str(), str.size());
  350. }
  351. /** This class represents a netmask and can be queried to see if a certain
  352. IP address is matched by this mask */
  353. class Netmask
  354. {
  355. public:
  356. Netmask()
  357. {
  358. d_network.sin4.sin_family = 0; // disable this doing anything useful
  359. d_network.sin4.sin_port = 0; // this guarantees d_network compares identical
  360. d_mask = 0;
  361. d_bits = 0;
  362. }
  363. Netmask(const ComboAddress& network, uint8_t bits=0xff): d_network(network)
  364. {
  365. d_network.sin4.sin_port = 0;
  366. setBits(network.isIPv4() ? std::min(bits, static_cast<uint8_t>(32)) : std::min(bits, static_cast<uint8_t>(128)));
  367. }
  368. void setBits(uint8_t value)
  369. {
  370. d_bits = value;
  371. if (d_bits < 32) {
  372. d_mask = ~(0xFFFFFFFF >> d_bits);
  373. }
  374. else {
  375. // note that d_mask is unused for IPv6
  376. d_mask = 0xFFFFFFFF;
  377. }
  378. if (isIPv4()) {
  379. d_network.sin4.sin_addr.s_addr = htonl(ntohl(d_network.sin4.sin_addr.s_addr) & d_mask);
  380. }
  381. else if (isIPv6()) {
  382. uint8_t bytes = d_bits/8;
  383. uint8_t *us = (uint8_t*) &d_network.sin6.sin6_addr.s6_addr;
  384. uint8_t bits = d_bits % 8;
  385. uint8_t mask = (uint8_t) ~(0xFF>>bits);
  386. if (bytes < sizeof(d_network.sin6.sin6_addr.s6_addr)) {
  387. us[bytes] &= mask;
  388. }
  389. for(size_t idx = bytes + 1; idx < sizeof(d_network.sin6.sin6_addr.s6_addr); ++idx) {
  390. us[idx] = 0;
  391. }
  392. }
  393. }
  394. //! Constructor supplies the mask, which cannot be changed
  395. Netmask(const string &mask)
  396. {
  397. pair<string,string> split = splitField(mask,'/');
  398. d_network = makeComboAddress(split.first);
  399. if (!split.second.empty()) {
  400. setBits(static_cast<uint8_t>(pdns_stou(split.second)));
  401. }
  402. else if (d_network.sin4.sin_family == AF_INET) {
  403. setBits(32);
  404. }
  405. else {
  406. setBits(128);
  407. }
  408. }
  409. bool match(const ComboAddress& ip) const
  410. {
  411. return match(&ip);
  412. }
  413. //! If this IP address in socket address matches
  414. bool match(const ComboAddress *ip) const
  415. {
  416. if(d_network.sin4.sin_family != ip->sin4.sin_family) {
  417. return false;
  418. }
  419. if(d_network.sin4.sin_family == AF_INET) {
  420. return match4(htonl((unsigned int)ip->sin4.sin_addr.s_addr));
  421. }
  422. if(d_network.sin6.sin6_family == AF_INET6) {
  423. uint8_t bytes=d_bits/8, n;
  424. const uint8_t *us=(const uint8_t*) &d_network.sin6.sin6_addr.s6_addr;
  425. const uint8_t *them=(const uint8_t*) &ip->sin6.sin6_addr.s6_addr;
  426. for(n=0; n < bytes; ++n) {
  427. if(us[n]!=them[n]) {
  428. return false;
  429. }
  430. }
  431. // still here, now match remaining bits
  432. uint8_t bits= d_bits % 8;
  433. uint8_t mask= (uint8_t) ~(0xFF>>bits);
  434. return((us[n]) == (them[n] & mask));
  435. }
  436. return false;
  437. }
  438. //! If this ASCII IP address matches
  439. bool match(const string &ip) const
  440. {
  441. ComboAddress address=makeComboAddress(ip);
  442. return match(&address);
  443. }
  444. //! If this IP address in native format matches
  445. bool match4(uint32_t ip) const
  446. {
  447. return (ip & d_mask) == (ntohl(d_network.sin4.sin_addr.s_addr));
  448. }
  449. string toString() const
  450. {
  451. return d_network.toString()+"/"+std::to_string((unsigned int)d_bits);
  452. }
  453. string toStringNoMask() const
  454. {
  455. return d_network.toString();
  456. }
  457. const ComboAddress& getNetwork() const
  458. {
  459. return d_network;
  460. }
  461. const ComboAddress& getMaskedNetwork() const
  462. {
  463. return getNetwork();
  464. }
  465. uint8_t getBits() const
  466. {
  467. return d_bits;
  468. }
  469. bool isIPv6() const
  470. {
  471. return d_network.sin6.sin6_family == AF_INET6;
  472. }
  473. bool isIPv4() const
  474. {
  475. return d_network.sin4.sin_family == AF_INET;
  476. }
  477. bool operator<(const Netmask& rhs) const
  478. {
  479. if (empty() && !rhs.empty())
  480. return false;
  481. if (!empty() && rhs.empty())
  482. return true;
  483. if (d_bits > rhs.d_bits)
  484. return true;
  485. if (d_bits < rhs.d_bits)
  486. return false;
  487. return d_network < rhs.d_network;
  488. }
  489. bool operator>(const Netmask& rhs) const
  490. {
  491. return rhs.operator<(*this);
  492. }
  493. bool operator==(const Netmask& rhs) const
  494. {
  495. return tie(d_network, d_bits) == tie(rhs.d_network, rhs.d_bits);
  496. }
  497. bool empty() const
  498. {
  499. return d_network.sin4.sin_family==0;
  500. }
  501. //! Get normalized version of the netmask. This means that all address bits below the network bits are zero.
  502. Netmask getNormalized() const {
  503. return Netmask(getMaskedNetwork(), d_bits);
  504. }
  505. //! Get Netmask for super network of this one (i.e. with fewer network bits)
  506. Netmask getSuper(uint8_t bits) const {
  507. return Netmask(d_network, std::min(d_bits, bits));
  508. }
  509. //! Get the total number of address bits for this netmask (either 32 or 128 depending on IP version)
  510. uint8_t getAddressBits() const
  511. {
  512. return d_network.getBits();
  513. }
  514. /** Get the value of the bit at the provided bit index. When the index >= 0,
  515. the index is relative to the LSB starting at index zero. When the index < 0,
  516. the index is relative to the MSB starting at index -1 and counting down.
  517. When the index points outside the network bits, it always yields zero.
  518. */
  519. bool getBit(int bit) const
  520. {
  521. if (bit < -d_bits)
  522. return false;
  523. if (bit >= 0) {
  524. if(isIPv4()) {
  525. if (bit >= 32 || bit < (32 - d_bits))
  526. return false;
  527. }
  528. if(isIPv6()) {
  529. if (bit >= 128 || bit < (128 - d_bits))
  530. return false;
  531. }
  532. }
  533. return d_network.getBit(bit);
  534. }
  535. private:
  536. ComboAddress d_network;
  537. uint32_t d_mask;
  538. uint8_t d_bits;
  539. };
  540. /** Binary tree map implementation with <Netmask,T> pair.
  541. *
  542. * This is an binary tree implementation for storing attributes for IPv4 and IPv6 prefixes.
  543. * The most simple use case is simple NetmaskTree<bool> used by NetmaskGroup, which only
  544. * wants to know if given IP address is matched in the prefixes stored.
  545. *
  546. * This element is useful for anything that needs to *STORE* prefixes, and *MATCH* IP addresses
  547. * to a *LIST* of *PREFIXES*. Not the other way round.
  548. *
  549. * You can store IPv4 and IPv6 addresses to same tree, separate payload storage is kept per AFI.
  550. * Network prefixes (Netmasks) are always recorded in normalized fashion, meaning that only
  551. * the network bits are set. This is what is returned in the insert() and lookup() return
  552. * values.
  553. *
  554. * Use swap if you need to move the tree to another NetmaskTree instance, it is WAY faster
  555. * than using copy ctor or assignment operator, since it moves the nodes and tree root to
  556. * new home instead of actually recreating the tree.
  557. *
  558. * Please see NetmaskGroup for example of simple use case. Other usecases can be found
  559. * from GeoIPBackend and Sortlist, and from dnsdist.
  560. */
  561. template <typename T>
  562. class NetmaskTree {
  563. public:
  564. class Iterator;
  565. typedef Netmask key_type;
  566. typedef T value_type;
  567. typedef std::pair<const key_type,value_type> node_type;
  568. typedef size_t size_type;
  569. typedef class Iterator iterator;
  570. private:
  571. /** Single node in tree, internal use only.
  572. */
  573. class TreeNode : boost::noncopyable {
  574. public:
  575. explicit TreeNode() noexcept :
  576. parent(nullptr), node(), assigned(false), d_bits(0) {
  577. }
  578. explicit TreeNode(const key_type& key) noexcept :
  579. parent(nullptr), node({key.getNormalized(), value_type()}),
  580. assigned(false), d_bits(key.getAddressBits()) {
  581. }
  582. //<! Makes a left leaf node with specified key.
  583. TreeNode* make_left(const key_type& key) {
  584. d_bits = node.first.getBits();
  585. left = make_unique<TreeNode>(key);
  586. left->parent = this;
  587. return left.get();
  588. }
  589. //<! Makes a right leaf node with specified key.
  590. TreeNode* make_right(const key_type& key) {
  591. d_bits = node.first.getBits();
  592. right = make_unique<TreeNode>(key);
  593. right->parent = this;
  594. return right.get();
  595. }
  596. //<! Splits branch at indicated bit position by inserting key
  597. TreeNode* split(const key_type& key, int bits) {
  598. if (parent == nullptr) {
  599. // not to be called on the root node
  600. throw std::logic_error(
  601. "NetmaskTree::TreeNode::split(): must not be called on root node");
  602. }
  603. // determine reference from parent
  604. unique_ptr<TreeNode>& parent_ref =
  605. (parent->left.get() == this ? parent->left : parent->right);
  606. if (parent_ref.get() != this) {
  607. throw std::logic_error(
  608. "NetmaskTree::TreeNode::split(): parent node reference is invalid");
  609. }
  610. // create new tree node for the new key
  611. TreeNode* new_node = new TreeNode(key);
  612. new_node->d_bits = bits;
  613. // attach the new node under our former parent
  614. unique_ptr<TreeNode> new_child(new_node);
  615. std::swap(parent_ref, new_child); // hereafter new_child points to "this"
  616. new_node->parent = parent;
  617. // attach "this" node below the new node
  618. // (left or right depending on bit)
  619. new_child->parent = new_node;
  620. if (new_child->node.first.getBit(-1-bits)) {
  621. std::swap(new_node->right, new_child);
  622. } else {
  623. std::swap(new_node->left, new_child);
  624. }
  625. return new_node;
  626. }
  627. //<! Forks branch for new key at indicated bit position
  628. TreeNode* fork(const key_type& key, int bits) {
  629. if (parent == nullptr) {
  630. // not to be called on the root node
  631. throw std::logic_error(
  632. "NetmaskTree::TreeNode::fork(): must not be called on root node");
  633. }
  634. // determine reference from parent
  635. unique_ptr<TreeNode>& parent_ref =
  636. (parent->left.get() == this ? parent->left : parent->right);
  637. if (parent_ref.get() != this) {
  638. throw std::logic_error(
  639. "NetmaskTree::TreeNode::fork(): parent node reference is invalid");
  640. }
  641. // create new tree node for the branch point
  642. TreeNode* branch_node = new TreeNode(node.first.getSuper(bits));
  643. branch_node->d_bits = bits;
  644. // attach the branch node under our former parent
  645. unique_ptr<TreeNode> new_child1(branch_node);
  646. std::swap(parent_ref, new_child1); // hereafter new_child1 points to "this"
  647. branch_node->parent = parent;
  648. // create second new leaf node for the new key
  649. TreeNode* new_node = new TreeNode(key);
  650. unique_ptr<TreeNode> new_child2(new_node);
  651. // attach the new child nodes below the branch node
  652. // (left or right depending on bit)
  653. new_child1->parent = branch_node;
  654. new_child2->parent = branch_node;
  655. if (new_child1->node.first.getBit(-1-bits)) {
  656. std::swap(branch_node->right, new_child1);
  657. std::swap(branch_node->left, new_child2);
  658. } else {
  659. std::swap(branch_node->right, new_child2);
  660. std::swap(branch_node->left, new_child1);
  661. }
  662. return new_node;
  663. }
  664. //<! Traverse left branch depth-first
  665. TreeNode *traverse_l()
  666. {
  667. TreeNode *tnode = this;
  668. while (tnode->left)
  669. tnode = tnode->left.get();
  670. return tnode;
  671. }
  672. //<! Traverse tree depth-first and in-order (L-N-R)
  673. TreeNode *traverse_lnr()
  674. {
  675. TreeNode *tnode = this;
  676. // precondition: descended left as deep as possible
  677. if (tnode->right) {
  678. // descend right
  679. tnode = tnode->right.get();
  680. // descend left as deep as possible and return next node
  681. return tnode->traverse_l();
  682. }
  683. // ascend to parent
  684. while (tnode->parent != nullptr) {
  685. TreeNode *prev_child = tnode;
  686. tnode = tnode->parent;
  687. // return this node, but only when we come from the left child branch
  688. if (tnode->left && tnode->left.get() == prev_child)
  689. return tnode;
  690. }
  691. return nullptr;
  692. }
  693. //<! Traverse only assigned nodes
  694. TreeNode *traverse_lnr_assigned()
  695. {
  696. TreeNode *tnode = traverse_lnr();
  697. while (tnode != nullptr && !tnode->assigned)
  698. tnode = tnode->traverse_lnr();
  699. return tnode;
  700. }
  701. unique_ptr<TreeNode> left;
  702. unique_ptr<TreeNode> right;
  703. TreeNode* parent;
  704. node_type node;
  705. bool assigned; //<! Whether this node is assigned-to by the application
  706. int d_bits; //<! How many bits have been used so far
  707. };
  708. void cleanup_tree(TreeNode* node)
  709. {
  710. // only cleanup this node if it has no children and node not assigned
  711. if (!(node->left || node->right || node->assigned)) {
  712. // get parent node ptr
  713. TreeNode* pparent = node->parent;
  714. // delete this node
  715. if (pparent) {
  716. if (pparent->left.get() == node)
  717. pparent->left.reset();
  718. else
  719. pparent->right.reset();
  720. // now recurse up to the parent
  721. cleanup_tree(pparent);
  722. }
  723. }
  724. }
  725. void copyTree(const NetmaskTree& rhs)
  726. {
  727. TreeNode *node;
  728. node = rhs.d_root.get();
  729. if (node != nullptr)
  730. node = node->traverse_l();
  731. while (node != nullptr) {
  732. if (node->assigned)
  733. insert(node->node.first).second = node->node.second;
  734. node = node->traverse_lnr();
  735. }
  736. }
  737. public:
  738. class Iterator {
  739. public:
  740. typedef node_type value_type;
  741. typedef node_type& reference;
  742. typedef node_type* pointer;
  743. typedef std::forward_iterator_tag iterator_category;
  744. typedef size_type difference_type;
  745. private:
  746. friend class NetmaskTree;
  747. const NetmaskTree* d_tree;
  748. TreeNode* d_node;
  749. Iterator(const NetmaskTree* tree, TreeNode* node): d_tree(tree), d_node(node) {
  750. }
  751. public:
  752. Iterator(): d_tree(nullptr), d_node(nullptr) {}
  753. Iterator& operator++() // prefix
  754. {
  755. if (d_node == nullptr) {
  756. throw std::logic_error(
  757. "NetmaskTree::Iterator::operator++: iterator is invalid");
  758. }
  759. d_node = d_node->traverse_lnr_assigned();
  760. return *this;
  761. }
  762. Iterator operator++(int) // postfix
  763. {
  764. Iterator tmp(*this);
  765. operator++();
  766. return tmp;
  767. }
  768. reference operator*()
  769. {
  770. if (d_node == nullptr) {
  771. throw std::logic_error(
  772. "NetmaskTree::Iterator::operator*: iterator is invalid");
  773. }
  774. return d_node->node;
  775. }
  776. pointer operator->()
  777. {
  778. if (d_node == nullptr) {
  779. throw std::logic_error(
  780. "NetmaskTree::Iterator::operator->: iterator is invalid");
  781. }
  782. return &d_node->node;
  783. }
  784. bool operator==(const Iterator& rhs)
  785. {
  786. return (d_tree == rhs.d_tree && d_node == rhs.d_node);
  787. }
  788. bool operator!=(const Iterator& rhs)
  789. {
  790. return !(*this == rhs);
  791. }
  792. };
  793. public:
  794. NetmaskTree() noexcept: d_root(new TreeNode()), d_left(nullptr), d_size(0) {
  795. }
  796. NetmaskTree(const NetmaskTree& rhs): d_root(new TreeNode()), d_left(nullptr), d_size(0) {
  797. copyTree(rhs);
  798. }
  799. NetmaskTree& operator=(const NetmaskTree& rhs) {
  800. clear();
  801. copyTree(rhs);
  802. return *this;
  803. }
  804. const iterator begin() const {
  805. return Iterator(this, d_left);
  806. }
  807. const iterator end() const {
  808. return Iterator(this, nullptr);
  809. }
  810. iterator begin() {
  811. return Iterator(this, d_left);
  812. }
  813. iterator end() {
  814. return Iterator(this, nullptr);
  815. }
  816. node_type& insert(const string &mask) {
  817. return insert(key_type(mask));
  818. }
  819. //<! Creates new value-pair in tree and returns it.
  820. node_type& insert(const key_type& key) {
  821. TreeNode* node;
  822. bool is_left = true;
  823. // we turn left on IPv4 and right on IPv6
  824. if (key.isIPv4()) {
  825. node = d_root->left.get();
  826. if (node == nullptr) {
  827. node = new TreeNode(key);
  828. node->assigned = true;
  829. node->parent = d_root.get();
  830. d_root->left = unique_ptr<TreeNode>(node);
  831. d_size++;
  832. d_left = node;
  833. return node->node;
  834. }
  835. } else if (key.isIPv6()) {
  836. node = d_root->right.get();
  837. if (node == nullptr) {
  838. node = new TreeNode(key);
  839. node->assigned = true;
  840. node->parent = d_root.get();
  841. d_root->right = unique_ptr<TreeNode>(node);
  842. d_size++;
  843. if (!d_root->left)
  844. d_left = node;
  845. return node->node;
  846. }
  847. if (d_root->left)
  848. is_left = false;
  849. } else
  850. throw NetmaskException("invalid address family");
  851. // we turn left on 0 and right on 1
  852. int bits = 0;
  853. for(; bits < key.getBits(); bits++) {
  854. bool vall = key.getBit(-1-bits);
  855. if (bits >= node->d_bits) {
  856. // the end of the current node is reached; continue with the next
  857. if (vall) {
  858. if (node->left || node->assigned)
  859. is_left = false;
  860. if (!node->right) {
  861. // the right branch doesn't exist yet; attach our key here
  862. node = node->make_right(key);
  863. break;
  864. }
  865. node = node->right.get();
  866. } else {
  867. if (!node->left) {
  868. // the left branch doesn't exist yet; attach our key here
  869. node = node->make_left(key);
  870. break;
  871. }
  872. node = node->left.get();
  873. }
  874. continue;
  875. }
  876. if (bits >= node->node.first.getBits()) {
  877. // the matching branch ends here, yet the key netmask has more bits; add a
  878. // child node below the existing branch leaf.
  879. if (vall) {
  880. if (node->assigned)
  881. is_left = false;
  882. node = node->make_right(key);
  883. } else {
  884. node = node->make_left(key);
  885. }
  886. break;
  887. }
  888. bool valr = node->node.first.getBit(-1-bits);
  889. if (vall != valr) {
  890. if (vall)
  891. is_left = false;
  892. // the branch matches just upto this point, yet continues in a different
  893. // direction; fork the branch.
  894. node = node->fork(key, bits);
  895. break;
  896. }
  897. }
  898. if (node->node.first.getBits() > key.getBits()) {
  899. // key is a super-network of the matching node; split the branch and
  900. // insert a node for the key above the matching node.
  901. node = node->split(key, key.getBits());
  902. }
  903. if (node->left)
  904. is_left = false;
  905. node_type& value = node->node;
  906. if (!node->assigned) {
  907. // only increment size if not assigned before
  908. d_size++;
  909. // update the pointer to the left-most tree node
  910. if (is_left)
  911. d_left = node;
  912. node->assigned = true;
  913. } else {
  914. // tree node exists for this value
  915. if (is_left && d_left != node) {
  916. throw std::logic_error(
  917. "NetmaskTree::insert(): lost track of left-most node in tree");
  918. }
  919. }
  920. return value;
  921. }
  922. //<! Creates or updates value
  923. void insert_or_assign(const key_type& mask, const value_type& value) {
  924. insert(mask).second = value;
  925. }
  926. void insert_or_assign(const string& mask, const value_type& value) {
  927. insert(key_type(mask)).second = value;
  928. }
  929. //<! check if given key is present in TreeMap
  930. bool has_key(const key_type& key) const {
  931. const node_type *ptr = lookup(key);
  932. return ptr && ptr->first == key;
  933. }
  934. //<! Returns "best match" for key_type, which might not be value
  935. const node_type* lookup(const key_type& value) const {
  936. return lookup(value.getNetwork(), value.getBits());
  937. }
  938. //<! Perform best match lookup for value, using at most max_bits
  939. const node_type* lookup(const ComboAddress& value, int max_bits = 128) const {
  940. TreeNode *node = nullptr;
  941. uint8_t addr_bits = value.getBits();
  942. if (max_bits < 0 || max_bits > addr_bits)
  943. max_bits = addr_bits;
  944. if (value.isIPv4())
  945. node = d_root->left.get();
  946. else if (value.isIPv6())
  947. node = d_root->right.get();
  948. else
  949. throw NetmaskException("invalid address family");
  950. if (node == nullptr) return nullptr;
  951. node_type *ret = nullptr;
  952. int bits = 0;
  953. for(; bits < max_bits; bits++) {
  954. bool vall = value.getBit(-1-bits);
  955. if (bits >= node->d_bits) {
  956. // the end of the current node is reached; continue with the next
  957. // (we keep track of last assigned node)
  958. if (node->assigned && bits == node->node.first.getBits())
  959. ret = &node->node;
  960. if (vall) {
  961. if (!node->right)
  962. break;
  963. node = node->right.get();
  964. } else {
  965. if (!node->left)
  966. break;
  967. node = node->left.get();
  968. }
  969. continue;
  970. }
  971. if (bits >= node->node.first.getBits()) {
  972. // the matching branch ends here
  973. break;
  974. }
  975. bool valr = node->node.first.getBit(-1-bits);
  976. if (vall != valr) {
  977. // the branch matches just upto this point, yet continues in a different
  978. // direction
  979. break;
  980. }
  981. }
  982. // needed if we did not find one in loop
  983. if (node->assigned && bits == node->node.first.getBits())
  984. ret = &node->node;
  985. // this can be nullptr.
  986. return ret;
  987. }
  988. //<! Removes key from TreeMap.
  989. void erase(const key_type& key) {
  990. TreeNode *node = nullptr;
  991. if (key.isIPv4())
  992. node = d_root->left.get();
  993. else if (key.isIPv6())
  994. node = d_root->right.get();
  995. else
  996. throw NetmaskException("invalid address family");
  997. // no tree, no value
  998. if (node == nullptr) return;
  999. int bits = 0;
  1000. for(; node && bits < key.getBits(); bits++) {
  1001. bool vall = key.getBit(-1-bits);
  1002. if (bits >= node->d_bits) {
  1003. // the end of the current node is reached; continue with the next
  1004. if (vall) {
  1005. node = node->right.get();
  1006. } else {
  1007. node = node->left.get();
  1008. }
  1009. continue;
  1010. }
  1011. if (bits >= node->node.first.getBits()) {
  1012. // the matching branch ends here
  1013. if (key.getBits() != node->node.first.getBits())
  1014. node = nullptr;
  1015. break;
  1016. }
  1017. bool valr = node->node.first.getBit(-1-bits);
  1018. if (vall != valr) {
  1019. // the branch matches just upto this point, yet continues in a different
  1020. // direction
  1021. node = nullptr;
  1022. break;
  1023. }
  1024. }
  1025. if (node) {
  1026. if (d_size == 0) {
  1027. throw std::logic_error(
  1028. "NetmaskTree::erase(): size of tree is zero before erase");
  1029. }
  1030. d_size--;
  1031. node->assigned = false;
  1032. node->node.second = value_type();
  1033. if (node == d_left)
  1034. d_left = d_left->traverse_lnr_assigned();
  1035. cleanup_tree(node);
  1036. }
  1037. }
  1038. void erase(const string& key) {
  1039. erase(key_type(key));
  1040. }
  1041. //<! checks whether the container is empty.
  1042. bool empty() const {
  1043. return (d_size == 0);
  1044. }
  1045. //<! returns the number of elements
  1046. size_type size() const {
  1047. return d_size;
  1048. }
  1049. //<! See if given ComboAddress matches any prefix
  1050. bool match(const ComboAddress& value) const {
  1051. return (lookup(value) != nullptr);
  1052. }
  1053. bool match(const std::string& value) const {
  1054. return match(ComboAddress(value));
  1055. }
  1056. //<! Clean out the tree
  1057. void clear() {
  1058. d_root.reset(new TreeNode());
  1059. d_left = nullptr;
  1060. d_size = 0;
  1061. }
  1062. //<! swaps the contents with another NetmaskTree
  1063. void swap(NetmaskTree& rhs) {
  1064. std::swap(d_root, rhs.d_root);
  1065. std::swap(d_left, rhs.d_left);
  1066. std::swap(d_size, rhs.d_size);
  1067. }
  1068. private:
  1069. unique_ptr<TreeNode> d_root; //<! Root of our tree
  1070. TreeNode *d_left;
  1071. size_type d_size;
  1072. };
  1073. /** This class represents a group of supplemental Netmask classes. An IP address matchs
  1074. if it is matched by zero or more of the Netmask classes within.
  1075. */
  1076. class NetmaskGroup
  1077. {
  1078. public:
  1079. NetmaskGroup() noexcept {
  1080. }
  1081. //! If this IP address is matched by any of the classes within
  1082. bool match(const ComboAddress *ip) const
  1083. {
  1084. const auto &ret = tree.lookup(*ip);
  1085. if(ret) return ret->second;
  1086. return false;
  1087. }
  1088. bool match(const ComboAddress& ip) const
  1089. {
  1090. return match(&ip);
  1091. }
  1092. bool lookup(const ComboAddress* ip, Netmask* nmp) const
  1093. {
  1094. const auto &ret = tree.lookup(*ip);
  1095. if (ret) {
  1096. if (nmp != nullptr)
  1097. *nmp = ret->first;
  1098. return ret->second;
  1099. }
  1100. return false;
  1101. }
  1102. bool lookup(const ComboAddress& ip, Netmask* nmp) const
  1103. {
  1104. return lookup(&ip, nmp);
  1105. }
  1106. //! Add this string to the list of possible matches
  1107. void addMask(const string &ip, bool positive=true)
  1108. {
  1109. if(!ip.empty() && ip[0] == '!') {
  1110. addMask(Netmask(ip.substr(1)), false);
  1111. } else {
  1112. addMask(Netmask(ip), positive);
  1113. }
  1114. }
  1115. //! Add this Netmask to the list of possible matches
  1116. void addMask(const Netmask& nm, bool positive=true)
  1117. {
  1118. tree.insert(nm).second=positive;
  1119. }
  1120. //! Delete this Netmask from the list of possible matches
  1121. void deleteMask(const Netmask& nm)
  1122. {
  1123. tree.erase(nm);
  1124. }
  1125. void deleteMask(const std::string& ip)
  1126. {
  1127. if (!ip.empty())
  1128. deleteMask(Netmask(ip));
  1129. }
  1130. void clear()
  1131. {
  1132. tree.clear();
  1133. }
  1134. bool empty() const
  1135. {
  1136. return tree.empty();
  1137. }
  1138. size_t size() const
  1139. {
  1140. return tree.size();
  1141. }
  1142. string toString() const
  1143. {
  1144. ostringstream str;
  1145. for(auto iter = tree.begin(); iter != tree.end(); ++iter) {
  1146. if(iter != tree.begin())
  1147. str <<", ";
  1148. if(!(iter->second))
  1149. str<<"!";
  1150. str<<iter->first.toString();
  1151. }
  1152. return str.str();
  1153. }
  1154. void toStringVector(vector<string>* vec) const
  1155. {
  1156. for(auto iter = tree.begin(); iter != tree.end(); ++iter) {
  1157. vec->push_back((iter->second ? "" : "!") + iter->first.toString());
  1158. }
  1159. }
  1160. void toMasks(const string &ips)
  1161. {
  1162. vector<string> parts;
  1163. stringtok(parts, ips, ", \t");
  1164. for (vector<string>::const_iterator iter = parts.begin(); iter != parts.end(); ++iter)
  1165. addMask(*iter);
  1166. }
  1167. private:
  1168. NetmaskTree<bool> tree;
  1169. };
  1170. struct SComboAddress
  1171. {
  1172. SComboAddress(const ComboAddress& orig) : ca(orig) {}
  1173. ComboAddress ca;
  1174. bool operator<(const SComboAddress& rhs) const
  1175. {
  1176. return ComboAddress::addressOnlyLessThan()(ca, rhs.ca);
  1177. }
  1178. operator const ComboAddress&()
  1179. {
  1180. return ca;
  1181. }
  1182. };
  1183. class NetworkError : public runtime_error
  1184. {
  1185. public:
  1186. NetworkError(const string& why="Network Error") : runtime_error(why.c_str())
  1187. {}
  1188. NetworkError(const char *why="Network Error") : runtime_error(why)
  1189. {}
  1190. };
  1191. int SSocket(int family, int type, int flags);
  1192. int SConnect(int sockfd, const ComboAddress& remote);
  1193. /* tries to connect to remote for a maximum of timeout seconds.
  1194. sockfd should be set to non-blocking beforehand.
  1195. returns 0 on success (the socket is writable), throw a
  1196. runtime_error otherwise */
  1197. int SConnectWithTimeout(int sockfd, const ComboAddress& remote, int timeout);
  1198. int SBind(int sockfd, const ComboAddress& local);
  1199. int SAccept(int sockfd, ComboAddress& remote);
  1200. int SListen(int sockfd, int limit);
  1201. int SSetsockopt(int sockfd, int level, int opname, int value);
  1202. void setSocketIgnorePMTU(int sockfd);
  1203. bool setReusePort(int sockfd);
  1204. #if defined(IP_PKTINFO)
  1205. #define GEN_IP_PKTINFO IP_PKTINFO
  1206. #elif defined(IP_RECVDSTADDR)
  1207. #define GEN_IP_PKTINFO IP_RECVDSTADDR
  1208. #endif
  1209. bool IsAnyAddress(const ComboAddress& addr);
  1210. bool HarvestDestinationAddress(const struct msghdr* msgh, ComboAddress* destination);
  1211. bool HarvestTimestamp(struct msghdr* msgh, struct timeval* tv);
  1212. void fillMSGHdr(struct msghdr* msgh, struct iovec* iov, cmsgbuf_aligned* cbuf, size_t cbufsize, char* data, size_t datalen, ComboAddress* addr);
  1213. ssize_t sendfromto(int sock, const char* data, size_t len, int flags, const ComboAddress& from, const ComboAddress& to);
  1214. size_t sendMsgWithOptions(int fd, const char* buffer, size_t len, const ComboAddress* dest, const ComboAddress* local, unsigned int localItf, int flags);
  1215. /* requires a non-blocking, connected TCP socket */
  1216. bool isTCPSocketUsable(int sock);
  1217. extern template class NetmaskTree<bool>;