View Javadoc

1   //
2   //  ========================================================================
3   //  Copyright (c) 1995-2014 Mort Bay Consulting Pty. Ltd.
4   //  ------------------------------------------------------------------------
5   //  All rights reserved. This program and the accompanying materials
6   //  are made available under the terms of the Eclipse Public License v1.0
7   //  and Apache License v2.0 which accompanies this distribution.
8   //
9   //      The Eclipse Public License is available at
10  //      http://www.eclipse.org/legal/epl-v10.html
11  //
12  //      The Apache License v2.0 is available at
13  //      http://www.opensource.org/licenses/apache2.0.php
14  //
15  //  You may elect to redistribute this code under either of these licenses.
16  //  ========================================================================
17  //
18  
19  package org.eclipse.jetty.websocket.common.extensions.compress;
20  
21  import java.nio.ByteBuffer;
22  import java.util.Queue;
23  import java.util.zip.DataFormatException;
24  import java.util.zip.Deflater;
25  import java.util.zip.Inflater;
26  import java.util.zip.ZipException;
27  
28  import org.eclipse.jetty.util.BufferUtil;
29  import org.eclipse.jetty.util.ConcurrentArrayQueue;
30  import org.eclipse.jetty.util.IteratingCallback;
31  import org.eclipse.jetty.util.log.Log;
32  import org.eclipse.jetty.util.log.Logger;
33  import org.eclipse.jetty.websocket.api.BadPayloadException;
34  import org.eclipse.jetty.websocket.api.BatchMode;
35  import org.eclipse.jetty.websocket.api.WriteCallback;
36  import org.eclipse.jetty.websocket.api.extensions.Frame;
37  import org.eclipse.jetty.websocket.common.OpCode;
38  import org.eclipse.jetty.websocket.common.extensions.AbstractExtension;
39  import org.eclipse.jetty.websocket.common.frames.DataFrame;
40  
41  public abstract class CompressExtension extends AbstractExtension
42  {
43      protected static final byte[] TAIL_BYTES = new byte[]{0x00, 0x00, (byte)0xFF, (byte)0xFF};
44      private static final Logger LOG = Log.getLogger(CompressExtension.class);
45      
46      /** Never drop tail bytes 0000FFFF, from any frame type */
47      protected static final int TAIL_DROP_NEVER = 0;
48      /** Always drop tail bytes 0000FFFF, from all frame types */
49      protected static final int TAIL_DROP_ALWAYS = 1;
50      /** Only drop tail bytes 0000FFFF, from fin==true frames */
51      protected static final int TAIL_DROP_FIN_ONLY = 2;
52  
53      /** Always set RSV flag, on all frame types */
54      protected static final int RSV_USE_ALWAYS = 0;
55      /** 
56       * Only set RSV flag on first frame in multi-frame messages.
57       * <p>
58       * Note: this automatically means no-continuation frames have
59       * the RSV bit set 
60       */
61      protected static final int RSV_USE_ONLY_FIRST = 1;
62  
63      private final Queue<FrameEntry> entries = new ConcurrentArrayQueue<>();
64      private final IteratingCallback flusher = new Flusher();
65      private final Deflater compressor;
66      private final Inflater decompressor;
67      private int tailDrop = TAIL_DROP_NEVER;
68      private int rsvUse = RSV_USE_ALWAYS;
69  
70      protected CompressExtension()
71      {
72          compressor = new Deflater(Deflater.BEST_COMPRESSION, true);
73          decompressor = new Inflater(true);
74          tailDrop = getTailDropMode();
75          rsvUse = getRsvUseMode();
76      }
77      
78      public Deflater getDeflater()
79      {
80          return compressor;
81      }
82  
83      public Inflater getInflater()
84      {
85          return decompressor;
86      }
87  
88      /**
89       * Indicates use of RSV1 flag for indicating deflation is in use.
90       */
91      @Override
92      public boolean isRsv1User()
93      {
94          return true;
95      }
96      
97      /**
98       * Return the mode of operation for dropping (or keeping) tail bytes in frames generated by compress (outgoing)
99       * 
100      * @return either {@link #TAIL_DROP_ALWAYS}, {@link #TAIL_DROP_FIN_ONLY}, or {@link #TAIL_DROP_NEVER}
101      */
102     abstract int getTailDropMode();
103 
104     /**
105      * Return the mode of operation for RSV flag use in frames generate by compress (outgoing)
106      * 
107      * @return either {@link #RSV_USE_ALWAYS} or {@link #RSV_USE_ONLY_FIRST}
108      */
109     abstract int getRsvUseMode();
110 
111     protected void forwardIncoming(Frame frame, ByteAccumulator accumulator)
112     {
113         DataFrame newFrame = new DataFrame(frame);
114         // Unset RSV1 since it's not compressed anymore.
115         newFrame.setRsv1(false);
116 
117         ByteBuffer buffer = getBufferPool().acquire(accumulator.getLength(), false);
118         try
119         {
120             BufferUtil.flipToFill(buffer);
121             accumulator.transferTo(buffer);
122             newFrame.setPayload(buffer);
123             nextIncomingFrame(newFrame);
124         }
125         finally
126         {
127             getBufferPool().release(buffer);
128         }
129     }
130 
131     protected ByteAccumulator decompress(byte[] input)
132     {
133         // Since we don't track text vs binary vs continuation state, just grab whatever is the greater value.
134         int maxSize = Math.max(getPolicy().getMaxTextMessageSize(), getPolicy().getMaxBinaryMessageBufferSize());
135         ByteAccumulator accumulator = new ByteAccumulator(maxSize);
136 
137         decompressor.setInput(input, 0, input.length);
138 
139         if (LOG.isDebugEnabled())
140             LOG.debug("Decompressing {} bytes", input.length);
141 
142         try
143         {
144             // It is allowed to send DEFLATE blocks with BFINAL=1.
145             // For such blocks, getRemaining() will be > 0 but finished()
146             // will be true, so we need to check for both.
147             // When BFINAL=0, finished() will always be false and we only
148             // check the remaining bytes.
149             while (decompressor.getRemaining() > 0 && !decompressor.finished())
150             {
151                 byte[] output = new byte[Math.min(input.length * 2, 32 * 1024)];
152                 int decompressed = decompressor.inflate(output);
153                 if (decompressed == 0)
154                 {
155                     if (decompressor.needsInput())
156                     {
157                         throw new BadPayloadException("Unable to inflate frame, not enough input on frame");
158                     }
159                     if (decompressor.needsDictionary())
160                     {
161                         throw new BadPayloadException("Unable to inflate frame, frame erroneously says it needs a dictionary");
162                     }
163                 }
164                 else
165                 {
166                     accumulator.addChunk(output, 0, decompressed);
167                 }
168             }
169             if (LOG.isDebugEnabled())
170                 LOG.debug("Decompressed {}->{} bytes", input.length, accumulator.getLength());
171             return accumulator;
172         }
173         catch (DataFormatException x)
174         {
175             throw new BadPayloadException(x);
176         }
177     }
178 
179     @Override
180     public void outgoingFrame(Frame frame, WriteCallback callback, BatchMode batchMode)
181     {
182         // We use a queue and an IteratingCallback to handle concurrency.
183         // We must compress and write atomically, otherwise the compression
184         // context on the other end gets confused.
185 
186         if (flusher.isFailed())
187         {
188             notifyCallbackFailure(callback, new ZipException());
189             return;
190         }
191 
192         FrameEntry entry = new FrameEntry(frame, callback, batchMode);
193         if (LOG.isDebugEnabled())
194             LOG.debug("Queuing {}", entry);
195         entries.offer(entry);
196         flusher.iterate();
197     }
198 
199     protected void notifyCallbackSuccess(WriteCallback callback)
200     {
201         try
202         {
203             if (callback != null)
204                 callback.writeSuccess();
205         }
206         catch (Throwable x)
207         {
208             if (LOG.isDebugEnabled())
209                 LOG.debug("Exception while notifying success of callback " + callback, x);
210         }
211     }
212 
213     protected void notifyCallbackFailure(WriteCallback callback, Throwable failure)
214     {
215         try
216         {
217             if (callback != null)
218                 callback.writeFailed(failure);
219         }
220         catch (Throwable x)
221         {
222             if (LOG.isDebugEnabled())
223                 LOG.debug("Exception while notifying failure of callback " + callback, x);
224         }
225     }
226 
227     @Override
228     public String toString()
229     {
230         return getClass().getSimpleName();
231     }
232 
233     private static class FrameEntry
234     {
235         private final Frame frame;
236         private final WriteCallback callback;
237         private final BatchMode batchMode;
238 
239         private FrameEntry(Frame frame, WriteCallback callback, BatchMode batchMode)
240         {
241             this.frame = frame;
242             this.callback = callback;
243             this.batchMode = batchMode;
244         }
245 
246         @Override
247         public String toString()
248         {
249             return frame.toString();
250         }
251     }
252 
253     private class Flusher extends IteratingCallback implements WriteCallback
254     {
255         private static final int INPUT_BUFSIZE = 32 * 1024;
256         private FrameEntry current;
257         private ByteBuffer payload;
258         private boolean finished = true;
259 
260         @Override
261         protected Action process() throws Exception
262         {
263             if (finished)
264             {
265                 current = entries.poll();
266                 LOG.debug("Processing {}", current);
267                 if (current == null)
268                     return Action.IDLE;
269                 deflate(current);
270             }
271             else
272             {
273                 compress(current, false);
274             }
275             return Action.SCHEDULED;
276         }
277 
278         private void deflate(FrameEntry entry)
279         {
280             Frame frame = entry.frame;
281             BatchMode batchMode = entry.batchMode;
282             if (OpCode.isControlFrame(frame.getOpCode()) || !frame.hasPayload())
283             {
284                 nextOutgoingFrame(frame, this, batchMode);
285                 return;
286             }
287 
288             compress(entry, true);
289         }
290 
291         private void compress(FrameEntry entry, boolean first)
292         {
293             // Get a chunk of the payload to avoid to blow
294             // the heap if the payload is a huge mapped file.
295             Frame frame = entry.frame;
296             ByteBuffer data = frame.getPayload();
297             int remaining = data.remaining();
298             int inputLength = Math.min(remaining, INPUT_BUFSIZE);
299             if (LOG.isDebugEnabled())
300                 LOG.debug("Compressing {}: {} bytes in {} bytes chunk", entry, remaining, inputLength);
301 
302             // Avoid to copy the bytes if the ByteBuffer
303             // is backed by an array.
304             int inputOffset;
305             byte[] input;
306             if (data.hasArray())
307             {
308                 input = data.array();
309                 int position = data.position();
310                 inputOffset = position + data.arrayOffset();
311                 data.position(position + inputLength);
312             }
313             else
314             {
315                 input = new byte[inputLength];
316                 inputOffset = 0;
317                 data.get(input, 0, inputLength);
318             }
319             finished = inputLength == remaining;
320 
321             compressor.setInput(input, inputOffset, inputLength);
322 
323             // Use an additional space in case the content is not compressible.
324             byte[] output = new byte[inputLength + 64];
325             int outputOffset = 0;
326             int outputLength = 0;
327             while (true)
328             {
329                 int space = output.length - outputOffset;
330                 int compressed = compressor.deflate(output, outputOffset, space, Deflater.SYNC_FLUSH);
331                 outputLength += compressed;
332                 if (compressed < space)
333                 {
334                     // Everything was compressed.
335                     break;
336                 }
337                 else
338                 {
339                     // The compressed output is bigger than the uncompressed input.
340                     byte[] newOutput = new byte[output.length * 2];
341                     System.arraycopy(output, 0, newOutput, 0, output.length);
342                     outputOffset += output.length;
343                     output = newOutput;
344                 }
345             }
346 
347             boolean fin = frame.isFin() && finished;
348 
349             // Handle tail bytes generated by SYNC_FLUSH.
350             if(tailDrop == TAIL_DROP_ALWAYS) {
351                 payload = ByteBuffer.wrap(output, 0, outputLength - TAIL_BYTES.length);
352             } else if(tailDrop == TAIL_DROP_FIN_ONLY) {
353                 payload = ByteBuffer.wrap(output, 0, outputLength - (fin?TAIL_BYTES.length:0));
354             } else {
355                 // always include
356                 payload = ByteBuffer.wrap(output, 0, outputLength);
357             }
358             if (LOG.isDebugEnabled())
359             {
360                 LOG.debug("Compressed {}: {}->{} chunk bytes",entry,inputLength,outputLength);
361             }
362 
363             boolean continuation = frame.getType().isContinuation() || !first;
364             DataFrame chunk = new DataFrame(frame, continuation);
365             if(rsvUse == RSV_USE_ONLY_FIRST) {
366                 chunk.setRsv1(!continuation);
367             } else {
368                 // always set
369                 chunk.setRsv1(true);
370             }
371             chunk.setPayload(payload);
372             chunk.setFin(fin);
373 
374             nextOutgoingFrame(chunk, this, entry.batchMode);
375         }
376 
377         @Override
378         protected void onCompleteSuccess()
379         {
380             // This IteratingCallback never completes.
381         }
382         
383         @Override
384         protected void onCompleteFailure(Throwable x)
385         {
386             // Fail all the frames in the queue.
387             FrameEntry entry;
388             while ((entry = entries.poll()) != null)
389                 notifyCallbackFailure(entry.callback, x);
390         }
391 
392         @Override
393         public void writeSuccess()
394         {
395             if (finished)
396                 notifyCallbackSuccess(current.callback);
397             succeeded();
398         }
399 
400         @Override
401         public void writeFailed(Throwable x)
402         {
403             notifyCallbackFailure(current.callback, x);
404             // If something went wrong, very likely the compression context
405             // will be invalid, so we need to fail this IteratingCallback.
406             failed(x);
407         }
408     }
409 }