1 module socks5d.client;
2 
3 import std.socket;
4 import socks5d.packets;
5 import std.experimental.logger;
6 
7 class Client
8 {
9     protected:
10         uint         id;
11         Socket       socket;
12         TcpSocket	 targetSocket;
13         string       authString;
14         AuthMethod[] availableMethods = [ AuthMethod.NOAUTH ];
15 
16     public:
17         this(Socket clientSocket, uint id)
18         {
19             socket = clientSocket;
20             this.id = id;
21         }
22 
23         void setAuthString(string authString)
24         {
25             if (authString.length > 1) {
26                 this.authString = authString;
27                 availableMethods = [ AuthMethod.AUTH ];
28             }
29         }
30 
31         final void run()
32         {
33             warningf("[%d] New client accepted: %s", id, socket.remoteAddress().toString());
34 
35             try {
36                 if (authenticate()) {
37                     infof("[%d] Client successfully authenticated.", id);
38                 } else {
39                     warningf("[%d] Client failed to authenticate.", id);
40                     socket.close();
41 
42                     return;
43                 }
44 
45                 if (handshake) {
46                     targetToClientSession(socket, targetSocket);
47                 } else {
48                     socket.close();
49                 }
50 
51             } catch (SocksException e) {
52                 errorf("Error: %s", e.msg);
53                 socket.close();
54             }
55         }
56 
57     protected:
58         bool authenticate()
59         {
60             auto identificationPacket = new MethodIdentificationPacket;
61             identificationPacket.receive(socket);
62             tracef("[%d] -> %s", id, identificationPacket.printFields);
63 
64             auto packet2 = new MethodSelectionPacket;
65 
66             packet2.method = identificationPacket.detectAuthMethod(availableMethods);
67 
68             tracef("[%d] <- %s", id, packet2.printFields);
69             packet2.send(socket);
70 
71             if (packet2.method == AuthMethod.NOTAVAILABLE) {
72                 return false;
73             }
74 
75             if (packet2.method == AuthMethod.AUTH) {
76                 auto authPacket = new AuthPacket;
77                 auto authStatus = new AuthStatusPacket;
78 
79                 authPacket.receive(socket);
80                 tracef("[%d] -> %s", id, authPacket.printFields);
81                 tracef("[%d] Client auth with credentials: %s", id, authPacket.getAuthString());
82 
83                 if (authPacket.getAuthString() == authString) {
84                     authStatus.status = 0x00;
85                     tracef("[%d] <- %s", id, authStatus.printFields);
86                     authStatus.send(socket);
87 
88                     return true;
89                 } else {
90                     authStatus.status = 0x01;
91                     authStatus.send(socket);
92 
93                     return false;
94                 }
95             }
96 
97             return true;
98         }
99 
100         bool handshake()
101         {
102             auto requestPacket = new RequestPacket;
103             auto packet4 = new ResponsePacket;
104             InternetAddress targetAddress;
105 
106             try {
107                 requestPacket.receive(socket);
108             } catch (RequestException e) {
109                 errorf("Error: %s", e.msg);
110                 packet4.rep = e.replyCode;
111                 tracef("[%d] <- %s", id, packet4.printFields);
112                 packet4.send(socket);
113 
114                 return false;
115             }
116 
117             tracef("[%d] -> %s", id, requestPacket.printFields);
118 
119             targetSocket = connectToTarget(requestPacket.getDestinationAddress());
120 
121             packet4.atyp = AddressType.IPV4;
122             packet4.setBindAddress(cast(InternetAddress)targetSocket.localAddress);
123 
124             tracef("[%d] Local target address: %s", id, targetSocket.localAddress.toString());
125             tracef("[%d] <- %s", id, packet4.printFields);
126             packet4.send(socket);
127 
128             return true;
129         }
130 
131         TcpSocket connectToTarget(InternetAddress address)
132         out (targetSock) {
133             assert(targetSock.isAlive);
134         } body {
135             auto targetSock = new TcpSocket;
136             tracef("[%d] Connecting to target %s", id, address.toString());
137             targetSock.connect(address);
138 
139             return targetSock;
140         }
141 
142         void targetToClientSession(Socket clientSocket, Socket targetSocket)
143         {
144             auto sset = new SocketSet(2);
145             ubyte[1024*8] buffer;
146             ptrdiff_t received;
147 
148             debug {
149                 int bytesToClient;
150                 static int bytesToClientLogThreshold = 1024*128;
151                 int bytesToTarget;
152                 static int bytesToTargetLogThreshold = 1024*8;
153             }
154 
155             for (;; sset.reset()) {
156                 sset.add(clientSocket);
157                 sset.add(targetSocket);
158 
159                 if (Socket.select(sset, null, null) <= 0) {
160                     infof("[%d] End of data transfer", id);
161                     break;
162                 }
163 
164                 if (sset.isSet(clientSocket)) {
165                     received = clientSocket.receive(buffer);
166                     if (received == Socket.ERROR) {
167                         warningf("[%d] Connection error on clientSocket.", id);
168                         break;
169                     } else if (received == 0) {
170                         infof("[%d] Client connection closed.", id);
171                         break;
172                     }
173 
174                     targetSocket.send(buffer[0..received]);
175 
176                     debug {
177                         bytesToTarget += received;
178                         if (bytesToTarget >= bytesToTargetLogThreshold) {
179                             tracef("[%d] <- %d bytes sent to target", id, bytesToTarget);
180                             bytesToTarget -= bytesToTargetLogThreshold;
181                         }
182                     }
183                 }
184 
185                 if (sset.isSet(targetSocket)) {
186                     received = targetSocket.receive(buffer);
187                     if (received == Socket.ERROR) {
188                         warningf("[%d] Connection error on targetSocket.", id);
189                         break;
190                     } else if (received == 0) {
191                         infof("[%d] Target connection closed.", id);
192                         break;
193                     }
194 
195                     clientSocket.send(buffer[0..received]);
196 
197                     debug {
198                         bytesToClient += received;
199                         if (bytesToClient >= bytesToClientLogThreshold) {
200                             tracef("[%d] <- %d bytes sent to client", id, bytesToClient);
201                             bytesToClient -= bytesToClientLogThreshold;
202                         }
203                     }
204                 }
205             }
206 
207             clientSocket.close();
208             targetSocket.close();
209         }
210 }