View Javadoc

1   //
2   //  ========================================================================
3   //  Copyright (c) 1995-2014 Mort Bay Consulting Pty. Ltd.
4   //  ------------------------------------------------------------------------
5   //  All rights reserved. This program and the accompanying materials
6   //  are made available under the terms of the Eclipse Public License v1.0
7   //  and Apache License v2.0 which accompanies this distribution.
8   //
9   //      The Eclipse Public License is available at
10  //      http://www.eclipse.org/legal/epl-v10.html
11  //
12  //      The Apache License v2.0 is available at
13  //      http://www.opensource.org/licenses/apache2.0.php
14  //
15  //  You may elect to redistribute this code under either of these licenses.
16  //  ========================================================================
17  //
18  
19  package org.eclipse.jetty.websocket.common.extensions.compress;
20  
21  import java.nio.ByteBuffer;
22  import java.util.Queue;
23  import java.util.zip.DataFormatException;
24  import java.util.zip.Deflater;
25  import java.util.zip.Inflater;
26  import java.util.zip.ZipException;
27  
28  import org.eclipse.jetty.util.BufferUtil;
29  import org.eclipse.jetty.util.ConcurrentArrayQueue;
30  import org.eclipse.jetty.util.IteratingCallback;
31  import org.eclipse.jetty.util.log.Log;
32  import org.eclipse.jetty.util.log.Logger;
33  import org.eclipse.jetty.websocket.api.BadPayloadException;
34  import org.eclipse.jetty.websocket.api.BatchMode;
35  import org.eclipse.jetty.websocket.api.WriteCallback;
36  import org.eclipse.jetty.websocket.api.extensions.Frame;
37  import org.eclipse.jetty.websocket.common.OpCode;
38  import org.eclipse.jetty.websocket.common.extensions.AbstractExtension;
39  import org.eclipse.jetty.websocket.common.frames.DataFrame;
40  
41  public abstract class CompressExtension extends AbstractExtension
42  {
43      protected static final byte[] TAIL_BYTES = new byte[]{0x00, 0x00, (byte)0xFF, (byte)0xFF};
44      private static final Logger LOG = Log.getLogger(CompressExtension.class);
45      
46      /** Never drop tail bytes 0000FFFF, from any frame type */
47      protected static final int TAIL_DROP_NEVER = 0;
48      /** Always drop tail bytes 0000FFFF, from all frame types */
49      protected static final int TAIL_DROP_ALWAYS = 1;
50      /** Only drop tail bytes 0000FFFF, from fin==true frames */
51      protected static final int TAIL_DROP_FIN_ONLY = 2;
52  
53      /** Always set RSV flag, on all frame types */
54      protected static final int RSV_USE_ALWAYS = 0;
55      /** 
56       * Only set RSV flag on first frame in multi-frame messages.
57       * <p>
58       * Note: this automatically means no-continuation frames have
59       * the RSV bit set 
60       */
61      protected static final int RSV_USE_ONLY_FIRST = 1;
62  
63      private final Queue<FrameEntry> entries = new ConcurrentArrayQueue<>();
64      private final IteratingCallback flusher = new Flusher();
65      private final Deflater compressor;
66      private final Inflater decompressor;
67      private int tailDrop = TAIL_DROP_NEVER;
68      private int rsvUse = RSV_USE_ALWAYS;
69  
70      protected CompressExtension()
71      {
72          compressor = new Deflater(Deflater.BEST_COMPRESSION, true);
73          decompressor = new Inflater(true);
74          tailDrop = getTailDropMode();
75          rsvUse = getRsvUseMode();
76      }
77      
78      public Deflater getDeflater()
79      {
80          return compressor;
81      }
82  
83      public Inflater getInflater()
84      {
85          return decompressor;
86      }
87  
88      /**
89       * Indicates use of RSV1 flag for indicating deflation is in use.
90       */
91      @Override
92      public boolean isRsv1User()
93      {
94          return true;
95      }
96      
97      /**
98       * Return the mode of operation for dropping (or keeping) tail bytes in frames generated by compress (outgoing)
99       * 
100      * @return either {@link #TAIL_DROP_ALWAYS}, {@link #TAIL_DROP_FIN_ONLY}, or {@link #TAIL_DROP_NEVER}
101      */
102     abstract int getTailDropMode();
103 
104     /**
105      * Return the mode of operation for RSV flag use in frames generate by compress (outgoing)
106      * 
107      * @return either {@link #RSV_USE_ALWAYS} or {@link #RSV_USE_ONLY_FIRST}
108      */
109     abstract int getRsvUseMode();
110 
111     protected void forwardIncoming(Frame frame, ByteAccumulator accumulator)
112     {
113         DataFrame newFrame = new DataFrame(frame);
114         // Unset RSV1 since it's not compressed anymore.
115         newFrame.setRsv1(false);
116 
117         ByteBuffer buffer = getBufferPool().acquire(accumulator.getLength(), false);
118         try
119         {
120             BufferUtil.flipToFill(buffer);
121             accumulator.transferTo(buffer);
122             newFrame.setPayload(buffer);
123             nextIncomingFrame(newFrame);
124         }
125         finally
126         {
127             getBufferPool().release(buffer);
128         }
129     }
130 
131     protected ByteAccumulator decompress(byte[] input)
132     {
133         // Since we don't track text vs binary vs continuation state, just grab whatever is the greater value.
134         int maxSize = Math.max(getPolicy().getMaxTextMessageSize(), getPolicy().getMaxBinaryMessageBufferSize());
135         ByteAccumulator accumulator = new ByteAccumulator(maxSize);
136 
137         decompressor.setInput(input, 0, input.length);
138         LOG.debug("Decompressing {} bytes", input.length);
139 
140         try
141         {
142             // It is allowed to send DEFLATE blocks with BFINAL=1.
143             // For such blocks, getRemaining() will be > 0 but finished()
144             // will be true, so we need to check for both.
145             // When BFINAL=0, finished() will always be false and we only
146             // check the remaining bytes.
147             while (decompressor.getRemaining() > 0 && !decompressor.finished())
148             {
149                 byte[] output = new byte[Math.min(input.length * 2, 32 * 1024)];
150                 int decompressed = decompressor.inflate(output);
151                 if (decompressed == 0)
152                 {
153                     if (decompressor.needsInput())
154                     {
155                         throw new BadPayloadException("Unable to inflate frame, not enough input on frame");
156                     }
157                     if (decompressor.needsDictionary())
158                     {
159                         throw new BadPayloadException("Unable to inflate frame, frame erroneously says it needs a dictionary");
160                     }
161                 }
162                 else
163                 {
164                     accumulator.addChunk(output, 0, decompressed);
165                 }
166             }
167             LOG.debug("Decompressed {}->{} bytes", input.length, accumulator.getLength());
168             return accumulator;
169         }
170         catch (DataFormatException x)
171         {
172             throw new BadPayloadException(x);
173         }
174     }
175 
176     @Override
177     public void outgoingFrame(Frame frame, WriteCallback callback, BatchMode batchMode)
178     {
179         // We use a queue and an IteratingCallback to handle concurrency.
180         // We must compress and write atomically, otherwise the compression
181         // context on the other end gets confused.
182 
183         if (flusher.isFailed())
184         {
185             notifyCallbackFailure(callback, new ZipException());
186             return;
187         }
188 
189         FrameEntry entry = new FrameEntry(frame, callback, batchMode);
190         LOG.debug("Queuing {}", entry);
191         entries.offer(entry);
192         flusher.iterate();
193     }
194 
195     protected void notifyCallbackSuccess(WriteCallback callback)
196     {
197         try
198         {
199             if (callback != null)
200                 callback.writeSuccess();
201         }
202         catch (Throwable x)
203         {
204             LOG.debug("Exception while notifying success of callback " + callback, x);
205         }
206     }
207 
208     protected void notifyCallbackFailure(WriteCallback callback, Throwable failure)
209     {
210         try
211         {
212             if (callback != null)
213                 callback.writeFailed(failure);
214         }
215         catch (Throwable x)
216         {
217             LOG.debug("Exception while notifying failure of callback " + callback, x);
218         }
219     }
220 
221     @Override
222     public String toString()
223     {
224         return getClass().getSimpleName();
225     }
226 
227     private static class FrameEntry
228     {
229         private final Frame frame;
230         private final WriteCallback callback;
231         private final BatchMode batchMode;
232 
233         private FrameEntry(Frame frame, WriteCallback callback, BatchMode batchMode)
234         {
235             this.frame = frame;
236             this.callback = callback;
237             this.batchMode = batchMode;
238         }
239 
240         @Override
241         public String toString()
242         {
243             return frame.toString();
244         }
245     }
246 
247     private class Flusher extends IteratingCallback implements WriteCallback
248     {
249         private FrameEntry current;
250         private ByteBuffer payload;
251         private boolean finished = true;
252 
253         @Override
254         protected Action process() throws Exception
255         {
256             if (finished)
257             {
258                 current = entries.poll();
259                 LOG.debug("Processing {}", current);
260                 if (current == null)
261                     return Action.IDLE;
262                 deflate(current);
263             }
264             else
265             {
266                 compress(current, false);
267             }
268             return Action.SCHEDULED;
269         }
270 
271         private void deflate(FrameEntry entry)
272         {
273             Frame frame = entry.frame;
274             BatchMode batchMode = entry.batchMode;
275             if (OpCode.isControlFrame(frame.getOpCode()) || !frame.hasPayload())
276             {
277                 nextOutgoingFrame(frame, this, batchMode);
278                 return;
279             }
280 
281             compress(entry, true);
282         }
283 
284         private void compress(FrameEntry entry, boolean first)
285         {
286             // Get a chunk of the payload to avoid to blow
287             // the heap if the payload is a huge mapped file.
288             Frame frame = entry.frame;
289             ByteBuffer data = frame.getPayload();
290             int remaining = data.remaining();
291             int inputLength = Math.min(remaining, 32 * 1024);
292             LOG.debug("Compressing {}: {} bytes in {} bytes chunk", entry, remaining, inputLength);
293 
294             // Avoid to copy the bytes if the ByteBuffer
295             // is backed by an array.
296             int inputOffset;
297             byte[] input;
298             if (data.hasArray())
299             {
300                 input = data.array();
301                 int position = data.position();
302                 inputOffset = position + data.arrayOffset();
303                 data.position(position + inputLength);
304             }
305             else
306             {
307                 input = new byte[inputLength];
308                 inputOffset = 0;
309                 data.get(input, 0, inputLength);
310             }
311             finished = inputLength == remaining;
312 
313             compressor.setInput(input, inputOffset, inputLength);
314 
315             // Use an additional space in case the content is not compressible.
316             byte[] output = new byte[inputLength + 64];
317             int outputOffset = 0;
318             int outputLength = 0;
319             while (true)
320             {
321                 int space = output.length - outputOffset;
322                 int compressed = compressor.deflate(output, outputOffset, space, Deflater.SYNC_FLUSH);
323                 outputLength += compressed;
324                 if (compressed < space)
325                 {
326                     // Everything was compressed.
327                     break;
328                 }
329                 else
330                 {
331                     // The compressed output is bigger than the uncompressed input.
332                     byte[] newOutput = new byte[output.length * 2];
333                     System.arraycopy(output, 0, newOutput, 0, output.length);
334                     outputOffset += output.length;
335                     output = newOutput;
336                 }
337             }
338 
339             boolean fin = frame.isFin() && finished;
340 
341             // Handle tail bytes generated by SYNC_FLUSH.
342             if(tailDrop == TAIL_DROP_ALWAYS) {
343                 payload = ByteBuffer.wrap(output, 0, outputLength - TAIL_BYTES.length);
344             } else if(tailDrop == TAIL_DROP_FIN_ONLY) {
345                 payload = ByteBuffer.wrap(output, 0, outputLength - (fin?TAIL_BYTES.length:0));
346             } else {
347                 // always include
348                 payload = ByteBuffer.wrap(output, 0, outputLength);
349             }
350             if (LOG.isDebugEnabled())
351             {
352                 LOG.debug("Compressed {}: {}->{} chunk bytes",entry,inputLength,outputLength);
353             }
354 
355             boolean continuation = frame.getType().isContinuation() || !first;
356             DataFrame chunk = new DataFrame(frame, continuation);
357             if(rsvUse == RSV_USE_ONLY_FIRST) {
358                 chunk.setRsv1(!continuation);
359             } else {
360                 // always set
361                 chunk.setRsv1(true);
362             }
363             chunk.setPayload(payload);
364             chunk.setFin(fin);
365 
366             nextOutgoingFrame(chunk, this, entry.batchMode);
367         }
368 
369         @Override
370         protected void completed()
371         {
372             // This IteratingCallback never completes.
373         }
374 
375         @Override
376         public void writeSuccess()
377         {
378             if (finished)
379                 notifyCallbackSuccess(current.callback);
380             succeeded();
381         }
382 
383         @Override
384         public void writeFailed(Throwable x)
385         {
386             notifyCallbackFailure(current.callback, x);
387             // If something went wrong, very likely the compression context
388             // will be invalid, so we need to fail this IteratingCallback.
389             failed(x);
390             // Now no more frames can be queued, fail those in the queue.
391             FrameEntry entry;
392             while ((entry = entries.poll()) != null)
393                 notifyCallbackFailure(entry.callback, x);
394         }
395     }
396 }