001    /**
002     * Copyright (C) 2012 FuseSource, Inc.
003     * http://fusesource.com
004     *
005     * Licensed under the Apache License, Version 2.0 (the "License");
006     * you may not use this file except in compliance with the License.
007     * You may obtain a copy of the License at
008     *
009     *    http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.fusesource.hawtdispatch.transport;
019    
020    import org.fusesource.hawtdispatch.Task;
021    
022    import javax.net.ssl.*;
023    import java.io.EOFException;
024    import java.io.IOException;
025    import java.net.Socket;
026    import java.net.URI;
027    import java.nio.ByteBuffer;
028    import java.nio.channels.*;
029    import java.security.cert.Certificate;
030    import java.security.cert.X509Certificate;
031    import java.util.ArrayList;
032    
033    import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_UNWRAP;
034    import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_WRAP;
035    import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING;
036    import static javax.net.ssl.SSLEngineResult.Status.BUFFER_OVERFLOW;
037    
038    /**
039     * An SSL Transport for secure communications.
040     *
041     * @author <a href="http://hiramchirino.com">Hiram Chirino</a>
042     */
043    public class SslTransport extends TcpTransport implements SecuredSession {
044    
045    
046        /**
047         * Maps uri schemes to a protocol algorithm names.
048         * Valid algorithm names listed at:
049         * http://download.oracle.com/javase/6/docs/technotes/guides/security/StandardNames.html#SSLContext
050         */
051        public static String protocol(String scheme) {
052            if( scheme.equals("tls") ) {
053                return "TLS";
054            } else if( scheme.startsWith("tlsv") ) {
055                return "TLSv"+scheme.substring(4);
056            } else if( scheme.equals("ssl") ) {
057                return "SSL";
058            } else if( scheme.startsWith("sslv") ) {
059                return "SSLv"+scheme.substring(4);
060            }
061            return null;
062        }
063    
064        enum ClientAuth {
065            WANT, NEED, NONE
066        };
067    
068        private ClientAuth clientAuth = ClientAuth.WANT;
069    
070        private SSLContext sslContext;
071        private SSLEngine engine;
072    
073        private ByteBuffer readBuffer;
074        private boolean readUnderflow;
075    
076        private ByteBuffer writeBuffer;
077        private boolean writeFlushing;
078    
079        private ByteBuffer readOverflowBuffer;
080        private SSLChannel ssl_channel = new SSLChannel();
081    
082    
083        public void setSSLContext(SSLContext ctx) {
084            this.sslContext = ctx;
085        }
086    
087        /**
088         * Allows subclasses of TcpTransportFactory to create custom instances of
089         * TcpTransport.
090         */
091        public static SslTransport createTransport(URI uri) throws Exception {
092            String protocol = protocol(uri.getScheme());
093            if( protocol !=null ) {
094                SslTransport rc = new SslTransport();
095                rc.setSSLContext(SSLContext.getInstance(protocol));
096                return rc;
097            }
098            return null;
099        }
100    
101        public class SSLChannel implements ScatteringByteChannel, GatheringByteChannel {
102    
103            public int write(ByteBuffer plain) throws IOException {
104                return secure_write(plain);
105            }
106    
107            public int read(ByteBuffer plain) throws IOException {
108                return secure_read(plain);
109            }
110    
111            public boolean isOpen() {
112                return getSocketChannel().isOpen();
113            }
114    
115            public void close() throws IOException {
116                getSocketChannel().close();
117            }
118    
119            public long write(ByteBuffer[] srcs, int offset, int length) throws IOException {
120                if(offset+length > srcs.length || length<0 || offset<0) {
121                    throw new IndexOutOfBoundsException();
122                }
123                long rc=0;
124                for (int i = 0; i < length; i++) {
125                    ByteBuffer src = srcs[offset+i];
126                    if(src.hasRemaining()) {
127                        rc += write(src);
128                    }
129                    if( src.hasRemaining() ) {
130                        return rc;
131                    }
132                }
133                return rc;
134            }
135    
136            public long write(ByteBuffer[] srcs) throws IOException {
137                return write(srcs, 0, srcs.length);
138            }
139    
140            public long read(ByteBuffer[] dsts, int offset, int length) throws IOException {
141                if(offset+length > dsts.length || length<0 || offset<0) {
142                    throw new IndexOutOfBoundsException();
143                }
144                long rc=0;
145                for (int i = 0; i < length; i++) {
146                    ByteBuffer dst = dsts[offset+i];
147                    if(dst.hasRemaining()) {
148                        rc += read(dst);
149                    }
150                    if( dst.hasRemaining() ) {
151                        return rc;
152                    }
153                }
154                return rc;
155            }
156    
157            public long read(ByteBuffer[] dsts) throws IOException {
158                return read(dsts, 0, dsts.length);
159            }
160            
161            public Socket socket() {
162                SocketChannel c = channel;
163                if( c == null ) {
164                    return null;
165                }
166                return c.socket();
167            }
168        }
169    
170        public SSLSession getSSLSession() {
171            return engine==null ? null : engine.getSession();
172        }
173    
174        public X509Certificate[] getPeerX509Certificates() {
175            if( engine==null ) {
176                return null;
177            }
178            try {
179                ArrayList<X509Certificate> rc = new ArrayList<X509Certificate>();
180                for( Certificate c:engine.getSession().getPeerCertificates() ) {
181                    if(c instanceof X509Certificate) {
182                        rc.add((X509Certificate) c);
183                    }
184                }
185                return rc.toArray(new X509Certificate[rc.size()]);
186            } catch (SSLPeerUnverifiedException e) {
187                return null;
188            }
189        }
190    
191        @Override
192        public void connecting(URI remoteLocation, URI localLocation) throws Exception {
193            assert engine == null;
194            engine = sslContext.createSSLEngine();
195            engine.setUseClientMode(true);
196            super.connecting(remoteLocation, localLocation);
197        }
198    
199        @Override
200        public void connected(SocketChannel channel) throws Exception {
201            if (engine == null) {
202                engine = sslContext.createSSLEngine();
203                engine.setUseClientMode(false);
204                switch (clientAuth) {
205                    case WANT: engine.setWantClientAuth(true); break;
206                    case NEED: engine.setNeedClientAuth(true); break;
207                    case NONE: engine.setWantClientAuth(false); break;
208                }
209    
210            }
211            super.connected(channel);
212        }
213    
214        @Override
215        protected void initializeChannel() throws Exception {
216            super.initializeChannel();
217            SSLSession session = engine.getSession();
218            readBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
219            readBuffer.flip();
220            writeBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize());
221        }
222    
223        @Override
224        protected void onConnected() throws IOException {
225            super.onConnected();
226            engine.beginHandshake();
227            handshake();
228        }
229    
230        @Override
231        public void flush() {
232            if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
233                handshake();
234            } else {
235                super.flush();
236            }
237        }
238    
239        @Override
240        public void drainInbound() {
241            if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
242                handshake();
243            } else {
244                super.drainInbound();
245            }
246        }
247    
248        /**
249         * @return true if fully flushed.
250         * @throws IOException
251         */
252        protected boolean transportFlush() throws IOException {
253            while (true) {
254                if(writeFlushing) {
255                    int count = super.getWriteChannel().write(writeBuffer);
256                    if( !writeBuffer.hasRemaining() ) {
257                        writeBuffer.clear();
258                        writeFlushing = false;
259                        suspendWrite();
260                        return true;
261                    } else {
262                        return false;
263                    }
264                } else {
265                    if( writeBuffer.position()!=0 ) {
266                        writeBuffer.flip();
267                        writeFlushing = true;
268                        resumeWrite();
269                    } else {
270                        return true;
271                    }
272                }
273            }
274        }
275    
276        private int secure_write(ByteBuffer plain) throws IOException {
277            if( !transportFlush() ) {
278                // can't write anymore until the write_secured_buffer gets fully flushed out..
279                return 0;
280            }
281            int rc = 0;
282            while ( plain.hasRemaining() ^ engine.getHandshakeStatus()==NEED_WRAP ) {
283                SSLEngineResult result = engine.wrap(plain, writeBuffer);
284                assert result.getStatus()!= BUFFER_OVERFLOW;
285                rc += result.bytesConsumed();
286                if( !transportFlush() ) {
287                    break;
288                }
289            }
290            if( plain.remaining()==0 && engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
291                dispatchQueue.execute(new Task() {
292                    public void run() {
293                        handshake();
294                    }
295                });
296            }
297            return rc;
298        }
299    
300        private int secure_read(ByteBuffer plain) throws IOException {
301            int rc=0;
302            while ( plain.hasRemaining() ^ engine.getHandshakeStatus() == NEED_UNWRAP ) {
303                if( readOverflowBuffer !=null ) {
304                    if(  plain.hasRemaining() ) {
305                        // lets drain the overflow buffer before trying to suck down anymore
306                        // network bytes.
307                        int size = Math.min(plain.remaining(), readOverflowBuffer.remaining());
308                        plain.put(readOverflowBuffer.array(), readOverflowBuffer.position(), size);
309                        readOverflowBuffer.position(readOverflowBuffer.position()+size);
310                        if( !readOverflowBuffer.hasRemaining() ) {
311                            readOverflowBuffer = null;
312                        }
313                        rc += size;
314                    } else {
315                        return rc;
316                    }
317                } else if( readUnderflow ) {
318                    int count = super.getReadChannel().read(readBuffer);
319                    if( count == -1 ) {  // peer closed socket.
320                        if (rc==0) {
321                            return -1;
322                        } else {
323                            return rc;
324                        }
325                    }
326                    if( count==0 ) {  // no data available right now.
327                        return rc;
328                    }
329                    // read in some more data, perhaps now we can unwrap.
330                    readUnderflow = false;
331                    readBuffer.flip();
332                } else {
333                    SSLEngineResult result = engine.unwrap(readBuffer, plain);
334                    rc += result.bytesProduced();
335                    if( result.getStatus() == BUFFER_OVERFLOW ) {
336                        readOverflowBuffer = ByteBuffer.allocate(engine.getSession().getApplicationBufferSize());
337                        result = engine.unwrap(readBuffer, readOverflowBuffer);
338                        if( readOverflowBuffer.position()==0 ) {
339                            readOverflowBuffer = null;
340                        } else {
341                            readOverflowBuffer.flip();
342                        }
343                    }
344                    switch( result.getStatus() ) {
345                        case CLOSED:
346                            if (rc==0) {
347                                engine.closeInbound();
348                                return -1;
349                            } else {
350                                return rc;
351                            }
352                        case OK:
353                            if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) {
354                                dispatchQueue.execute(new Task() {
355                                    public void run() {
356                                        handshake();
357                                    }
358                                });
359                            }
360                            break;
361                        case BUFFER_UNDERFLOW:
362                            readBuffer.compact();
363                            readUnderflow = true;
364                            break;
365                        case BUFFER_OVERFLOW:
366                            throw new AssertionError("Unexpected case.");
367                    }
368                }
369            }
370            return rc;
371        }
372    
373        public void handshake() {
374            try {
375                if( !transportFlush() ) {
376                    return;
377                }
378                switch (engine.getHandshakeStatus()) {
379                    case NEED_TASK:
380                        final Runnable task = engine.getDelegatedTask();
381                        if( task!=null ) {
382                            blockingExecutor.execute(new Task() {
383                                public void run() {
384                                    task.run();
385                                    dispatchQueue.execute(new Task() {
386                                        public void run() {
387                                            if (isConnected()) {
388                                                handshake();
389                                            }
390                                        }
391                                    });
392                                }
393                            });
394                        }
395                        break;
396    
397                    case NEED_WRAP:
398                        secure_write(ByteBuffer.allocate(0));
399                        break;
400    
401                    case NEED_UNWRAP:
402                        if( secure_read(ByteBuffer.allocate(0)) == -1) {
403                            throw new EOFException("Peer disconnected during ssl handshake");
404                        }
405                        break;
406    
407                    case FINISHED:
408                    case NOT_HANDSHAKING:
409                        drainOutboundSource.merge(1);
410                        drainInbound();
411                        break;
412    
413                    default:
414                        System.err.println("Unexpected ssl engine handshake status: "+ engine.getHandshakeStatus());
415                        break;
416                }
417            } catch (IOException e ) {
418                onTransportFailure(e);
419            }
420        }
421    
422    
423        public ReadableByteChannel getReadChannel() {
424            return ssl_channel;
425        }
426    
427        public WritableByteChannel getWriteChannel() {
428            return ssl_channel;
429        }
430    
431        public String getClientAuth() {
432            return clientAuth.name();
433        }
434    
435        public void setClientAuth(String clientAuth) {
436            this.clientAuth = ClientAuth.valueOf(clientAuth.toUpperCase());
437        }
438    }
439    
440