1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package org.eclipse.jetty.websocket.jsr356.server.deploy;
20
21 import java.util.HashSet;
22 import java.util.Set;
23
24 import javax.servlet.ServletContainerInitializer;
25 import javax.servlet.ServletContext;
26 import javax.servlet.ServletException;
27 import javax.servlet.annotation.HandlesTypes;
28 import javax.websocket.DeploymentException;
29 import javax.websocket.Endpoint;
30 import javax.websocket.server.ServerApplicationConfig;
31 import javax.websocket.server.ServerEndpoint;
32 import javax.websocket.server.ServerEndpointConfig;
33
34 import org.eclipse.jetty.server.handler.ContextHandler;
35 import org.eclipse.jetty.servlet.ServletContextHandler;
36 import org.eclipse.jetty.util.DecoratedObjectFactory;
37 import org.eclipse.jetty.util.TypeUtil;
38 import org.eclipse.jetty.util.log.Log;
39 import org.eclipse.jetty.util.log.Logger;
40 import org.eclipse.jetty.websocket.jsr356.server.ServerContainer;
41 import org.eclipse.jetty.websocket.server.WebSocketUpgradeFilter;
42
43 @HandlesTypes(
44 { ServerApplicationConfig.class, ServerEndpoint.class, Endpoint.class })
45 public class WebSocketServerContainerInitializer implements ServletContainerInitializer
46 {
47 public static final String ENABLE_KEY = "org.eclipse.jetty.websocket.jsr356";
48 private static final Logger LOG = Log.getLogger(WebSocketServerContainerInitializer.class);
49
50
51
52
53
54
55
56
57
58
59 public static ServerContainer configureContext(ServletContextHandler context) throws ServletException
60 {
61
62 WebSocketUpgradeFilter filter = WebSocketUpgradeFilter.configureContext(context);
63
64
65 ServerContainer jettyContainer = new ServerContainer(filter,filter.getFactory(),context.getServer().getThreadPool());
66 context.addBean(jettyContainer);
67
68
69 context.setAttribute(javax.websocket.server.ServerContainer.class.getName(),jettyContainer);
70
71 return jettyContainer;
72 }
73
74
75
76
77
78
79
80
81
82
83
84 public static ServerContainer configureContext(ServletContext context, ServletContextHandler jettyContext) throws ServletException
85 {
86
87 WebSocketUpgradeFilter filter = WebSocketUpgradeFilter.configureContext(context);
88
89
90 ServerContainer jettyContainer = new ServerContainer(filter,filter.getFactory(),jettyContext.getServer().getThreadPool());
91 jettyContext.addBean(jettyContainer);
92
93
94 context.setAttribute(javax.websocket.server.ServerContainer.class.getName(),jettyContainer);
95
96 return jettyContainer;
97 }
98
99 private boolean isEnabled(Set<Class<?>> c, ServletContext context)
100 {
101
102 String cp = context.getInitParameter(ENABLE_KEY);
103 if(TypeUtil.isTrue(cp))
104 {
105
106 return true;
107 }
108
109 if(TypeUtil.isFalse(cp))
110 {
111
112 LOG.warn("JSR-356 support disabled via parameter on context {} - {}",context.getContextPath(),context);
113 return false;
114 }
115
116
117 Object enable = context.getAttribute(ENABLE_KEY);
118
119 if(TypeUtil.isTrue(enable))
120 {
121
122 return true;
123 }
124
125 if (TypeUtil.isFalse(enable))
126 {
127
128 LOG.warn("JSR-356 support disabled via attribute on context {} - {}",context.getContextPath(),context);
129 return false;
130 }
131
132
133 if (c.isEmpty())
134 {
135 if (LOG.isDebugEnabled())
136 {
137 LOG.debug("No JSR-356 annotations or interfaces discovered. JSR-356 support disabled",context.getContextPath(),context);
138 }
139 return false;
140 }
141
142 return true;
143 }
144
145 @Override
146 public void onStartup(Set<Class<?>> c, ServletContext context) throws ServletException
147 {
148 if(!isEnabled(c,context))
149 {
150 return;
151 }
152
153 ContextHandler handler = ContextHandler.getContextHandler(context);
154
155 if (handler == null)
156 {
157 throw new ServletException("Not running on Jetty, JSR-356 support unavailable");
158 }
159
160 if (!(handler instanceof ServletContextHandler))
161 {
162 throw new ServletException("Not running in Jetty ServletContextHandler, JSR-356 support unavailable");
163 }
164
165 ServletContextHandler jettyContext = (ServletContextHandler)handler;
166
167 ClassLoader old = Thread.currentThread().getContextClassLoader();
168 try
169 {
170 Thread.currentThread().setContextClassLoader(context.getClassLoader());
171
172
173 ServerContainer jettyContainer = configureContext(context,jettyContext);
174
175
176 context.setAttribute(javax.websocket.server.ServerContainer.class.getName(),jettyContainer);
177
178
179
180 DecoratedObjectFactory instantiator = (DecoratedObjectFactory)context.getAttribute(DecoratedObjectFactory.ATTR);
181 if (instantiator == null)
182 {
183 LOG.info("Using WebSocket local DecoratedObjectFactory - none found in ServletContext");
184 instantiator = new DecoratedObjectFactory();
185 }
186
187 if (LOG.isDebugEnabled())
188 {
189 LOG.debug("Found {} classes",c.size());
190 }
191
192
193 Set<Class<? extends Endpoint>> discoveredExtendedEndpoints = new HashSet<>();
194 Set<Class<?>> discoveredAnnotatedEndpoints = new HashSet<>();
195 Set<Class<? extends ServerApplicationConfig>> serverAppConfigs = new HashSet<>();
196
197 filterClasses(c,discoveredExtendedEndpoints,discoveredAnnotatedEndpoints,serverAppConfigs);
198
199 if (LOG.isDebugEnabled())
200 {
201 LOG.debug("Discovered {} extends Endpoint classes",discoveredExtendedEndpoints.size());
202 LOG.debug("Discovered {} @ServerEndpoint classes",discoveredAnnotatedEndpoints.size());
203 LOG.debug("Discovered {} ServerApplicationConfig classes",serverAppConfigs.size());
204 }
205
206
207 boolean wasFiltered = false;
208 Set<ServerEndpointConfig> deployableExtendedEndpointConfigs = new HashSet<>();
209 Set<Class<?>> deployableAnnotatedEndpoints = new HashSet<>();
210
211 for (Class<? extends ServerApplicationConfig> clazz : serverAppConfigs)
212 {
213 if (LOG.isDebugEnabled())
214 {
215 LOG.debug("Found ServerApplicationConfig: {}",clazz);
216 }
217 try
218 {
219 ServerApplicationConfig config = clazz.newInstance();
220
221 Set<ServerEndpointConfig> seconfigs = config.getEndpointConfigs(discoveredExtendedEndpoints);
222 if (seconfigs != null)
223 {
224 wasFiltered = true;
225 deployableExtendedEndpointConfigs.addAll(seconfigs);
226 }
227
228 Set<Class<?>> annotatedClasses = config.getAnnotatedEndpointClasses(discoveredAnnotatedEndpoints);
229 if (annotatedClasses != null)
230 {
231 wasFiltered = true;
232 deployableAnnotatedEndpoints.addAll(annotatedClasses);
233 }
234 }
235 catch (InstantiationException | IllegalAccessException e)
236 {
237 throw new ServletException("Unable to instantiate: " + clazz.getName(),e);
238 }
239 }
240
241
242 if (!wasFiltered)
243 {
244 deployableAnnotatedEndpoints.addAll(discoveredAnnotatedEndpoints);
245
246 deployableExtendedEndpointConfigs = new HashSet<>();
247 }
248
249 if (LOG.isDebugEnabled())
250 {
251 LOG.debug("Deploying {} ServerEndpointConfig(s)",deployableExtendedEndpointConfigs.size());
252 }
253
254 for (ServerEndpointConfig config : deployableExtendedEndpointConfigs)
255 {
256 try
257 {
258 jettyContainer.addEndpoint(config);
259 }
260 catch (DeploymentException e)
261 {
262 throw new ServletException(e);
263 }
264 }
265
266 if (LOG.isDebugEnabled())
267 {
268 LOG.debug("Deploying {} @ServerEndpoint(s)",deployableAnnotatedEndpoints.size());
269 }
270 for (Class<?> annotatedClass : deployableAnnotatedEndpoints)
271 {
272 try
273 {
274 jettyContainer.addEndpoint(annotatedClass);
275 }
276 catch (DeploymentException e)
277 {
278 throw new ServletException(e);
279 }
280 }
281 } finally {
282 Thread.currentThread().setContextClassLoader(old);
283 }
284 }
285
286 @SuppressWarnings("unchecked")
287 private void filterClasses(Set<Class<?>> c, Set<Class<? extends Endpoint>> discoveredExtendedEndpoints, Set<Class<?>> discoveredAnnotatedEndpoints,
288 Set<Class<? extends ServerApplicationConfig>> serverAppConfigs)
289 {
290 for (Class<?> clazz : c)
291 {
292 if (ServerApplicationConfig.class.isAssignableFrom(clazz))
293 {
294 serverAppConfigs.add((Class<? extends ServerApplicationConfig>)clazz);
295 }
296
297 if (Endpoint.class.isAssignableFrom(clazz))
298 {
299 discoveredExtendedEndpoints.add((Class<? extends Endpoint>)clazz);
300 }
301
302 ServerEndpoint endpoint = clazz.getAnnotation(ServerEndpoint.class);
303
304 if (endpoint != null)
305 {
306 discoveredAnnotatedEndpoints.add(clazz);
307 }
308 }
309 }
310 }