View Javadoc

1   //
2   //  ========================================================================
3   //  Copyright (c) 1995-2016 Mort Bay Consulting Pty. Ltd.
4   //  ------------------------------------------------------------------------
5   //  All rights reserved. This program and the accompanying materials
6   //  are made available under the terms of the Eclipse Public License v1.0
7   //  and Apache License v2.0 which accompanies this distribution.
8   //
9   //      The Eclipse Public License is available at
10  //      http://www.eclipse.org/legal/epl-v10.html
11  //
12  //      The Apache License v2.0 is available at
13  //      http://www.opensource.org/licenses/apache2.0.php
14  //
15  //  You may elect to redistribute this code under either of these licenses.
16  //  ========================================================================
17  //
18  
19  package org.eclipse.jetty.websocket.common.extensions.compress;
20  
21  import java.io.ByteArrayOutputStream;
22  import java.nio.ByteBuffer;
23  import java.util.Queue;
24  import java.util.concurrent.atomic.AtomicInteger;
25  import java.util.zip.DataFormatException;
26  import java.util.zip.Deflater;
27  import java.util.zip.Inflater;
28  import java.util.zip.ZipException;
29  
30  import org.eclipse.jetty.util.BufferUtil;
31  import org.eclipse.jetty.util.ConcurrentArrayQueue;
32  import org.eclipse.jetty.util.IteratingCallback;
33  import org.eclipse.jetty.util.log.Log;
34  import org.eclipse.jetty.util.log.Logger;
35  import org.eclipse.jetty.websocket.api.BatchMode;
36  import org.eclipse.jetty.websocket.api.WriteCallback;
37  import org.eclipse.jetty.websocket.api.extensions.Frame;
38  import org.eclipse.jetty.websocket.common.OpCode;
39  import org.eclipse.jetty.websocket.common.extensions.AbstractExtension;
40  import org.eclipse.jetty.websocket.common.frames.DataFrame;
41  
42  public abstract class CompressExtension extends AbstractExtension
43  {
44      protected static final byte[] TAIL_BYTES = new byte[] { 0x00, 0x00, (byte)0xFF, (byte)0xFF };
45      protected static final ByteBuffer TAIL_BYTES_BUF = ByteBuffer.wrap(TAIL_BYTES);
46      private static final Logger LOG = Log.getLogger(CompressExtension.class);
47  
48      /** Never drop tail bytes 0000FFFF, from any frame type */
49      protected static final int TAIL_DROP_NEVER = 0;
50      /** Always drop tail bytes 0000FFFF, from all frame types */
51      protected static final int TAIL_DROP_ALWAYS = 1;
52      /** Only drop tail bytes 0000FFFF, from fin==true frames */
53      protected static final int TAIL_DROP_FIN_ONLY = 2;
54  
55      /** Always set RSV flag, on all frame types */
56      protected static final int RSV_USE_ALWAYS = 0;
57      /**
58       * Only set RSV flag on first frame in multi-frame messages.
59       * <p>
60       * Note: this automatically means no-continuation frames have the RSV bit set
61       */
62      protected static final int RSV_USE_ONLY_FIRST = 1;
63  
64      /** Inflater / Decompressed Buffer Size */
65      protected static final int INFLATE_BUFFER_SIZE = 8 * 1024;
66  
67      /** Deflater / Inflater: Maximum Input Buffer Size */
68      protected static final int INPUT_MAX_BUFFER_SIZE = 8 * 1024;
69  
70      /** Inflater : Output Buffer Size */
71      private static final int DECOMPRESS_BUF_SIZE = 8 * 1024;
72      
73      private final static boolean NOWRAP = true;
74  
75      private final Queue<FrameEntry> entries = new ConcurrentArrayQueue<>();
76      private final IteratingCallback flusher = new Flusher();
77      private final Deflater deflater;
78      private final Inflater inflater;
79      protected AtomicInteger decompressCount = new AtomicInteger(0);
80      private int tailDrop = TAIL_DROP_NEVER;
81      private int rsvUse = RSV_USE_ALWAYS;
82  
83      protected CompressExtension()
84      {
85          deflater = new Deflater(Deflater.DEFAULT_COMPRESSION,NOWRAP);
86          inflater = new Inflater(NOWRAP);
87          tailDrop = getTailDropMode();
88          rsvUse = getRsvUseMode();
89      }
90  
91      public Deflater getDeflater()
92      {
93          return deflater;
94      }
95  
96      public Inflater getInflater()
97      {
98          return inflater;
99      }
100 
101     /**
102      * Indicates use of RSV1 flag for indicating deflation is in use.
103      */
104     @Override
105     public boolean isRsv1User()
106     {
107         return true;
108     }
109 
110     /**
111      * Return the mode of operation for dropping (or keeping) tail bytes in frames generated by compress (outgoing)
112      * 
113      * @return either {@link #TAIL_DROP_ALWAYS}, {@link #TAIL_DROP_FIN_ONLY}, or {@link #TAIL_DROP_NEVER}
114      */
115     abstract int getTailDropMode();
116 
117     /**
118      * Return the mode of operation for RSV flag use in frames generate by compress (outgoing)
119      * 
120      * @return either {@link #RSV_USE_ALWAYS} or {@link #RSV_USE_ONLY_FIRST}
121      */
122     abstract int getRsvUseMode();
123 
124     protected void forwardIncoming(Frame frame, ByteAccumulator accumulator)
125     {
126         DataFrame newFrame = new DataFrame(frame);
127         // Unset RSV1 since it's not compressed anymore.
128         newFrame.setRsv1(false);
129 
130         ByteBuffer buffer = getBufferPool().acquire(accumulator.getLength(),false);
131         try
132         {
133             BufferUtil.flipToFill(buffer);
134             accumulator.transferTo(buffer);
135             newFrame.setPayload(buffer);
136             nextIncomingFrame(newFrame);
137         }
138         finally
139         {
140             getBufferPool().release(buffer);
141         }
142     }
143 
144     protected ByteAccumulator newByteAccumulator()
145     {
146         int maxSize = Math.max(getPolicy().getMaxTextMessageSize(),getPolicy().getMaxBinaryMessageBufferSize());
147         return new ByteAccumulator(maxSize);
148     }
149 
150     protected void decompress(ByteAccumulator accumulator, ByteBuffer buf) throws DataFormatException
151     {
152         if ((buf == null) || (!buf.hasRemaining()))
153         {
154             return;
155         }
156         byte[] output = new byte[DECOMPRESS_BUF_SIZE];
157         
158         while(buf.hasRemaining() && inflater.needsInput())
159         {
160             if (!supplyInput(inflater,buf))
161             {
162                 LOG.debug("Needed input, but no buffer could supply input");
163                 return;
164             }
165     
166             int read = 0;
167             while ((read = inflater.inflate(output)) >= 0)
168             {
169                 if (read == 0)
170                 {
171                     LOG.debug("Decompress: read 0 {}",toDetail(inflater));
172                     break;
173                 }
174                 else
175                 {
176                     // do something with output
177                     if (LOG.isDebugEnabled())
178                     {
179                         LOG.debug("Decompressed {} bytes: {}",read,toDetail(inflater));
180                     }
181     
182                     accumulator.copyChunk(output,0,read);
183                 }
184             }
185         }
186 
187         if (LOG.isDebugEnabled())
188         {
189             LOG.debug("Decompress: exiting {}",toDetail(inflater));
190         }
191     }
192 
193     @Override
194     public void outgoingFrame(Frame frame, WriteCallback callback, BatchMode batchMode)
195     {
196         // We use a queue and an IteratingCallback to handle concurrency.
197         // We must compress and write atomically, otherwise the compression
198         // context on the other end gets confused.
199 
200         if (flusher.isFailed())
201         {
202             notifyCallbackFailure(callback,new ZipException());
203             return;
204         }
205 
206         FrameEntry entry = new FrameEntry(frame,callback,batchMode);
207         if (LOG.isDebugEnabled())
208             LOG.debug("Queuing {}",entry);
209         entries.offer(entry);
210         flusher.iterate();
211     }
212 
213     protected void notifyCallbackSuccess(WriteCallback callback)
214     {
215         try
216         {
217             if (callback != null)
218                 callback.writeSuccess();
219         }
220         catch (Throwable x)
221         {
222             if (LOG.isDebugEnabled())
223                 LOG.debug("Exception while notifying success of callback " + callback,x);
224         }
225     }
226 
227     protected void notifyCallbackFailure(WriteCallback callback, Throwable failure)
228     {
229         try
230         {
231             if (callback != null)
232                 callback.writeFailed(failure);
233         }
234         catch (Throwable x)
235         {
236             if (LOG.isDebugEnabled())
237                 LOG.debug("Exception while notifying failure of callback " + callback,x);
238         }
239     }
240 
241     private static boolean supplyInput(Inflater inflater, ByteBuffer buf)
242     {
243         if (buf.remaining() <= 0)
244         {
245             if (LOG.isDebugEnabled())
246             {
247                 LOG.debug("No data left left to supply to Inflater");
248             }
249             return false;
250         }
251 
252         byte input[];
253         int inputOffset = 0;
254         int len;
255 
256         if (buf.hasArray())
257         {
258             // no need to create a new byte buffer, just return this one.
259             len = buf.remaining();
260             input = buf.array();
261             inputOffset = buf.position() + buf.arrayOffset();
262             buf.position(buf.position() + len);
263         }
264         else
265         {
266             // Only create an return byte buffer that is reasonable in size
267             len = Math.min(INPUT_MAX_BUFFER_SIZE,buf.remaining());
268             input = new byte[len];
269             inputOffset = 0;
270             buf.get(input,0,len);
271         }
272 
273         inflater.setInput(input,inputOffset,len);
274         if (LOG.isDebugEnabled())
275         {
276             LOG.debug("Supplied {} input bytes: {}",input.length,toDetail(inflater));
277         }
278         return true;
279     }
280 
281     private static boolean supplyInput(Deflater deflater, ByteBuffer buf)
282     {
283         if (buf.remaining() <= 0)
284         {
285             if (LOG.isDebugEnabled())
286             {
287                 LOG.debug("No data left left to supply to Deflater");
288             }
289             return false;
290         }
291 
292         byte input[];
293         int inputOffset = 0;
294         int len;
295 
296         if (buf.hasArray())
297         {
298             // no need to create a new byte buffer, just return this one.
299             len = buf.remaining();
300             input = buf.array();
301             inputOffset = buf.position() + buf.arrayOffset();
302             buf.position(buf.position() + len);
303         }
304         else
305         {
306             // Only create an return byte buffer that is reasonable in size
307             len = Math.min(INPUT_MAX_BUFFER_SIZE,buf.remaining());
308             input = new byte[len];
309             inputOffset = 0;
310             buf.get(input,0,len);
311         }
312 
313         deflater.setInput(input,inputOffset,len);
314         if (LOG.isDebugEnabled())
315         {
316             LOG.debug("Supplied {} input bytes: {}",input.length,toDetail(deflater));
317         }
318         return true;
319     }
320 
321     private static String toDetail(Inflater inflater)
322     {
323         return String.format("Inflater[finished=%b,read=%d,written=%d,remaining=%d,in=%d,out=%d]",inflater.finished(),inflater.getBytesRead(),
324                 inflater.getBytesWritten(),inflater.getRemaining(),inflater.getTotalIn(),inflater.getTotalOut());
325     }
326 
327     private static String toDetail(Deflater deflater)
328     {
329         return String.format("Deflater[finished=%b,read=%d,written=%d,in=%d,out=%d]",deflater.finished(),deflater.getBytesRead(),deflater.getBytesWritten(),
330                 deflater.getTotalIn(),deflater.getTotalOut());
331     }
332 
333     public static boolean endsWithTail(ByteBuffer buf)
334     {
335         if ((buf == null) || (buf.remaining() < TAIL_BYTES.length))
336         {
337             return false;
338         }
339         int limit = buf.limit();
340         for (int i = TAIL_BYTES.length; i > 0; i--)
341         {
342             if (buf.get(limit - i) != TAIL_BYTES[TAIL_BYTES.length - i])
343             {
344                 return false;
345             }
346         }
347         return true;
348     }
349 
350     @Override
351     public String toString()
352     {
353         return getClass().getSimpleName();
354     }
355 
356     private static class FrameEntry
357     {
358         private final Frame frame;
359         private final WriteCallback callback;
360         private final BatchMode batchMode;
361 
362         private FrameEntry(Frame frame, WriteCallback callback, BatchMode batchMode)
363         {
364             this.frame = frame;
365             this.callback = callback;
366             this.batchMode = batchMode;
367         }
368 
369         @Override
370         public String toString()
371         {
372             return frame.toString();
373         }
374     }
375 
376     private class Flusher extends IteratingCallback implements WriteCallback
377     {
378         private FrameEntry current;
379         private boolean finished = true;
380         
381         @Override
382         public void failed(Throwable x)
383         {
384             LOG.warn(x);
385             super.failed(x);
386         }
387 
388         @Override
389         protected Action process() throws Exception
390         {
391             if (finished)
392             {
393                 current = entries.poll();
394                 LOG.debug("Processing {}",current);
395                 if (current == null)
396                     return Action.IDLE;
397                 deflate(current);
398             }
399             else
400             {
401                 compress(current,false);
402             }
403             return Action.SCHEDULED;
404         }
405 
406         private void deflate(FrameEntry entry)
407         {
408             Frame frame = entry.frame;
409             BatchMode batchMode = entry.batchMode;
410             if (OpCode.isControlFrame(frame.getOpCode()))
411             {
412                 // Do not deflate control frames
413                 nextOutgoingFrame(frame,this,batchMode);
414                 return;
415             }
416             
417             compress(entry,true);
418         }
419 
420         private void compress(FrameEntry entry, boolean first)
421         {
422             // Get a chunk of the payload to avoid to blow
423             // the heap if the payload is a huge mapped file.
424             Frame frame = entry.frame;
425             ByteBuffer data = frame.getPayload();
426             int remaining = data.remaining();
427             int outputLength = Math.max(256,data.remaining());
428             if (LOG.isDebugEnabled())
429                 LOG.debug("Compressing {}: {} bytes in {} bytes chunk",entry,remaining,outputLength);
430 
431             boolean needsCompress = true;
432 
433             if (deflater.needsInput() && !supplyInput(deflater,data))
434             {
435                 // no input supplied
436                 needsCompress = false;
437             }
438             
439             ByteArrayOutputStream out = new ByteArrayOutputStream();
440 
441             byte[] output = new byte[outputLength];
442 
443             boolean fin = frame.isFin();
444 
445             // Compress the data
446             while (needsCompress)
447             {
448                 int compressed = deflater.deflate(output,0,outputLength,Deflater.SYNC_FLUSH);
449 
450                 // Append the output for the eventual frame.
451                 if (LOG.isDebugEnabled())
452                     LOG.debug("Wrote {} bytes to output buffer",compressed);
453                 out.write(output,0,compressed);
454 
455                 if (compressed < outputLength)
456                 {
457                     needsCompress = false;
458                 }
459             }
460 
461             ByteBuffer payload = ByteBuffer.wrap(out.toByteArray());
462 
463             if (payload.remaining() > 0)
464             {
465                 // Handle tail bytes generated by SYNC_FLUSH.
466                 if (LOG.isDebugEnabled())
467                     LOG.debug("compressed bytes[] = {}",BufferUtil.toDetailString(payload));
468 
469                 if (tailDrop == TAIL_DROP_ALWAYS)
470                 {
471                     if (endsWithTail(payload))
472                     {
473                         payload.limit(payload.limit() - TAIL_BYTES.length);
474                     }
475                     if (LOG.isDebugEnabled())
476                         LOG.debug("payload (TAIL_DROP_ALWAYS) = {}",BufferUtil.toDetailString(payload));
477                 }
478                 else if (tailDrop == TAIL_DROP_FIN_ONLY)
479                 {
480                     if (frame.isFin() && endsWithTail(payload))
481                     {
482                         payload.limit(payload.limit() - TAIL_BYTES.length);
483                     }
484                     if (LOG.isDebugEnabled())
485                         LOG.debug("payload (TAIL_DROP_FIN_ONLY) = {}",BufferUtil.toDetailString(payload));
486                 }
487             }
488             else if (fin)
489             {
490                 // Special case: 7.2.3.6.  Generating an Empty Fragment Manually
491                 // https://tools.ietf.org/html/rfc7692#section-7.2.3.6
492                 payload = ByteBuffer.wrap(new byte[] { 0x00 });
493             }
494 
495             if (LOG.isDebugEnabled())
496             {
497                 LOG.debug("Compressed {}: input:{} -> payload:{}",entry,outputLength,payload.remaining());
498             }
499 
500             boolean continuation = frame.getType().isContinuation() || !first;
501             DataFrame chunk = new DataFrame(frame,continuation);
502             if (rsvUse == RSV_USE_ONLY_FIRST)
503             {
504                 chunk.setRsv1(!continuation);
505             }
506             else
507             {
508                 // always set
509                 chunk.setRsv1(true);
510             }
511             chunk.setPayload(payload);
512             chunk.setFin(fin);
513 
514             nextOutgoingFrame(chunk,this,entry.batchMode);
515         }
516 
517         @Override
518         protected void onCompleteSuccess()
519         {
520             // This IteratingCallback never completes.
521         }
522 
523         @Override
524         protected void onCompleteFailure(Throwable x)
525         {
526             // Fail all the frames in the queue.
527             FrameEntry entry;
528             while ((entry = entries.poll()) != null)
529                 notifyCallbackFailure(entry.callback,x);
530         }
531 
532         @Override
533         public void writeSuccess()
534         {
535             if (finished)
536                 notifyCallbackSuccess(current.callback);
537             succeeded();
538         }
539 
540         @Override
541         public void writeFailed(Throwable x)
542         {
543             notifyCallbackFailure(current.callback,x);
544             // If something went wrong, very likely the compression context
545             // will be invalid, so we need to fail this IteratingCallback.
546             failed(x);
547         }
548     }
549 }