View Javadoc

1   //
2   //  ========================================================================
3   //  Copyright (c) 1995-2016 Mort Bay Consulting Pty. Ltd.
4   //  ------------------------------------------------------------------------
5   //  All rights reserved. This program and the accompanying materials
6   //  are made available under the terms of the Eclipse Public License v1.0
7   //  and Apache License v2.0 which accompanies this distribution.
8   //
9   //      The Eclipse Public License is available at
10  //      http://www.eclipse.org/legal/epl-v10.html
11  //
12  //      The Apache License v2.0 is available at
13  //      http://www.opensource.org/licenses/apache2.0.php
14  //
15  //  You may elect to redistribute this code under either of these licenses.
16  //  ========================================================================
17  //
18  
19  package org.eclipse.jetty.websocket.common.extensions.compress;
20  
21  import java.io.ByteArrayOutputStream;
22  import java.nio.ByteBuffer;
23  import java.util.ArrayDeque;
24  import java.util.Queue;
25  import java.util.concurrent.atomic.AtomicInteger;
26  import java.util.zip.DataFormatException;
27  import java.util.zip.Deflater;
28  import java.util.zip.Inflater;
29  import java.util.zip.ZipException;
30  
31  import org.eclipse.jetty.util.BufferUtil;
32  import org.eclipse.jetty.util.IteratingCallback;
33  import org.eclipse.jetty.util.log.Log;
34  import org.eclipse.jetty.util.log.Logger;
35  import org.eclipse.jetty.websocket.api.BatchMode;
36  import org.eclipse.jetty.websocket.api.WriteCallback;
37  import org.eclipse.jetty.websocket.api.extensions.Frame;
38  import org.eclipse.jetty.websocket.common.OpCode;
39  import org.eclipse.jetty.websocket.common.extensions.AbstractExtension;
40  import org.eclipse.jetty.websocket.common.frames.DataFrame;
41  
42  public abstract class CompressExtension extends AbstractExtension
43  {
44      protected static final byte[] TAIL_BYTES = new byte[] { 0x00, 0x00, (byte)0xFF, (byte)0xFF };
45      protected static final ByteBuffer TAIL_BYTES_BUF = ByteBuffer.wrap(TAIL_BYTES);
46      private static final Logger LOG = Log.getLogger(CompressExtension.class);
47  
48      /** Never drop tail bytes 0000FFFF, from any frame type */
49      protected static final int TAIL_DROP_NEVER = 0;
50      /** Always drop tail bytes 0000FFFF, from all frame types */
51      protected static final int TAIL_DROP_ALWAYS = 1;
52      /** Only drop tail bytes 0000FFFF, from fin==true frames */
53      protected static final int TAIL_DROP_FIN_ONLY = 2;
54  
55      /** Always set RSV flag, on all frame types */
56      protected static final int RSV_USE_ALWAYS = 0;
57      /**
58       * Only set RSV flag on first frame in multi-frame messages.
59       * <p>
60       * Note: this automatically means no-continuation frames have the RSV bit set
61       */
62      protected static final int RSV_USE_ONLY_FIRST = 1;
63  
64      /** Inflater / Decompressed Buffer Size */
65      protected static final int INFLATE_BUFFER_SIZE = 8 * 1024;
66  
67      /** Deflater / Inflater: Maximum Input Buffer Size */
68      protected static final int INPUT_MAX_BUFFER_SIZE = 8 * 1024;
69  
70      /** Inflater : Output Buffer Size */
71      private static final int DECOMPRESS_BUF_SIZE = 8 * 1024;
72      
73      private final static boolean NOWRAP = true;
74  
75      private final Queue<FrameEntry> entries = new ArrayDeque<>();
76      private final IteratingCallback flusher = new Flusher();
77      private Deflater deflaterImpl;
78      private Inflater inflaterImpl;
79      protected AtomicInteger decompressCount = new AtomicInteger(0);
80      private int tailDrop = TAIL_DROP_NEVER;
81      private int rsvUse = RSV_USE_ALWAYS;
82  
83      protected CompressExtension()
84      {
85          tailDrop = getTailDropMode();
86          rsvUse = getRsvUseMode();
87      }
88  
89      public Deflater getDeflater()
90      {
91          if (deflaterImpl == null)
92          {
93              deflaterImpl = new Deflater(Deflater.DEFAULT_COMPRESSION,NOWRAP);
94          }
95          return deflaterImpl;
96      }
97  
98      public Inflater getInflater()
99      {
100         if (inflaterImpl == null)
101         {
102             inflaterImpl = new Inflater(NOWRAP);
103         }
104         return inflaterImpl;
105     }
106 
107     /**
108      * Indicates use of RSV1 flag for indicating deflation is in use.
109      */
110     @Override
111     public boolean isRsv1User()
112     {
113         return true;
114     }
115 
116     /**
117      * Return the mode of operation for dropping (or keeping) tail bytes in frames generated by compress (outgoing)
118      * 
119      * @return either {@link #TAIL_DROP_ALWAYS}, {@link #TAIL_DROP_FIN_ONLY}, or {@link #TAIL_DROP_NEVER}
120      */
121     abstract int getTailDropMode();
122 
123     /**
124      * Return the mode of operation for RSV flag use in frames generate by compress (outgoing)
125      * 
126      * @return either {@link #RSV_USE_ALWAYS} or {@link #RSV_USE_ONLY_FIRST}
127      */
128     abstract int getRsvUseMode();
129 
130     protected void forwardIncoming(Frame frame, ByteAccumulator accumulator)
131     {
132         DataFrame newFrame = new DataFrame(frame);
133         // Unset RSV1 since it's not compressed anymore.
134         newFrame.setRsv1(false);
135 
136         ByteBuffer buffer = getBufferPool().acquire(accumulator.getLength(),false);
137         try
138         {
139             BufferUtil.flipToFill(buffer);
140             accumulator.transferTo(buffer);
141             newFrame.setPayload(buffer);
142             nextIncomingFrame(newFrame);
143         }
144         finally
145         {
146             getBufferPool().release(buffer);
147         }
148     }
149 
150     protected ByteAccumulator newByteAccumulator()
151     {
152         int maxSize = Math.max(getPolicy().getMaxTextMessageSize(),getPolicy().getMaxBinaryMessageBufferSize());
153         return new ByteAccumulator(maxSize);
154     }
155 
156     protected void decompress(ByteAccumulator accumulator, ByteBuffer buf) throws DataFormatException
157     {
158         if ((buf == null) || (!buf.hasRemaining()))
159         {
160             return;
161         }
162         byte[] output = new byte[DECOMPRESS_BUF_SIZE];
163         
164         Inflater inflater = getInflater();
165         
166         while(buf.hasRemaining() && inflater.needsInput())
167         {
168             if (!supplyInput(inflater,buf))
169             {
170                 LOG.debug("Needed input, but no buffer could supply input");
171                 return;
172             }
173     
174             int read;
175             while ((read = inflater.inflate(output)) >= 0)
176             {
177                 if (read == 0)
178                 {
179                     LOG.debug("Decompress: read 0 {}",toDetail(inflater));
180                     break;
181                 }
182                 else
183                 {
184                     // do something with output
185                     if (LOG.isDebugEnabled())
186                     {
187                         LOG.debug("Decompressed {} bytes: {}",read,toDetail(inflater));
188                     }
189     
190                     accumulator.copyChunk(output,0,read);
191                 }
192             }
193         }
194 
195         if (LOG.isDebugEnabled())
196         {
197             LOG.debug("Decompress: exiting {}",toDetail(inflater));
198         }
199     }
200 
201     @Override
202     public void outgoingFrame(Frame frame, WriteCallback callback, BatchMode batchMode)
203     {
204         // We use a queue and an IteratingCallback to handle concurrency.
205         // We must compress and write atomically, otherwise the compression
206         // context on the other end gets confused.
207 
208         if (flusher.isFailed())
209         {
210             notifyCallbackFailure(callback,new ZipException());
211             return;
212         }
213 
214         FrameEntry entry = new FrameEntry(frame,callback,batchMode);
215         if (LOG.isDebugEnabled())
216             LOG.debug("Queuing {}",entry);
217         offerEntry(entry);
218         flusher.iterate();
219     }
220 
221     private void offerEntry(FrameEntry entry)
222     {
223         synchronized (this)
224         {
225             entries.offer(entry);
226         }
227     }
228 
229     private FrameEntry pollEntry()
230     {
231         synchronized (this)
232         {
233             return entries.poll();
234         }
235     }
236 
237     protected void notifyCallbackSuccess(WriteCallback callback)
238     {
239         try
240         {
241             if (callback != null)
242                 callback.writeSuccess();
243         }
244         catch (Throwable x)
245         {
246             if (LOG.isDebugEnabled())
247                 LOG.debug("Exception while notifying success of callback " + callback,x);
248         }
249     }
250 
251     protected void notifyCallbackFailure(WriteCallback callback, Throwable failure)
252     {
253         try
254         {
255             if (callback != null)
256                 callback.writeFailed(failure);
257         }
258         catch (Throwable x)
259         {
260             if (LOG.isDebugEnabled())
261                 LOG.debug("Exception while notifying failure of callback " + callback,x);
262         }
263     }
264 
265     private static boolean supplyInput(Inflater inflater, ByteBuffer buf)
266     {
267         if (buf.remaining() <= 0)
268         {
269             if (LOG.isDebugEnabled())
270             {
271                 LOG.debug("No data left left to supply to Inflater");
272             }
273             return false;
274         }
275 
276         byte input[];
277         int inputOffset;
278         int len;
279 
280         if (buf.hasArray())
281         {
282             // no need to create a new byte buffer, just return this one.
283             len = buf.remaining();
284             input = buf.array();
285             inputOffset = buf.position() + buf.arrayOffset();
286             buf.position(buf.position() + len);
287         }
288         else
289         {
290             // Only create an return byte buffer that is reasonable in size
291             len = Math.min(INPUT_MAX_BUFFER_SIZE,buf.remaining());
292             input = new byte[len];
293             inputOffset = 0;
294             buf.get(input,0,len);
295         }
296 
297         inflater.setInput(input,inputOffset,len);
298         if (LOG.isDebugEnabled())
299         {
300             LOG.debug("Supplied {} input bytes: {}",input.length,toDetail(inflater));
301         }
302         return true;
303     }
304 
305     private static boolean supplyInput(Deflater deflater, ByteBuffer buf)
306     {
307         if (buf.remaining() <= 0)
308         {
309             if (LOG.isDebugEnabled())
310             {
311                 LOG.debug("No data left left to supply to Deflater");
312             }
313             return false;
314         }
315 
316         byte input[];
317         int inputOffset;
318         int len;
319 
320         if (buf.hasArray())
321         {
322             // no need to create a new byte buffer, just return this one.
323             len = buf.remaining();
324             input = buf.array();
325             inputOffset = buf.position() + buf.arrayOffset();
326             buf.position(buf.position() + len);
327         }
328         else
329         {
330             // Only create an return byte buffer that is reasonable in size
331             len = Math.min(INPUT_MAX_BUFFER_SIZE,buf.remaining());
332             input = new byte[len];
333             inputOffset = 0;
334             buf.get(input,0,len);
335         }
336 
337         deflater.setInput(input,inputOffset,len);
338         if (LOG.isDebugEnabled())
339         {
340             LOG.debug("Supplied {} input bytes: {}",input.length,toDetail(deflater));
341         }
342         return true;
343     }
344 
345     private static String toDetail(Inflater inflater)
346     {
347         return String.format("Inflater[finished=%b,read=%d,written=%d,remaining=%d,in=%d,out=%d]",inflater.finished(),inflater.getBytesRead(),
348                 inflater.getBytesWritten(),inflater.getRemaining(),inflater.getTotalIn(),inflater.getTotalOut());
349     }
350 
351     private static String toDetail(Deflater deflater)
352     {
353         return String.format("Deflater[finished=%b,read=%d,written=%d,in=%d,out=%d]",deflater.finished(),deflater.getBytesRead(),deflater.getBytesWritten(),
354                 deflater.getTotalIn(),deflater.getTotalOut());
355     }
356 
357     public static boolean endsWithTail(ByteBuffer buf)
358     {
359         if ((buf == null) || (buf.remaining() < TAIL_BYTES.length))
360         {
361             return false;
362         }
363         int limit = buf.limit();
364         for (int i = TAIL_BYTES.length; i > 0; i--)
365         {
366             if (buf.get(limit - i) != TAIL_BYTES[TAIL_BYTES.length - i])
367             {
368                 return false;
369             }
370         }
371         return true;
372     }
373     
374     @Override
375     protected void doStop() throws Exception
376     {
377         if(deflaterImpl != null)
378             deflaterImpl.end();
379         if(inflaterImpl != null)
380             inflaterImpl.end();
381         super.doStop();
382     }
383 
384     @Override
385     public String toString()
386     {
387         return getClass().getSimpleName();
388     }
389 
390     private static class FrameEntry
391     {
392         private final Frame frame;
393         private final WriteCallback callback;
394         private final BatchMode batchMode;
395 
396         private FrameEntry(Frame frame, WriteCallback callback, BatchMode batchMode)
397         {
398             this.frame = frame;
399             this.callback = callback;
400             this.batchMode = batchMode;
401         }
402 
403         @Override
404         public String toString()
405         {
406             return frame.toString();
407         }
408     }
409 
410     private class Flusher extends IteratingCallback implements WriteCallback
411     {
412         private FrameEntry current;
413         private boolean finished = true;
414         
415         @Override
416         public void failed(Throwable x)
417         {
418             LOG.warn(x);
419             super.failed(x);
420         }
421 
422         @Override
423         protected Action process() throws Exception
424         {
425             if (finished)
426             {
427                 current = pollEntry();
428                 LOG.debug("Processing {}",current);
429                 if (current == null)
430                     return Action.IDLE;
431                 deflate(current);
432             }
433             else
434             {
435                 compress(current,false);
436             }
437             return Action.SCHEDULED;
438         }
439 
440         private void deflate(FrameEntry entry)
441         {
442             Frame frame = entry.frame;
443             BatchMode batchMode = entry.batchMode;
444             if (OpCode.isControlFrame(frame.getOpCode()))
445             {
446                 // Do not deflate control frames
447                 nextOutgoingFrame(frame,this,batchMode);
448                 return;
449             }
450             
451             compress(entry,true);
452         }
453 
454         private void compress(FrameEntry entry, boolean first)
455         {
456             // Get a chunk of the payload to avoid to blow
457             // the heap if the payload is a huge mapped file.
458             Frame frame = entry.frame;
459             ByteBuffer data = frame.getPayload();
460             int remaining = data.remaining();
461             int outputLength = Math.max(256,data.remaining());
462             if (LOG.isDebugEnabled())
463                 LOG.debug("Compressing {}: {} bytes in {} bytes chunk",entry,remaining,outputLength);
464 
465             boolean needsCompress = true;
466             
467             Deflater deflater = getDeflater();
468 
469             if (deflater.needsInput() && !supplyInput(deflater,data))
470             {
471                 // no input supplied
472                 needsCompress = false;
473             }
474             
475             ByteArrayOutputStream out = new ByteArrayOutputStream();
476 
477             byte[] output = new byte[outputLength];
478 
479             boolean fin = frame.isFin();
480 
481             // Compress the data
482             while (needsCompress)
483             {
484                 int compressed = deflater.deflate(output,0,outputLength,Deflater.SYNC_FLUSH);
485 
486                 // Append the output for the eventual frame.
487                 if (LOG.isDebugEnabled())
488                     LOG.debug("Wrote {} bytes to output buffer",compressed);
489                 out.write(output,0,compressed);
490 
491                 if (compressed < outputLength)
492                 {
493                     needsCompress = false;
494                 }
495             }
496 
497             ByteBuffer payload = ByteBuffer.wrap(out.toByteArray());
498 
499             if (payload.remaining() > 0)
500             {
501                 // Handle tail bytes generated by SYNC_FLUSH.
502                 if (LOG.isDebugEnabled())
503                     LOG.debug("compressed bytes[] = {}",BufferUtil.toDetailString(payload));
504 
505                 if (tailDrop == TAIL_DROP_ALWAYS)
506                 {
507                     if (endsWithTail(payload))
508                     {
509                         payload.limit(payload.limit() - TAIL_BYTES.length);
510                     }
511                     if (LOG.isDebugEnabled())
512                         LOG.debug("payload (TAIL_DROP_ALWAYS) = {}",BufferUtil.toDetailString(payload));
513                 }
514                 else if (tailDrop == TAIL_DROP_FIN_ONLY)
515                 {
516                     if (frame.isFin() && endsWithTail(payload))
517                     {
518                         payload.limit(payload.limit() - TAIL_BYTES.length);
519                     }
520                     if (LOG.isDebugEnabled())
521                         LOG.debug("payload (TAIL_DROP_FIN_ONLY) = {}",BufferUtil.toDetailString(payload));
522                 }
523             }
524             else if (fin)
525             {
526                 // Special case: 7.2.3.6.  Generating an Empty Fragment Manually
527                 // https://tools.ietf.org/html/rfc7692#section-7.2.3.6
528                 payload = ByteBuffer.wrap(new byte[] { 0x00 });
529             }
530 
531             if (LOG.isDebugEnabled())
532             {
533                 LOG.debug("Compressed {}: input:{} -> payload:{}",entry,outputLength,payload.remaining());
534             }
535 
536             boolean continuation = frame.getType().isContinuation() || !first;
537             DataFrame chunk = new DataFrame(frame,continuation);
538             if (rsvUse == RSV_USE_ONLY_FIRST)
539             {
540                 chunk.setRsv1(!continuation);
541             }
542             else
543             {
544                 // always set
545                 chunk.setRsv1(true);
546             }
547             chunk.setPayload(payload);
548             chunk.setFin(fin);
549 
550             nextOutgoingFrame(chunk,this,entry.batchMode);
551         }
552 
553         @Override
554         protected void onCompleteSuccess()
555         {
556             // This IteratingCallback never completes.
557         }
558 
559         @Override
560         protected void onCompleteFailure(Throwable x)
561         {
562             // Fail all the frames in the queue.
563             FrameEntry entry;
564             while ((entry = pollEntry()) != null)
565                 notifyCallbackFailure(entry.callback,x);
566         }
567 
568         @Override
569         public void writeSuccess()
570         {
571             if (finished)
572                 notifyCallbackSuccess(current.callback);
573             succeeded();
574         }
575 
576         @Override
577         public void writeFailed(Throwable x)
578         {
579             notifyCallbackFailure(current.callback,x);
580             // If something went wrong, very likely the compression context
581             // will be invalid, so we need to fail this IteratingCallback.
582             failed(x);
583         }
584     }
585 }