001    /*
002     *  Licensed to the Apache Software Foundation (ASF) under one
003     *  or more contributor license agreements.  See the NOTICE file
004     *  distributed with this work for additional information
005     *  regarding copyright ownership.  The ASF licenses this file
006     *  to you under the Apache License, Version 2.0 (the
007     *  "License"); you may not use this file except in compliance
008     *  with the License.  You may obtain a copy of the License at
009     *  
010     *    http://www.apache.org/licenses/LICENSE-2.0
011     *  
012     *  Unless required by applicable law or agreed to in writing,
013     *  software distributed under the License is distributed on an
014     *  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015     *  KIND, either express or implied.  See the License for the
016     *  specific language governing permissions and limitations
017     *  under the License. 
018     *  
019     */
020    package org.apache.directory.server.kerberos.shared.crypto.encryption;
021    
022    
023    /**
024     * An implementation of the n-fold algorithm, as required by RFC 3961,
025     * "Encryption and Checksum Specifications for Kerberos 5."
026     * 
027     * "To n-fold a number X, replicate the input value to a length that
028     * is the least common multiple of n and the length of X.  Before
029     * each repetition, the input is rotated to the right by 13 bit
030     * positions.  The successive n-bit chunks are added together using
031     * 1's-complement addition (that is, with end-around carry) to yield
032     * a n-bit result."
033     * 
034     * @author <a href="mailto:dev@directory.apache.org">Apache Directory Project</a>
035     * @version $Rev$, $Date$
036     */
037    public class NFold
038    {
039        /**
040         * N-fold the data n times.
041         * 
042         * @param n The number of times to n-fold the data.
043         * @param data The data to n-fold.
044         * @return The n-folded data.
045         */
046        public static byte[] nFold( int n, byte[] data )
047        {
048            int k = data.length * 8;
049            int lcm = getLcm( n, k );
050            int replicate = lcm / k;
051            byte[] sumBytes = new byte[lcm / 8];
052    
053            for ( int i = 0; i < replicate; i++ )
054            {
055                int rotation = 13 * i;
056    
057                byte[] temp = rotateRight( data, data.length * 8, rotation );
058    
059                for ( int j = 0; j < temp.length; j++ )
060                {
061                    sumBytes[j + i * temp.length] = temp[j];
062                }
063            }
064    
065            byte[] sum = new byte[n / 8];
066            byte[] nfold = new byte[n / 8];
067    
068            for ( int m = 0; m < lcm / n; m++ )
069            {
070                for ( int o = 0; o < n / 8; o++ )
071                {
072                    sum[o] = sumBytes[o + ( m * n / 8 )];
073                }
074    
075                nfold = sum( nfold, sum, nfold.length * 8 );
076    
077            }
078    
079            return nfold;
080        }
081    
082    
083        /**
084         * For 2 numbers, return the least-common multiple.
085         *
086         * @param n1 The first number.
087         * @param n2 The second number.
088         * @return The least-common multiple.
089         */
090        protected static int getLcm( int n1, int n2 )
091        {
092            int temp;
093            int product;
094    
095            product = n1 * n2;
096    
097            do
098            {
099                if ( n1 < n2 )
100                {
101                    temp = n1;
102                    n1 = n2;
103                    n2 = temp;
104                }
105                n1 = n1 % n2;
106            }
107            while ( n1 != 0 );
108    
109            return product / n2;
110        }
111    
112    
113        /**
114         * Right-rotate the given byte array.
115         *
116         * @param in The byte array to right-rotate.
117         * @param len The length of the byte array to rotate.
118         * @param step The number of positions to rotate the byte array.
119         * @return The right-rotated byte array.
120         */
121        private static byte[] rotateRight( byte[] in, int len, int step )
122        {
123            int numOfBytes = ( len - 1 ) / 8 + 1;
124            byte[] out = new byte[numOfBytes];
125    
126            for ( int i = 0; i < len; i++ )
127            {
128                int val = getBit( in, i );
129                setBit( out, ( i + step ) % len, val );
130            }
131            return out;
132        }
133    
134    
135        /**
136         * Perform one's complement addition (addition with end-around carry).  Note
137         * that for purposes of n-folding, we do not actually complement the
138         * result of the addition.
139         * 
140         * @param n1 The first number.
141         * @param n2 The second number.
142         * @param len The length of the byte arrays to sum.
143         * @return The sum with end-around carry.
144         */
145        protected static byte[] sum( byte[] n1, byte[] n2, int len )
146        {
147            int numOfBytes = ( len - 1 ) / 8 + 1;
148            byte[] out = new byte[numOfBytes];
149            int carry = 0;
150    
151            for ( int i = len - 1; i > -1; i-- )
152            {
153                int n1b = getBit( n1, i );
154                int n2b = getBit( n2, i );
155    
156                int sum = n1b + n2b + carry;
157    
158                if ( sum == 0 || sum == 1 )
159                {
160                    setBit( out, i, sum );
161                    carry = 0;
162                }
163                else if ( sum == 2 )
164                {
165                    carry = 1;
166                }
167                else if ( sum == 3 )
168                {
169                    setBit( out, i, 1 );
170                    carry = 1;
171                }
172            }
173    
174            if ( carry == 1 )
175            {
176                byte[] carryArray = new byte[n1.length];
177                carryArray[carryArray.length - 1] = 1;
178                out = sum( out, carryArray, n1.length * 8 );
179            }
180    
181            return out;
182        }
183    
184    
185        /**
186         * Get a bit from a byte array at a given position.
187         *
188         * @param data The data to get the bit from.
189         * @param pos The position to get the bit at.
190         * @return The value of the bit.
191         */
192        private static int getBit( byte[] data, int pos )
193        {
194            int posByte = pos / 8;
195            int posBit = pos % 8;
196    
197            byte valByte = data[posByte];
198            int valInt = valByte >> ( 8 - ( posBit + 1 ) ) & 0x0001;
199            return valInt;
200        }
201    
202    
203        /**
204         * Set a bit in a byte array at a given position.
205         *
206         * @param data The data to set the bit in.
207         * @param pos The position of the bit to set.
208         * @param The value to set the bit to.
209         */
210        private static void setBit( byte[] data, int pos, int val )
211        {
212            int posByte = pos / 8;
213            int posBit = pos % 8;
214            byte oldByte = data[posByte];
215            oldByte = ( byte ) ( ( ( 0xFF7F >> posBit ) & oldByte ) & 0x00FF );
216            byte newByte = ( byte ) ( ( val << ( 8 - ( posBit + 1 ) ) ) | oldByte );
217            data[posByte] = newByte;
218        }
219    }