001    /**
002     * Copyright (C) 2012 FuseSource, Inc.
003     * http://fusesource.com
004     *
005     * Licensed under the Apache License, Version 2.0 (the "License");
006     * you may not use this file except in compliance with the License.
007     * You may obtain a copy of the License at
008     *
009     *    http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.fusesource.hawtdispatch.transport;
019    
020    import org.fusesource.hawtdispatch.Task;
021    
022    import javax.net.ssl.*;
023    import java.io.EOFException;
024    import java.io.IOException;
025    import java.nio.ByteBuffer;
026    import java.nio.channels.GatheringByteChannel;
027    import java.nio.channels.ReadableByteChannel;
028    import java.nio.channels.ScatteringByteChannel;
029    import java.nio.channels.WritableByteChannel;
030    import java.security.cert.Certificate;
031    import java.security.cert.X509Certificate;
032    import java.util.ArrayList;
033    
034    import static javax.net.ssl.SSLEngineResult.HandshakeStatus.*;
035    import static javax.net.ssl.SSLEngineResult.Status.BUFFER_OVERFLOW;
036    
037    /**
038     * Implements the SSL protocol as a WrappingProtocolCodec.  Useful for when
039     * you want to switch to the SSL protocol on a regular TCP Transport.
040     */
041    public class SslProtocolCodec implements WrappingProtocolCodec, SecuredSession {
042    
043        private ReadableByteChannel readChannel;
044        private WritableByteChannel writeChannel;
045    
046        public enum ClientAuth {
047            WANT, NEED, NONE
048        };
049    
050        private SSLContext sslContext;
051        private SSLEngine engine;
052    
053        private ByteBuffer readBuffer;
054        private boolean readUnderflow;
055    
056        private ByteBuffer writeBuffer;
057        private boolean writeFlushing;
058    
059        private ByteBuffer readOverflowBuffer;
060        Transport transport;
061    
062        int lastReadSize;
063        int lastWriteSize;
064        long readCounter;
065        long writeCounter;
066    
067        ProtocolCodec next;
068    
069    
070        public SslProtocolCodec() {
071        }
072    
073        public ProtocolCodec getNext() {
074            return next;
075        }
076        public void setNext(ProtocolCodec next) {
077            this.next = next;
078            initNext();
079        }
080    
081        private void initNext() {
082            if( next!=null ) {
083                this.next.setTransport(new TransportFilter(transport){
084                    public ReadableByteChannel getReadChannel() {
085                        return sslReadChannel;
086                    }
087                    public WritableByteChannel getWriteChannel() {
088                        return sslWriteChannel;
089                    }
090                });
091            }
092        }
093    
094        public void setSSLContext(SSLContext ctx) {
095            assert engine == null;
096            this.sslContext = ctx;
097        }
098    
099        public SslProtocolCodec client() throws Exception {
100            initializeEngine();
101            engine.setUseClientMode(true);
102            engine.beginHandshake();
103            return this;
104        }
105    
106        public SslProtocolCodec server(ClientAuth clientAuth) throws Exception {
107            initializeEngine();
108            engine.setUseClientMode(false);
109            switch (clientAuth) {
110                case WANT: engine.setWantClientAuth(true); break;
111                case NEED: engine.setNeedClientAuth(true); break;
112                case NONE: engine.setWantClientAuth(false); break;
113            }
114            engine.beginHandshake();
115            return this;
116        }
117    
118        protected void initializeEngine() throws Exception {
119            assert engine == null;
120            if( sslContext == null ) {
121                sslContext = SSLContext.getDefault();
122            }
123            engine = sslContext.createSSLEngine();
124            SSLSession session = engine.getSession();
125            readBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
126            readBuffer.flip();
127            writeBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
128        }
129    
130    
131        public SSLSession getSSLSession() {
132            return engine==null ? null : engine.getSession();
133        }
134    
135        public X509Certificate[] getPeerX509Certificates() {
136            if( engine==null ) {
137                return null;
138            }
139            try {
140                ArrayList<X509Certificate> rc = new ArrayList<X509Certificate>();
141                for( Certificate c:engine.getSession().getPeerCertificates() ) {
142                    if(c instanceof X509Certificate) {
143                        rc.add((X509Certificate) c);
144                    }
145                }
146                return rc.toArray(new X509Certificate[rc.size()]);
147            } catch (SSLPeerUnverifiedException e) {
148                return null;
149            }
150        }
151    
152        SSLReadChannel sslReadChannel = new SSLReadChannel();
153        SSLWriteChannel sslWriteChannel = new SSLWriteChannel();
154    
155        public void setTransport(Transport transport) {
156            this.transport = transport;
157            this.readChannel = transport.getReadChannel();
158            this.writeChannel = transport.getWriteChannel();
159            initNext();
160        }
161    
162        public void handshake() throws IOException {
163            if( !transportFlush() ) {
164                return;
165            }
166            switch (engine.getHandshakeStatus()) {
167                case NEED_TASK:
168                    final Runnable task = engine.getDelegatedTask();
169                    if( task!=null ) {
170                        transport.getBlockingExecutor().execute(new Task() {
171                            public void run() {
172                                task.run();
173                                transport.getDispatchQueue().execute(new Task() {
174                                    public void run() {
175                                        if (readChannel.isOpen() && writeChannel.isOpen()) {
176                                            try {
177                                                handshake();
178                                            } catch (IOException e) {
179                                                transport.getTransportListener().onTransportFailure(e);
180                                            }
181                                        }
182                                    }
183                                });
184                            }
185                        });
186                    }
187                    break;
188    
189                case NEED_WRAP:
190                    secure_write(ByteBuffer.allocate(0));
191                    break;
192    
193                case NEED_UNWRAP:
194                    if( secure_read(ByteBuffer.allocate(0)) == -1) {
195                        throw new EOFException("Peer disconnected during ssl handshake");
196                    }
197                    break;
198    
199                case FINISHED:
200                case NOT_HANDSHAKING:
201                    transport.drainInbound();
202                    transport.getTransportListener().onRefill();
203                    break;
204    
205                default:
206                    System.err.println("Unexpected ssl engine handshake status: "+ engine.getHandshakeStatus());
207                    break;
208            }
209        }
210    
211        /**
212         * @return true if fully flushed.
213         * @throws IOException
214         */
215        protected boolean transportFlush() throws IOException {
216            while (true) {
217                if(writeFlushing) {
218                    lastWriteSize = writeChannel.write(writeBuffer);
219                    if( lastWriteSize > 0 ) {
220                        writeCounter += lastWriteSize;
221                    }
222                    if( !writeBuffer.hasRemaining() ) {
223                        writeBuffer.clear();
224                        writeFlushing = false;
225                        return true;
226                    } else {
227                        return false;
228                    }
229                } else {
230                    if( writeBuffer.position()!=0 ) {
231                        writeBuffer.flip();
232                        writeFlushing = true;
233                    } else {
234                        return true;
235                    }
236                }
237            }
238        }
239    
240        private int secure_read(ByteBuffer plain) throws IOException {
241            int rc=0;
242            while ( plain.hasRemaining() ^ engine.getHandshakeStatus() == NEED_UNWRAP ) {
243                if( readOverflowBuffer !=null ) {
244                    if(  plain.hasRemaining() ) {
245                        // lets drain the overflow buffer before trying to suck down anymore
246                        // network bytes.
247                        int size = Math.min(plain.remaining(), readOverflowBuffer.remaining());
248                        plain.put(readOverflowBuffer.array(), readOverflowBuffer.position(), size);
249                        readOverflowBuffer.position(readOverflowBuffer.position()+size);
250                        if( !readOverflowBuffer.hasRemaining() ) {
251                            readOverflowBuffer = null;
252                        }
253                        rc += size;
254                    } else {
255                        return rc;
256                    }
257                } else if( readUnderflow ) {
258                    lastReadSize = readChannel.read(readBuffer);
259                    if( lastReadSize == -1 ) {  // peer closed socket.
260                        if (rc==0) {
261                            return -1;
262                        } else {
263                            return rc;
264                        }
265                    }
266                    if( lastReadSize==0 ) {  // no data available right now.
267                        return rc;
268                    }
269                    readCounter += lastReadSize;
270                    // read in some more data, perhaps now we can unwrap.
271                    readUnderflow = false;
272                    readBuffer.flip();
273                } else {
274                    SSLEngineResult result = engine.unwrap(readBuffer, plain);
275                    rc += result.bytesProduced();
276                    if( result.getStatus() == BUFFER_OVERFLOW ) {
277                        readOverflowBuffer = ByteBuffer.allocate(engine.getSession().getApplicationBufferSize());
278                        result = engine.unwrap(readBuffer, readOverflowBuffer);
279                        if( readOverflowBuffer.position()==0 ) {
280                            readOverflowBuffer = null;
281                        } else {
282                            readOverflowBuffer.flip();
283                        }
284                    }
285                    switch( result.getStatus() ) {
286                        case CLOSED:
287                            if (rc==0) {
288                                engine.closeInbound();
289                                return -1;
290                            } else {
291                                return rc;
292                            }
293                        case OK:
294                            if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
295                                handshake();
296                            }
297                            break;
298                        case BUFFER_UNDERFLOW:
299                            readBuffer.compact();
300                            readUnderflow = true;
301                            break;
302                        case BUFFER_OVERFLOW:
303                            throw new AssertionError("Unexpected case.");
304                    }
305                }
306            }
307            return rc;
308        }
309    
310        private int secure_write(ByteBuffer plain) throws IOException {
311            if( !transportFlush() ) {
312                // can't write anymore until the write_secured_buffer gets fully flushed out..
313                return 0;
314            }
315            int rc = 0;
316            while ( plain.hasRemaining() ^ engine.getHandshakeStatus()==NEED_WRAP ) {
317                SSLEngineResult result = engine.wrap(plain, writeBuffer);
318                assert result.getStatus()!= BUFFER_OVERFLOW;
319                rc += result.bytesConsumed();
320                if( !transportFlush() ) {
321                    break;
322                }
323            }
324            if( plain.remaining()==0 && engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
325                handshake();
326            }
327            return rc;
328        }
329    
330        public class SSLReadChannel implements ScatteringByteChannel {
331    
332            public int read(ByteBuffer plain) throws IOException {
333                if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
334                    handshake();
335                }
336                return secure_read(plain);
337            }
338    
339            public boolean isOpen() {
340                return readChannel.isOpen();
341            }
342    
343            public void close() throws IOException {
344                readChannel.close();
345            }
346    
347            public long read(ByteBuffer[] dsts, int offset, int length) throws IOException {
348                if(offset+length > dsts.length || length<0 || offset<0) {
349                    throw new IndexOutOfBoundsException();
350                }
351                long rc=0;
352                for (int i = 0; i < length; i++) {
353                    ByteBuffer dst = dsts[offset+i];
354                    if(dst.hasRemaining()) {
355                        rc += read(dst);
356                    }
357                    if( dst.hasRemaining() ) {
358                        return rc;
359                    }
360                }
361                return rc;
362            }
363    
364            public long read(ByteBuffer[] dsts) throws IOException {
365                return read(dsts, 0, dsts.length);
366            }
367        }
368    
369        public class SSLWriteChannel implements GatheringByteChannel {
370    
371            public int write(ByteBuffer plain) throws IOException {
372                if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
373                    handshake();
374                }
375                return secure_write(plain);
376            }
377    
378            public boolean isOpen() {
379                return writeChannel.isOpen();
380            }
381    
382            public void close() throws IOException {
383                writeChannel.close();
384            }
385    
386            public long write(ByteBuffer[] srcs, int offset, int length) throws IOException {
387                if(offset+length > srcs.length || length<0 || offset<0) {
388                    throw new IndexOutOfBoundsException();
389                }
390                long rc=0;
391                for (int i = 0; i < length; i++) {
392                    ByteBuffer src = srcs[offset+i];
393                    if(src.hasRemaining()) {
394                        rc += write(src);
395                    }
396                    if( src.hasRemaining() ) {
397                        return rc;
398                    }
399                }
400                return rc;
401            }
402    
403            public long write(ByteBuffer[] srcs) throws IOException {
404                return write(srcs, 0, srcs.length);
405            }
406        }
407    
408        public void unread(byte[] buffer) {
409            readBuffer.compact();
410            if( readBuffer.remaining() < buffer.length) {
411                throw new IllegalStateException("Cannot unread now");
412            }
413            readBuffer.put(buffer);
414            readBuffer.flip();
415        }
416    
417        public Object read() throws IOException {
418            return next.read();
419        }
420    
421        public ProtocolCodec.BufferState write(Object value) throws IOException {
422            return next.write(value);
423        }
424    
425        public ProtocolCodec.BufferState flush() throws IOException {
426            return next.flush();
427        }
428    
429        public boolean full() {
430            return next.full();
431        }
432    
433        public long getWriteCounter() {
434            return writeCounter;
435        }
436    
437        public long getLastWriteSize() {
438            return lastWriteSize;
439        }
440    
441        public long getReadCounter() {
442            return readCounter;
443        }
444    
445        public long getLastReadSize() {
446            return lastReadSize;
447        }
448    
449        public int getReadBufferSize() {
450            return readBuffer.capacity();
451        }
452    
453        public int getWriteBufferSize() {
454            return writeBuffer.capacity();
455        }
456    
457    
458    
459    }