001 /** 002 * Copyright (C) 2012 FuseSource, Inc. 003 * http://fusesource.com 004 * 005 * Licensed under the Apache License, Version 2.0 (the "License"); 006 * you may not use this file except in compliance with the License. 007 * You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018 package org.fusesource.hawtdispatch.transport; 019 020 import org.fusesource.hawtdispatch.Task; 021 022 import javax.net.ssl.*; 023 import java.io.EOFException; 024 import java.io.IOException; 025 import java.net.Socket; 026 import java.net.URI; 027 import java.nio.ByteBuffer; 028 import java.nio.channels.*; 029 import java.security.cert.Certificate; 030 import java.security.cert.X509Certificate; 031 import java.util.ArrayList; 032 033 import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_UNWRAP; 034 import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_WRAP; 035 import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING; 036 import static javax.net.ssl.SSLEngineResult.Status.BUFFER_OVERFLOW; 037 038 /** 039 * An SSL Transport for secure communications. 040 * 041 * @author <a href="http://hiramchirino.com">Hiram Chirino</a> 042 */ 043 public class SslTransport extends TcpTransport implements SecuredSession { 044 045 046 /** 047 * Maps uri schemes to a protocol algorithm names. 048 * Valid algorithm names listed at: 049 * http://download.oracle.com/javase/6/docs/technotes/guides/security/StandardNames.html#SSLContext 050 */ 051 public static String protocol(String scheme) { 052 if( scheme.equals("tls") ) { 053 return "TLS"; 054 } else if( scheme.startsWith("tlsv") ) { 055 return "TLSv"+scheme.substring(4); 056 } else if( scheme.equals("ssl") ) { 057 return "SSL"; 058 } else if( scheme.startsWith("sslv") ) { 059 return "SSLv"+scheme.substring(4); 060 } 061 return null; 062 } 063 064 enum ClientAuth { 065 WANT, NEED, NONE 066 }; 067 068 private ClientAuth clientAuth = ClientAuth.WANT; 069 070 private SSLContext sslContext; 071 private SSLEngine engine; 072 073 private ByteBuffer readBuffer; 074 private boolean readUnderflow; 075 076 private ByteBuffer writeBuffer; 077 private boolean writeFlushing; 078 079 private ByteBuffer readOverflowBuffer; 080 private SSLChannel ssl_channel = new SSLChannel(); 081 082 083 public void setSSLContext(SSLContext ctx) { 084 this.sslContext = ctx; 085 } 086 087 /** 088 * Allows subclasses of TcpTransportFactory to create custom instances of 089 * TcpTransport. 090 */ 091 public static SslTransport createTransport(URI uri) throws Exception { 092 String protocol = protocol(uri.getScheme()); 093 if( protocol !=null ) { 094 SslTransport rc = new SslTransport(); 095 rc.setSSLContext(SSLContext.getInstance(protocol)); 096 return rc; 097 } 098 return null; 099 } 100 101 public class SSLChannel implements ScatteringByteChannel, GatheringByteChannel { 102 103 public int write(ByteBuffer plain) throws IOException { 104 return secure_write(plain); 105 } 106 107 public int read(ByteBuffer plain) throws IOException { 108 return secure_read(plain); 109 } 110 111 public boolean isOpen() { 112 return getSocketChannel().isOpen(); 113 } 114 115 public void close() throws IOException { 116 getSocketChannel().close(); 117 } 118 119 public long write(ByteBuffer[] srcs, int offset, int length) throws IOException { 120 if(offset+length > srcs.length || length<0 || offset<0) { 121 throw new IndexOutOfBoundsException(); 122 } 123 long rc=0; 124 for (int i = 0; i < length; i++) { 125 ByteBuffer src = srcs[offset+i]; 126 if(src.hasRemaining()) { 127 rc += write(src); 128 } 129 if( src.hasRemaining() ) { 130 return rc; 131 } 132 } 133 return rc; 134 } 135 136 public long write(ByteBuffer[] srcs) throws IOException { 137 return write(srcs, 0, srcs.length); 138 } 139 140 public long read(ByteBuffer[] dsts, int offset, int length) throws IOException { 141 if(offset+length > dsts.length || length<0 || offset<0) { 142 throw new IndexOutOfBoundsException(); 143 } 144 long rc=0; 145 for (int i = 0; i < length; i++) { 146 ByteBuffer dst = dsts[offset+i]; 147 if(dst.hasRemaining()) { 148 rc += read(dst); 149 } 150 if( dst.hasRemaining() ) { 151 return rc; 152 } 153 } 154 return rc; 155 } 156 157 public long read(ByteBuffer[] dsts) throws IOException { 158 return read(dsts, 0, dsts.length); 159 } 160 161 public Socket socket() { 162 SocketChannel c = channel; 163 if( c == null ) { 164 return null; 165 } 166 return c.socket(); 167 } 168 } 169 170 public SSLSession getSSLSession() { 171 return engine==null ? null : engine.getSession(); 172 } 173 174 public X509Certificate[] getPeerX509Certificates() { 175 if( engine==null ) { 176 return null; 177 } 178 try { 179 ArrayList<X509Certificate> rc = new ArrayList<X509Certificate>(); 180 for( Certificate c:engine.getSession().getPeerCertificates() ) { 181 if(c instanceof X509Certificate) { 182 rc.add((X509Certificate) c); 183 } 184 } 185 return rc.toArray(new X509Certificate[rc.size()]); 186 } catch (SSLPeerUnverifiedException e) { 187 return null; 188 } 189 } 190 191 @Override 192 public void connecting(URI remoteLocation, URI localLocation) throws Exception { 193 assert engine == null; 194 engine = sslContext.createSSLEngine(); 195 engine.setUseClientMode(true); 196 super.connecting(remoteLocation, localLocation); 197 } 198 199 @Override 200 public void connected(SocketChannel channel) throws Exception { 201 if (engine == null) { 202 engine = sslContext.createSSLEngine(); 203 engine.setUseClientMode(false); 204 switch (clientAuth) { 205 case WANT: engine.setWantClientAuth(true); break; 206 case NEED: engine.setNeedClientAuth(true); break; 207 case NONE: engine.setWantClientAuth(false); break; 208 } 209 210 } 211 super.connected(channel); 212 } 213 214 @Override 215 protected void initializeChannel() throws Exception { 216 super.initializeChannel(); 217 SSLSession session = engine.getSession(); 218 readBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize()); 219 readBuffer.flip(); 220 writeBuffer = ByteBuffer.allocateDirect(session.getPacketBufferSize()); 221 } 222 223 @Override 224 protected void onConnected() throws IOException { 225 super.onConnected(); 226 engine.beginHandshake(); 227 handshake(); 228 } 229 230 @Override 231 public void flush() { 232 if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) { 233 handshake(); 234 } else { 235 super.flush(); 236 } 237 } 238 239 @Override 240 public void drainInbound() { 241 if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) { 242 handshake(); 243 } else { 244 super.drainInbound(); 245 } 246 } 247 248 /** 249 * @return true if fully flushed. 250 * @throws IOException 251 */ 252 protected boolean transportFlush() throws IOException { 253 while (true) { 254 if(writeFlushing) { 255 int count = super.getWriteChannel().write(writeBuffer); 256 if( !writeBuffer.hasRemaining() ) { 257 writeBuffer.clear(); 258 writeFlushing = false; 259 suspendWrite(); 260 return true; 261 } else { 262 return false; 263 } 264 } else { 265 if( writeBuffer.position()!=0 ) { 266 writeBuffer.flip(); 267 writeFlushing = true; 268 resumeWrite(); 269 } else { 270 return true; 271 } 272 } 273 } 274 } 275 276 private int secure_write(ByteBuffer plain) throws IOException { 277 if( !transportFlush() ) { 278 // can't write anymore until the write_secured_buffer gets fully flushed out.. 279 return 0; 280 } 281 int rc = 0; 282 while ( plain.hasRemaining() ^ engine.getHandshakeStatus()==NEED_WRAP ) { 283 SSLEngineResult result = engine.wrap(plain, writeBuffer); 284 assert result.getStatus()!= BUFFER_OVERFLOW; 285 rc += result.bytesConsumed(); 286 if( !transportFlush() ) { 287 break; 288 } 289 } 290 if( plain.remaining()==0 && engine.getHandshakeStatus()!=NOT_HANDSHAKING ) { 291 dispatchQueue.execute(new Task() { 292 public void run() { 293 handshake(); 294 } 295 }); 296 } 297 return rc; 298 } 299 300 private int secure_read(ByteBuffer plain) throws IOException { 301 int rc=0; 302 while ( plain.hasRemaining() ^ engine.getHandshakeStatus() == NEED_UNWRAP ) { 303 if( readOverflowBuffer !=null ) { 304 if( plain.hasRemaining() ) { 305 // lets drain the overflow buffer before trying to suck down anymore 306 // network bytes. 307 int size = Math.min(plain.remaining(), readOverflowBuffer.remaining()); 308 plain.put(readOverflowBuffer.array(), readOverflowBuffer.position(), size); 309 readOverflowBuffer.position(readOverflowBuffer.position()+size); 310 if( !readOverflowBuffer.hasRemaining() ) { 311 readOverflowBuffer = null; 312 } 313 rc += size; 314 } else { 315 return rc; 316 } 317 } else if( readUnderflow ) { 318 int count = super.getReadChannel().read(readBuffer); 319 if( count == -1 ) { // peer closed socket. 320 if (rc==0) { 321 return -1; 322 } else { 323 return rc; 324 } 325 } 326 if( count==0 ) { // no data available right now. 327 return rc; 328 } 329 // read in some more data, perhaps now we can unwrap. 330 readUnderflow = false; 331 readBuffer.flip(); 332 } else { 333 SSLEngineResult result = engine.unwrap(readBuffer, plain); 334 rc += result.bytesProduced(); 335 if( result.getStatus() == BUFFER_OVERFLOW ) { 336 readOverflowBuffer = ByteBuffer.allocate(engine.getSession().getApplicationBufferSize()); 337 result = engine.unwrap(readBuffer, readOverflowBuffer); 338 if( readOverflowBuffer.position()==0 ) { 339 readOverflowBuffer = null; 340 } else { 341 readOverflowBuffer.flip(); 342 } 343 } 344 switch( result.getStatus() ) { 345 case CLOSED: 346 if (rc==0) { 347 engine.closeInbound(); 348 return -1; 349 } else { 350 return rc; 351 } 352 case OK: 353 if ( engine.getHandshakeStatus()!=NOT_HANDSHAKING ) { 354 dispatchQueue.execute(new Task() { 355 public void run() { 356 handshake(); 357 } 358 }); 359 } 360 break; 361 case BUFFER_UNDERFLOW: 362 readBuffer.compact(); 363 readUnderflow = true; 364 break; 365 case BUFFER_OVERFLOW: 366 throw new AssertionError("Unexpected case."); 367 } 368 } 369 } 370 return rc; 371 } 372 373 public void handshake() { 374 try { 375 if( !transportFlush() ) { 376 return; 377 } 378 switch (engine.getHandshakeStatus()) { 379 case NEED_TASK: 380 final Runnable task = engine.getDelegatedTask(); 381 if( task!=null ) { 382 blockingExecutor.execute(new Task() { 383 public void run() { 384 task.run(); 385 dispatchQueue.execute(new Task() { 386 public void run() { 387 if (isConnected()) { 388 handshake(); 389 } 390 } 391 }); 392 } 393 }); 394 } 395 break; 396 397 case NEED_WRAP: 398 secure_write(ByteBuffer.allocate(0)); 399 break; 400 401 case NEED_UNWRAP: 402 if( secure_read(ByteBuffer.allocate(0)) == -1) { 403 throw new EOFException("Peer disconnected during ssl handshake"); 404 } 405 break; 406 407 case FINISHED: 408 case NOT_HANDSHAKING: 409 drainOutboundSource.merge(1); 410 drainInbound(); 411 break; 412 413 default: 414 System.err.println("Unexpected ssl engine handshake status: "+ engine.getHandshakeStatus()); 415 break; 416 } 417 } catch (IOException e ) { 418 onTransportFailure(e); 419 } 420 } 421 422 423 public ReadableByteChannel getReadChannel() { 424 return ssl_channel; 425 } 426 427 public WritableByteChannel getWriteChannel() { 428 return ssl_channel; 429 } 430 431 public String getClientAuth() { 432 return clientAuth.name(); 433 } 434 435 public void setClientAuth(String clientAuth) { 436 this.clientAuth = ClientAuth.valueOf(clientAuth.toUpperCase()); 437 } 438 } 439 440