1 /*
2  * @(#) $Id: JSTKSocketUtil.java,v 1.4 2003/07/08 08:13:53 pankaj Exp $
3  *
4  * Copyright (c) 2002-03 by Pankaj Kumar (http://www.pankaj-k.net). 
5  * All rights reserved.
6  *
7  * The license governing the use of this file can be found in the 
8  * root directory of the containing software.
9  */
10package org.jstk.ssl;
11
12import java.net.*;
13import java.io.*;
14import java.util.Vector;
15import javax.net.SocketFactory;
16import javax.net.ServerSocketFactory;
17import javax.net.ssl.SSLSocketFactory;
18import javax.net.ssl.SSLServerSocketFactory;
19import javax.net.ssl.SSLServerSocket;
20import javax.net.ssl.SSLSocket;
21import javax.net.ssl.SSLSession;
22import javax.net.ssl.SSLPeerUnverifiedException;
23import java.security.cert.Certificate;
24import java.security.cert.X509Certificate;
25import org.jstk.*;
26import java.nio.channels.SocketChannel;
27import java.nio.channels.ServerSocketChannel;
28
29public class JSTKSocketUtil {
30
31    public static JSTKServerSocket createServerSocket(JSTKArgs args) throws JSTKException {
32        try {
33            String inport = args.get("inport");
34            String inetAddrVal = args.get("inetaddr");
35            boolean verbose = Boolean.valueOf(args.get("verbose")).booleanValue();
36            String inproto = args.get("inproto");
37            boolean nio = Boolean.valueOf(args.get("nio")).booleanValue();
38
39            int lport = Integer.parseInt(inport);
40            JSTKServerSocket jss = null;
41
42            if (nio && !inproto.equalsIgnoreCase("SSL")){
43                InetSocketAddress isa = new InetSocketAddress(InetAddress.getLocalHost(), lport);
44                ServerSocketChannel ssc = ServerSocketChannel.open();
45                ssc.socket().bind(isa);
46                jss = JSTKServerSocket.getInstance(ssc);
47            } else {
48                ServerSocketFactory ssf = null;
49                ServerSocket serverSocket = null;
50                if (inproto.equalsIgnoreCase("SSL")){
51                    ssf = SSLServerSocketFactory.getDefault();
52                } else {
53                    ssf = ServerSocketFactory.getDefault();
54                }
55                if (inetAddrVal == null){
56                    serverSocket = ssf.createServerSocket(lport);
57                } else {
58                    InetAddress ia = InetAddress.getByName(inetAddrVal);
59                    serverSocket = ssf.createServerSocket(lport, 50, ia);
60                }
61                if (serverSocket instanceof SSLServerSocket){
62                    String[] csarray = getCSFileCipherSuites(args);
63                    if (csarray != null){
64                        ((SSLServerSocket)serverSocket).setEnabledCipherSuites(csarray);
65                    }
66                }
67                jss = JSTKServerSocket.getInstance(serverSocket);
68            }
69
70            return jss;
71        } catch (Exception exc){
72            throw new JSTKException("Could not create Server Scoket: " + exc, exc);
73        }
74    }
75
76    public static JSTKSocket connect(JSTKArgs args) throws JSTKException {
77        try {
78            String host = args.get("host");
79            int port = Integer.parseInt(args.get("port"));
80            String inetAddrVal = args.get("inetaddr");
81            boolean verbose = Boolean.valueOf(args.get("verbose")).booleanValue();
82            String outproto = args.get("outproto");
83            Socket socket = null;
84
85            if (getIOLibrary(args, outproto).equalsIgnoreCase("NIO")){
86                InetSocketAddress isa = new InetSocketAddress(InetAddress.getByName(host), port);
87                java.nio.channels.SocketChannel sc = java.nio.channels.SocketChannel.open();
88                sc.connect(isa);
89                sc.socket().setTcpNoDelay(true);
90                return JSTKSocket.getInstance(sc);
91            } else {
92                SocketFactory sf = null;
93                if (outproto.equalsIgnoreCase("SSL")){
94                    sf = SSLSocketFactory.getDefault();
95                } else {
96                    sf = SocketFactory.getDefault();
97                }
98                if (inetAddrVal == null){
99                    socket = sf.createSocket(host, port);
00                } else {
01                    InetAddress ia = InetAddress.getByName(inetAddrVal);
02                    socket = sf.createSocket(host, port, ia, port + 1);
03                }
04                socket.setTcpNoDelay(true);
05                if (socket instanceof SSLSocket){
06                    String[] csarray = getCSFileCipherSuites(args);
07                    if (csarray != null){
08                        ((SSLSocket)socket).setEnabledCipherSuites(csarray);
09                    }
10                }
11
12                return JSTKSocket.getInstance(socket);
13            }
14        } catch (Exception exc){
15            throw new JSTKException("Could not create Scoket: " + exc, exc);
16        }
17    }
18
19    public static void print(JSTKSocket jsocket, String dir) {
20        try {
21            Socket socket = jsocket.getSocket();
22
23            InetSocketAddress localAddr = (InetSocketAddress)socket.getLocalSocketAddress();
24            InetSocketAddress remoteAddr = (InetSocketAddress)socket.getRemoteSocketAddress();
25            String localAddrId = localAddr.getHostName() + ":" + localAddr.getPort();
26            String remoteAddrId = remoteAddr.getHostName() + ":" + remoteAddr.getPort();
27
28            System.out.println("  Connection   : " + localAddrId + dir + remoteAddrId);
29            if (socket instanceof SSLSocket){
30                SSLSession sess = ((SSLSocket)socket).getSession();
31                System.out.println("  Protocol     : " + sess.getProtocol());
32                System.out.println("  Cipher Suite : " + sess.getCipherSuite());
33                Certificate[] localCerts = sess.getLocalCertificates();
34                if (localCerts != null && localCerts.length > 0)
35                    printCertDNs(localCerts, "  Local Certs : ");
36
37                Certificate[] remoteCerts = null;
38                try {
39                    remoteCerts = sess.getPeerCertificates();
40                    printCertDNs(remoteCerts, "  Remote Certs: ");
41                } catch (SSLPeerUnverifiedException exc){
42                    System.out.println("  Remote Certs: Unverified");
43                }
44            } else {
45                System.out.println("  Protocol     : TCP");
46            }
47        } catch (Exception exc){
48            System.err.println("Could not print Socket Information: " + exc);
49        }
50    }
51
52    private static void printCertDNs(Certificate[] certs, String label){
53        System.out.println(label + "[0]" + ((X509Certificate)certs[0]).getSubjectDN());
54        StringBuffer indent = new StringBuffer();
55        for (int i = label.length(); i > 0; i--)
56            indent.append(" ");
57        for (int i = 1; i < certs.length; i++){
58            System.out.println(indent.toString() + "[" + i + "]" +
59                ((X509Certificate)certs[i]).getSubjectDN());
60        }
61    }
62
63    public static String getIOLibrary(JSTKArgs args, String proto){
64        boolean nio = Boolean.valueOf(args.get("nio")).booleanValue();
65        if (nio && !proto.equalsIgnoreCase("SSL"))
66            return "NIO";
67        else
68            return "CLASSIC";
69    }
70    public static String[] getCSFileCipherSuites(JSTKArgs args){
71        String csfile = args.get("csfile");
72        try {
73            if (csfile != null){
74                BufferedReader br =
75                    new BufferedReader(new InputStreamReader(new FileInputStream(csfile)));
76                Vector v = new Vector();
77                String s;
78                while ((s = br.readLine()) != null){
79                    s = s.trim();
80                    if (s.length() > 0)
81                        v.add(s);
82                }
83                String[] csarray = new String[v.size()];
84                for (int i = 0; i < v.size(); i++){
85                    csarray[i] = (String)v.elementAt(i);
86                }
87            return csarray;
88            }
89        } catch (IOException ioe){
90            System.err.println("Error reading csfile: " + csfile + ", Exception: " + ioe);
91        }
92        return null;
93    }
94}
95