001    /**
002     * Copyright (C) 2012 FuseSource, Inc.
003     * http://fusesource.com
004     *
005     * Licensed under the Apache License, Version 2.0 (the "License");
006     * you may not use this file except in compliance with the License.
007     * You may obtain a copy of the License at
008     *
009     *    http://www.apache.org/licenses/LICENSE-2.0
010     *
011     * Unless required by applicable law or agreed to in writing, software
012     * distributed under the License is distributed on an "AS IS" BASIS,
013     * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014     * See the License for the specific language governing permissions and
015     * limitations under the License.
016     */
017    
018    package org.fusesource.hawtdispatch.transport;
019    
020    import org.fusesource.hawtdispatch.*;
021    
022    import java.io.IOException;
023    import java.net.*;
024    import java.nio.ByteBuffer;
025    import java.nio.channels.ReadableByteChannel;
026    import java.nio.channels.SelectionKey;
027    import java.nio.channels.SocketChannel;
028    import java.nio.channels.WritableByteChannel;
029    import java.util.LinkedList;
030    import java.util.concurrent.Executor;
031    import java.util.concurrent.TimeUnit;
032    
033    /**
034     * An implementation of the {@link org.fusesource.hawtdispatch.transport.Transport} interface using raw tcp/ip
035     *
036     * @author <a href="http://hiramchirino.com">Hiram Chirino</a>
037     */
038    public class TcpTransport extends ServiceBase implements Transport {
039    
040        static InetAddress localhost;
041        synchronized static public InetAddress getLocalHost() throws UnknownHostException {
042            // cache it...
043            if( localhost==null ) {
044                // this can be slow on some systems and we use repeatedly.
045                localhost = InetAddress.getLocalHost();
046            }
047            return localhost;
048        }
049    
050        abstract static class SocketState {
051            void onStop(Task onCompleted) {
052            }
053            void onCanceled() {
054            }
055            boolean is(Class<? extends SocketState> clazz) {
056                return getClass()==clazz;
057            }
058        }
059    
060        static class DISCONNECTED extends SocketState{}
061    
062        class CONNECTING extends SocketState{
063            void onStop(Task onCompleted) {
064                trace("CONNECTING.onStop");
065                CANCELING state = new CANCELING();
066                socketState = state;
067                state.onStop(onCompleted);
068            }
069            void onCanceled() {
070                trace("CONNECTING.onCanceled");
071                CANCELING state = new CANCELING();
072                socketState = state;
073                state.onCanceled();
074            }
075        }
076    
077        class CONNECTED extends SocketState {
078    
079            public CONNECTED() {
080                localAddress = channel.socket().getLocalSocketAddress();
081                remoteAddress = channel.socket().getRemoteSocketAddress();
082            }
083    
084            void onStop(Task onCompleted) {
085                trace("CONNECTED.onStop");
086                CANCELING state = new CANCELING();
087                socketState = state;
088                state.add(createDisconnectTask());
089                state.onStop(onCompleted);
090            }
091            void onCanceled() {
092                trace("CONNECTED.onCanceled");
093                CANCELING state = new CANCELING();
094                socketState = state;
095                state.add(createDisconnectTask());
096                state.onCanceled();
097            }
098            Task createDisconnectTask() {
099                return new Task(){
100                    public void run() {
101                        listener.onTransportDisconnected();
102                    }
103                };
104            }
105        }
106    
107        class CANCELING extends SocketState {
108            private LinkedList<Task> runnables =  new LinkedList<Task>();
109            private int remaining;
110            private boolean dispose;
111    
112            public CANCELING() {
113                if( readSource!=null ) {
114                    remaining++;
115                    readSource.cancel();
116                }
117                if( writeSource!=null ) {
118                    remaining++;
119                    writeSource.cancel();
120                }
121            }
122            void onStop(Task onCompleted) {
123                trace("CANCELING.onCompleted");
124                add(onCompleted);
125                dispose = true;
126            }
127            void add(Task onCompleted) {
128                if( onCompleted!=null ) {
129                    runnables.add(onCompleted);
130                }
131            }
132            void onCanceled() {
133                trace("CANCELING.onCanceled");
134                remaining--;
135                if( remaining!=0 ) {
136                    return;
137                }
138                try {
139                    if( closeOnCancel ) {
140                        channel.close();
141                    }
142                } catch (IOException ignore) {
143                }
144                socketState = new CANCELED(dispose);
145                for (Task runnable : runnables) {
146                    runnable.run();
147                }
148                if (dispose) {
149                    dispose();
150                }
151            }
152        }
153    
154        class CANCELED extends SocketState {
155            private boolean disposed;
156    
157            public CANCELED(boolean disposed) {
158                this.disposed=disposed;
159            }
160    
161            void onStop(Task onCompleted) {
162                trace("CANCELED.onStop");
163                if( !disposed ) {
164                    disposed = true;
165                    dispose();
166                }
167                onCompleted.run();
168            }
169        }
170    
171        protected URI remoteLocation;
172        protected URI localLocation;
173        protected TransportListener listener;
174        protected ProtocolCodec codec;
175    
176        protected SocketChannel channel;
177    
178        protected SocketState socketState = new DISCONNECTED();
179    
180        protected DispatchQueue dispatchQueue;
181        private DispatchSource readSource;
182        private DispatchSource writeSource;
183        protected CustomDispatchSource<Integer, Integer> drainOutboundSource;
184        protected CustomDispatchSource<Integer, Integer> yieldSource;
185    
186        protected boolean useLocalHost = true;
187    
188        int maxReadRate;
189        int maxWriteRate;
190        int receiveBufferSize = 1024*64;
191        int sendBufferSize = 1024*64;
192        boolean closeOnCancel = true;
193    
194        boolean keepAlive = true;
195    
196        public static final int IPTOS_LOWCOST = 0x02;
197        public static final int IPTOS_RELIABILITY = 0x04;
198        public static final int IPTOS_THROUGHPUT = 0x08;
199        public static final int IPTOS_LOWDELAY = 0x10;
200    
201        int trafficClass = IPTOS_THROUGHPUT;
202    
203        protected RateLimitingChannel rateLimitingChannel;
204        SocketAddress localAddress;
205        SocketAddress remoteAddress;
206        protected Executor blockingExecutor;
207    
208        class RateLimitingChannel implements ReadableByteChannel, WritableByteChannel {
209    
210            int read_allowance = maxReadRate;
211            boolean read_suspended = false;
212            int read_resume_counter = 0;
213            int write_allowance = maxWriteRate;
214            boolean write_suspended = false;
215    
216            public void resetAllowance() {
217                if( read_allowance != maxReadRate || write_allowance != maxWriteRate) {
218                    read_allowance = maxReadRate;
219                    write_allowance = maxWriteRate;
220                    if( write_suspended ) {
221                        write_suspended = false;
222                        resumeWrite();
223                    }
224                    if( read_suspended ) {
225                        read_suspended = false;
226                        resumeRead();
227                        for( int i=0; i < read_resume_counter ; i++ ) {
228                            resumeRead();
229                        }
230                    }
231                }
232            }
233    
234            public int read(ByteBuffer dst) throws IOException {
235                if( maxReadRate ==0 ) {
236                    return channel.read(dst);
237                } else {
238                    int remaining = dst.remaining();
239                    if( read_allowance ==0 || remaining ==0 ) {
240                        return 0;
241                    }
242    
243                    int reduction = 0;
244                    if( remaining > read_allowance) {
245                        reduction = remaining - read_allowance;
246                        dst.limit(dst.limit() - reduction);
247                    }
248                    int rc=0;
249                    try {
250                        rc = channel.read(dst);
251                        read_allowance -= rc;
252                    } finally {
253                        if( reduction!=0 ) {
254                            if( dst.remaining() == 0 ) {
255                                // we need to suspend the read now until we get
256                                // a new allowance..
257                                readSource.suspend();
258                                read_suspended = true;
259                            }
260                            dst.limit(dst.limit() + reduction);
261                        }
262                    }
263                    return rc;
264                }
265            }
266    
267            public int write(ByteBuffer src) throws IOException {
268                if( maxWriteRate ==0 ) {
269                    return channel.write(src);
270                } else {
271                    int remaining = src.remaining();
272                    if( write_allowance ==0 || remaining ==0 ) {
273                        return 0;
274                    }
275    
276                    int reduction = 0;
277                    if( remaining > write_allowance) {
278                        reduction = remaining - write_allowance;
279                        src.limit(src.limit() - reduction);
280                    }
281                    int rc = 0;
282                    try {
283                        rc = channel.write(src);
284                        write_allowance -= rc;
285                    } finally {
286                        if( reduction!=0 ) {
287                            if( src.remaining() == 0 ) {
288                                // we need to suspend the read now until we get
289                                // a new allowance..
290                                write_suspended = true;
291                                suspendWrite();
292                            }
293                            src.limit(src.limit() + reduction);
294                        }
295                    }
296                    return rc;
297                }
298            }
299    
300            public boolean isOpen() {
301                return channel.isOpen();
302            }
303    
304            public void close() throws IOException {
305                channel.close();
306            }
307    
308            public void resumeRead() {
309                if( read_suspended ) {
310                    read_resume_counter += 1;
311                } else {
312                    _resumeRead();
313                }
314            }
315    
316        }
317    
318        private final Task CANCEL_HANDLER = new Task() {
319            public void run() {
320                socketState.onCanceled();
321            }
322        };
323    
324        static final class OneWay {
325            final Object command;
326            final Retained retained;
327    
328            public OneWay(Object command, Retained retained) {
329                this.command = command;
330                this.retained = retained;
331            }
332        }
333    
334        public void connected(SocketChannel channel) throws IOException, Exception {
335            this.channel = channel;
336            initializeChannel();
337            this.socketState = new CONNECTED();
338        }
339    
340        protected void initializeChannel() throws Exception {
341            this.channel.configureBlocking(false);
342            Socket socket = channel.socket();
343            try {
344                socket.setReuseAddress(true);
345            } catch (SocketException e) {
346            }
347            try {
348                socket.setSoLinger(true, 0);
349            } catch (SocketException e) {
350            }
351            try {
352                socket.setTrafficClass(trafficClass);
353            } catch (SocketException e) {
354            }
355            try {
356                socket.setKeepAlive(keepAlive);
357            } catch (SocketException e) {
358            }
359            try {
360                socket.setTcpNoDelay(true);
361            } catch (SocketException e) {
362            }
363            try {
364                socket.setReceiveBufferSize(receiveBufferSize);
365            } catch (SocketException e) {
366            }
367            try {
368                socket.setSendBufferSize(sendBufferSize);
369            } catch (SocketException e) {
370            }
371    
372            if( channel!=null && codec!=null ) {
373                initializeCodec();
374            }
375        }
376    
377        protected void initializeCodec() throws Exception {
378            codec.setTransport(this);
379        }
380    
381        public void connecting(final URI remoteLocation, final URI localLocation) throws Exception {
382            this.channel = SocketChannel.open();
383            initializeChannel();
384            this.remoteLocation = remoteLocation;
385            this.localLocation = localLocation;
386            socketState = new CONNECTING();
387        }
388    
389    
390        public DispatchQueue getDispatchQueue() {
391            return dispatchQueue;
392        }
393    
394        public void setDispatchQueue(DispatchQueue queue) {
395            this.dispatchQueue = queue;
396            if(readSource!=null) readSource.setTargetQueue(queue);
397            if(writeSource!=null) writeSource.setTargetQueue(queue);
398            if(drainOutboundSource!=null) drainOutboundSource.setTargetQueue(queue);
399            if(yieldSource!=null) yieldSource.setTargetQueue(queue);
400        }
401    
402        public void _start(Task onCompleted) {
403            try {
404                if (socketState.is(CONNECTING.class)) {
405    
406                    // Resolving host names might block.. so do it on the blocking executor.
407                    this.blockingExecutor.execute(new Runnable() {
408                        public void run() {
409                            try {
410    
411                                final InetSocketAddress localAddress = (localLocation != null) ?
412                                        new InetSocketAddress(InetAddress.getByName(localLocation.getHost()), localLocation.getPort())
413                                        : null;
414    
415                                String host = resolveHostName(remoteLocation.getHost());
416                                final InetSocketAddress remoteAddress = new InetSocketAddress(host, remoteLocation.getPort());
417    
418                                // Done resolving.. switch back to the dispatch queue.
419                                dispatchQueue.execute(new Task() {
420                                    @Override
421                                    public void run() {
422                                        // No need to complete if we have been canceled.
423                                        if( ! socketState.is(CONNECTING.class) ) {
424                                            return;
425                                        }
426                                        try {
427    
428                                            if (localAddress != null) {
429                                                channel.socket().bind(localAddress);
430                                            }
431                                            trace("connecting...");
432                                            channel.connect(remoteAddress);
433    
434                                            // this allows the connect to complete..
435                                            readSource = Dispatch.createSource(channel, SelectionKey.OP_CONNECT, dispatchQueue);
436                                            readSource.setEventHandler(new Task() {
437                                                public void run() {
438                                                    if (getServiceState() != STARTED) {
439                                                        return;
440                                                    }
441                                                    try {
442                                                        trace("connected.");
443                                                        channel.finishConnect();
444                                                        readSource.setCancelHandler(null);
445                                                        readSource.cancel();
446                                                        readSource = null;
447                                                        socketState = new CONNECTED();
448                                                        onConnected();
449                                                    } catch (IOException e) {
450                                                        onTransportFailure(e);
451                                                    }
452                                                }
453                                            });
454                                            readSource.setCancelHandler(CANCEL_HANDLER);
455                                            readSource.resume();
456    
457                                        } catch (IOException e) {
458                                            try {
459                                                channel.close();
460                                            } catch (IOException ignore) {
461                                            }
462                                            socketState = new CANCELED(true);
463                                            listener.onTransportFailure(e);
464                                        }
465                                    }
466                                });
467    
468                            } catch (IOException e) {
469                                try {
470                                    channel.close();
471                                } catch (IOException ignore) {
472                                }
473                                socketState = new CANCELED(true);
474                                listener.onTransportFailure(e);
475                            }
476                        }
477                    });
478                } else if (socketState.is(CONNECTED.class)) {
479                    dispatchQueue.execute(new Task() {
480                        public void run() {
481                            try {
482                                trace("was connected.");
483                                onConnected();
484                            } catch (IOException e) {
485                                onTransportFailure(e);
486                            }
487                        }
488                    });
489                } else {
490                    System.err.println("cannot be started.  socket state is: " + socketState);
491                }
492            } finally {
493                if (onCompleted != null) {
494                    onCompleted.run();
495                }
496            }
497        }
498    
499        public void _stop(final Task onCompleted) {
500            trace("stopping.. at state: "+socketState);
501            socketState.onStop(onCompleted);
502        }
503    
504        protected String resolveHostName(String host) throws UnknownHostException {
505            String localName = getLocalHost().getHostName();
506            if (localName != null && isUseLocalHost()) {
507                if (localName.equals(host)) {
508                    return "localhost";
509                }
510            }
511            return host;
512        }
513    
514        protected void onConnected() throws IOException {
515            yieldSource = Dispatch.createSource(EventAggregators.INTEGER_ADD, dispatchQueue);
516            yieldSource.setEventHandler(new Task() {
517                public void run() {
518                    drainInbound();
519                }
520            });
521            yieldSource.resume();
522            drainOutboundSource = Dispatch.createSource(EventAggregators.INTEGER_ADD, dispatchQueue);
523            drainOutboundSource.setEventHandler(new Task() {
524                public void run() {
525                    flush();
526                }
527            });
528            drainOutboundSource.resume();
529    
530            readSource = Dispatch.createSource(channel, SelectionKey.OP_READ, dispatchQueue);
531            writeSource = Dispatch.createSource(channel, SelectionKey.OP_WRITE, dispatchQueue);
532    
533            readSource.setCancelHandler(CANCEL_HANDLER);
534            writeSource.setCancelHandler(CANCEL_HANDLER);
535    
536            readSource.setEventHandler(new Task() {
537                public void run() {
538                    drainInbound();
539                }
540            });
541            writeSource.setEventHandler(new Task() {
542                public void run() {
543                    flush();
544                }
545            });
546    
547            if( maxReadRate !=0 || maxWriteRate !=0 ) {
548                rateLimitingChannel = new RateLimitingChannel();
549                schedualRateAllowanceReset();
550            }
551            listener.onTransportConnected();
552        }
553    
554        private void schedualRateAllowanceReset() {
555            dispatchQueue.executeAfter(1, TimeUnit.SECONDS, new Task(){
556                public void run() {
557                    if( !socketState.is(CONNECTED.class) ) {
558                        return;
559                    }
560                    rateLimitingChannel.resetAllowance();
561                    schedualRateAllowanceReset();
562                }
563            });
564        }
565    
566        private void dispose() {
567            if( readSource!=null ) {
568                readSource.cancel();
569                readSource=null;
570            }
571    
572            if( writeSource!=null ) {
573                writeSource.cancel();
574                writeSource=null;
575            }
576            this.codec = null;
577        }
578    
579        public void onTransportFailure(IOException error) {
580            listener.onTransportFailure(error);
581            socketState.onCanceled();
582        }
583    
584    
585        public boolean full() {
586            return codec==null ||
587                   codec.full() ||
588                   !socketState.is(CONNECTED.class) ||
589                   getServiceState() != STARTED;
590        }
591    
592        boolean rejectingOffers;
593    
594        public boolean offer(Object command) {
595            dispatchQueue.assertExecuting();
596            if( full() ) {
597                return false;
598            }
599            try {
600                ProtocolCodec.BufferState rc = codec.write(command);
601                rejectingOffers = codec.full();
602                switch (rc ) {
603                    case FULL:
604                        return false;
605                    default:
606                        drainOutboundSource.merge(1);
607                }
608            } catch (IOException e) {
609                onTransportFailure(e);
610            }
611            return true;
612        }
613    
614        boolean writeResumedForCodecFlush = false;
615    
616        /**
617         *
618         */
619        public void flush() {
620            dispatchQueue.assertExecuting();
621            if (getServiceState() != STARTED || !socketState.is(CONNECTED.class)) {
622                return;
623            }
624            try {
625                if( codec.flush() == ProtocolCodec.BufferState.EMPTY && transportFlush() ) {
626                    if( writeResumedForCodecFlush) {
627                        writeResumedForCodecFlush = false;
628                        suspendWrite();
629                    }
630                    rejectingOffers = false;
631                    listener.onRefill();
632    
633                } else {
634                    if(!writeResumedForCodecFlush) {
635                        writeResumedForCodecFlush = true;
636                        resumeWrite();
637                    }
638                }
639            } catch (IOException e) {
640                onTransportFailure(e);
641            }
642        }
643    
644        protected boolean transportFlush() throws IOException {
645            return true;
646        }
647    
648        public void drainInbound() {
649            if (!getServiceState().isStarted() || readSource.isSuspended()) {
650                return;
651            }
652            try {
653                long initial = codec.getReadCounter();
654                // Only process upto 2 x the read buffer worth of data at a time so we can give
655                // other connections a chance to process their requests.
656                while( codec.getReadCounter()-initial < codec.getReadBufferSize()<<2 ) {
657                    Object command = codec.read();
658                    if ( command!=null ) {
659                        try {
660                            listener.onTransportCommand(command);
661                        } catch (Throwable e) {
662                            e.printStackTrace();
663                            onTransportFailure(new IOException("Transport listener failure."));
664                        }
665    
666                        // the transport may be suspended after processing a command.
667                        if (getServiceState() == STOPPED || readSource.isSuspended()) {
668                            return;
669                        }
670                    } else {
671                        return;
672                    }
673                }
674                yieldSource.merge(1);
675            } catch (IOException e) {
676                onTransportFailure(e);
677            }
678        }
679    
680        public SocketAddress getLocalAddress() {
681            return localAddress;
682        }
683    
684        public SocketAddress getRemoteAddress() {
685            return remoteAddress;
686        }
687    
688        private boolean assertConnected() {
689            try {
690                if ( !isConnected() ) {
691                    throw new IOException("Not connected.");
692                }
693                return true;
694            } catch (IOException e) {
695                onTransportFailure(e);
696            }
697            return false;
698        }
699    
700        public void suspendRead() {
701            if( isConnected() && readSource!=null ) {
702                readSource.suspend();
703            }
704        }
705    
706    
707        public void resumeRead() {
708            if( isConnected() && readSource!=null ) {
709                if( rateLimitingChannel!=null ) {
710                    rateLimitingChannel.resumeRead();
711                } else {
712                    _resumeRead();
713                }
714            }
715        }
716    
717        private void _resumeRead() {
718            readSource.resume();
719            dispatchQueue.execute(new Task(){
720                public void run() {
721                    drainInbound();
722                }
723            });
724        }
725    
726        protected void suspendWrite() {
727            if( isConnected() && writeSource!=null ) {
728                writeSource.suspend();
729            }
730        }
731    
732        protected void resumeWrite() {
733            if( isConnected() && writeSource!=null ) {
734                writeSource.resume();
735            }
736        }
737    
738        public TransportListener getTransportListener() {
739            return listener;
740        }
741    
742        public void setTransportListener(TransportListener transportListener) {
743            this.listener = transportListener;
744        }
745    
746        public ProtocolCodec getProtocolCodec() {
747            return codec;
748        }
749    
750        public void setProtocolCodec(ProtocolCodec protocolCodec) throws Exception {
751            this.codec = protocolCodec;
752            if( channel!=null && codec!=null ) {
753                initializeCodec();
754            }
755        }
756    
757        public boolean isConnected() {
758            return socketState.is(CONNECTED.class);
759        }
760    
761        public boolean isClosed() {
762            return getServiceState() == STOPPED;
763        }
764    
765        public boolean isUseLocalHost() {
766            return useLocalHost;
767        }
768    
769        /**
770         * Sets whether 'localhost' or the actual local host name should be used to
771         * make local connections. On some operating systems such as Macs its not
772         * possible to connect as the local host name so localhost is better.
773         */
774        public void setUseLocalHost(boolean useLocalHost) {
775            this.useLocalHost = useLocalHost;
776        }
777    
778        private void trace(String message) {
779            // TODO:
780        }
781    
782        public SocketChannel getSocketChannel() {
783            return channel;
784        }
785    
786        public ReadableByteChannel getReadChannel() {
787            if(rateLimitingChannel!=null) {
788                return rateLimitingChannel;
789            } else {
790                return channel;
791            }
792        }
793    
794        public WritableByteChannel getWriteChannel() {
795            if(rateLimitingChannel!=null) {
796                return rateLimitingChannel;
797            } else {
798                return channel;
799            }
800        }
801    
802        public int getMaxReadRate() {
803            return maxReadRate;
804        }
805    
806        public void setMaxReadRate(int maxReadRate) {
807            this.maxReadRate = maxReadRate;
808        }
809    
810        public int getMaxWriteRate() {
811            return maxWriteRate;
812        }
813    
814        public void setMaxWriteRate(int maxWriteRate) {
815            this.maxWriteRate = maxWriteRate;
816        }
817    
818        public int getTrafficClass() {
819            return trafficClass;
820        }
821    
822        public void setTrafficClass(int trafficClass) {
823            this.trafficClass = trafficClass;
824        }
825    
826        public int getReceiveBufferSize() {
827            return receiveBufferSize;
828        }
829    
830        public void setReceiveBufferSize(int receiveBufferSize) {
831            this.receiveBufferSize = receiveBufferSize;
832            if( channel!=null ) {
833                try {
834                    channel.socket().setReceiveBufferSize(receiveBufferSize);
835                } catch (SocketException ignore) {
836                }
837            }
838        }
839    
840        public int getSendBufferSize() {
841            return sendBufferSize;
842        }
843    
844        public void setSendBufferSize(int sendBufferSize) {
845            this.sendBufferSize = sendBufferSize;
846            if( channel!=null ) {
847                try {
848                    channel.socket().setReceiveBufferSize(sendBufferSize);
849                } catch (SocketException ignore) {
850                }
851            }
852        }
853    
854        public boolean isKeepAlive() {
855            return keepAlive;
856        }
857    
858        public void setKeepAlive(boolean keepAlive) {
859            this.keepAlive = keepAlive;
860        }
861    
862        public Executor getBlockingExecutor() {
863            return blockingExecutor;
864        }
865    
866        public void setBlockingExecutor(Executor blockingExecutor) {
867            this.blockingExecutor = blockingExecutor;
868        }
869    
870        public boolean isCloseOnCancel() {
871            return closeOnCancel;
872        }
873    
874        public void setCloseOnCancel(boolean closeOnCancel) {
875            this.closeOnCancel = closeOnCancel;
876        }
877    }