1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package org.eclipse.jetty.util;
20
21 import java.net.InetAddress;
22 import java.util.AbstractSet;
23 import java.util.HashMap;
24 import java.util.Iterator;
25 import java.util.Map;
26 import java.util.Set;
27 import java.util.function.Predicate;
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50 public class InetAddressSet extends AbstractSet<String> implements Set<String>, Predicate<InetAddress>
51 {
52 private Map<String,InetPattern> _patterns = new HashMap<>();
53
54 @Override
55 public boolean add(String pattern)
56 {
57 return _patterns.put(pattern,newInetRange(pattern))==null;
58 }
59
60 protected InetPattern newInetRange(String pattern)
61 {
62 if (pattern==null)
63 return null;
64
65 int slash = pattern.lastIndexOf('/');
66 int dash = pattern.lastIndexOf('-');
67 try
68 {
69 if (slash>=0)
70 return new CidrInetRange(pattern,InetAddress.getByName(pattern.substring(0,slash).trim()),StringUtil.toInt(pattern,slash+1));
71
72 if (dash>=0)
73 return new MinMaxInetRange(pattern,InetAddress.getByName(pattern.substring(0,dash).trim()),InetAddress.getByName(pattern.substring(dash+1).trim()));
74
75 return new SingletonInetRange(pattern,InetAddress.getByName(pattern));
76 }
77 catch(Exception e)
78 {
79 try
80 {
81 if (slash<0 && dash>0)
82 return new LegacyInetRange(pattern);
83 }
84 catch(Exception e2)
85 {
86 e.addSuppressed(e2);
87 }
88 throw new IllegalArgumentException("Bad pattern: "+pattern,e);
89 }
90 }
91
92 @Override
93 public boolean remove(Object pattern)
94 {
95 return _patterns.remove(pattern)!=null;
96 }
97
98 @Override
99 public Iterator<String> iterator()
100 {
101 return _patterns.keySet().iterator();
102 }
103
104 @Override
105 public int size()
106 {
107 return _patterns.size();
108 }
109
110
111 @Override
112 public boolean test(InetAddress address)
113 {
114 if (address==null)
115 return false;
116 byte[] raw = address.getAddress();
117 for (InetPattern pattern : _patterns.values())
118 if (pattern.test(address,raw))
119 return true;
120 return false;
121 }
122
123 abstract static class InetPattern
124 {
125 final String _pattern;
126
127 InetPattern(String pattern)
128 {
129 _pattern=pattern;
130 }
131
132 abstract boolean test(InetAddress address, byte[] raw);
133
134 @Override
135 public String toString()
136 {
137 return _pattern;
138 }
139 }
140
141 static class SingletonInetRange extends InetPattern
142 {
143 final InetAddress _address;
144 public SingletonInetRange(String pattern, InetAddress address)
145 {
146 super(pattern);
147 _address=address;
148 }
149
150 public boolean test(InetAddress address, byte[] raw)
151 {
152 return _address.equals(address);
153 }
154 }
155
156
157 static class MinMaxInetRange extends InetPattern
158 {
159 final int[] _min;
160 final int[] _max;
161
162 public MinMaxInetRange(String pattern, InetAddress min, InetAddress max)
163 {
164 super(pattern);
165
166 byte[] raw_min = min.getAddress();
167 byte[] raw_max = max.getAddress();
168 if (raw_min.length!=raw_max.length)
169 throw new IllegalArgumentException("Cannot mix IPv4 and IPv6: "+pattern);
170
171 if (raw_min.length==4)
172 {
173
174 int count=0;
175 for (char c:pattern.toCharArray())
176 if (c=='.')
177 count++;
178 if (count!=6)
179 throw new IllegalArgumentException("Legacy pattern: "+pattern);
180 }
181
182 _min = new int[raw_min.length];
183 _max = new int[raw_min.length];
184
185 for (int i=0;i<_min.length;i++)
186 {
187 _min[i]=0xff&raw_min[i];
188 _max[i]=0xff&raw_max[i];
189 }
190
191 for (int i=0;i<_min.length;i++)
192 {
193 if (_min[i]>_max[i])
194 throw new IllegalArgumentException("min is greater than max: "+pattern);
195 if (_min[i]<_max[i])
196 break;
197 }
198 }
199
200 public boolean test(InetAddress item, byte[] raw)
201 {
202 if (raw.length!=_min.length)
203 return false;
204
205 boolean min_ok = false;
206 boolean max_ok = false;
207
208 for (int i=0;i<_min.length;i++)
209 {
210 int r = 0xff&raw[i];
211 if (!min_ok)
212 {
213 if (r<_min[i])
214 return false;
215 if (r>_min[i])
216 min_ok=true;
217 }
218 if (!max_ok)
219 {
220 if (r>_max[i])
221 return false;
222 if (r<_max[i])
223 max_ok=true;
224 }
225
226 if (min_ok && max_ok)
227 break;
228 }
229
230 return true;
231 }
232 }
233
234
235 static class CidrInetRange extends InetPattern
236 {
237 final byte[] _raw;
238 final int _octets;
239 final int _mask;
240 final int _masked;
241
242 public CidrInetRange(String pattern, InetAddress address, int cidr)
243 {
244 super(pattern);
245 _raw = address.getAddress();
246 _octets = cidr/8;
247 _mask = 0xff&(0xff<<(8-cidr%8));
248 _masked = _mask==0?0:_raw[_octets]&_mask;
249
250 if (cidr>(_raw.length*8))
251 throw new IllegalArgumentException("CIDR too large: "+pattern);
252
253 if (_mask!=0 && _raw[_octets]!=_masked)
254 throw new IllegalArgumentException("CIDR bits non zero: "+pattern);
255
256 for (int o=_octets+(_mask==0?0:1);o<_raw.length;o++)
257 if (_raw[o]!=0)
258 throw new IllegalArgumentException("CIDR bits non zero: "+pattern);
259 }
260
261 public boolean test(InetAddress item, byte[] raw)
262 {
263 if (raw.length!=_raw.length)
264 return false;
265
266 for (int o=0;o<_octets;o++)
267 if (_raw[o]!=raw[o])
268 return false;
269
270 if (_mask!=0 && (raw[_octets]&_mask)!=_masked)
271 return false;
272 return true;
273 }
274 }
275
276 static class LegacyInetRange extends InetPattern
277 {
278 int[] _min = new int[4];
279 int[] _max = new int[4];
280
281 public LegacyInetRange(String pattern)
282 {
283 super(pattern);
284
285 String[] parts = pattern.split("\\.");
286 if (parts.length!=4)
287 throw new IllegalArgumentException("Bad legacy pattern: "+pattern);
288
289 for (int i=0;i<4;i++)
290 {
291 String part=parts[i].trim();
292 int dash = part.indexOf('-');
293 if (dash<0)
294 _min[i]=_max[i]=Integer.parseInt(part);
295 else
296 {
297 _min[i] = (dash==0)?0:StringUtil.toInt(part,0);
298 _max[i] = (dash==part.length()-1)?255:StringUtil.toInt(part,dash+1);
299 }
300
301 if (_min[i]<0 || _min[i]>_max[i] || _max[i]>255)
302 throw new IllegalArgumentException("Bad legacy pattern: "+pattern);
303 }
304 }
305
306 public boolean test(InetAddress item, byte[] raw)
307 {
308 if (raw.length!=4)
309 return false;
310
311 for (int i=0;i<4;i++)
312 if ((0xff&raw[i])<_min[i] || (0xff&raw[i])>_max[i])
313 return false;
314
315 return true;
316 }
317 }
318 }