1 module socks5d.packets;
2 
3 import std.socket;
4 import std.bitmanip;
5 import std.conv;
6 
7 enum AuthMethod : ubyte {
8     NOAUTH = 0x00,
9     AUTH = 0x02,
10     NOTAVAILABLE = 0xFF
11 }
12 
13 enum RequestCmd : ubyte {
14     CONNECT = 0x01,
15     BIND = 0x02,
16     UDPASSOCIATE = 0x03
17 }
18 
19 enum AddressType : ubyte {
20     IPV4 = 0x01,
21     DOMAIN = 0x03,
22     IPV6 = 0x04
23 }
24 
25 enum ReplyCode : ubyte {
26     SUCCEEDED = 0x00,
27     FAILURE = 0x01,
28     NOTALLOWED = 0x02,
29     NETWORK_UNREACHABLE = 0x03,
30     HOST_UNREACHABLE = 0x04,
31     CONNECTION_REFUSED = 0x05,
32     TTL_EXPIRED = 0x06,
33     CMD_NOTSUPPORTED = 0x07,
34     ADDR_NOTSUPPORTED = 0x08
35 }
36 
37 string printFields(T)(T args)
38 {
39     import std.format : format;
40 
41     string result = typeid(T).toString() ~ ": ";
42     auto values = args.tupleof;
43 
44     size_t max;
45     size_t temp;
46     foreach (index, value; values) {
47         temp = T.tupleof[index].stringof.length;
48         if (max < temp) max = temp;
49     }
50     max += 1;
51     foreach (index, value; values) {
52         result ~= format("%s=%s ", T.tupleof[index].stringof, value);
53     }
54 
55     return result;
56 }
57 
58 class SocksException : Exception
59 {
60     public:
61 
62         this(string msg, string file = __FILE__,
63          size_t line = __LINE__, Throwable next = null) @safe pure nothrow
64         {
65             super(msg, file, line, next);
66         }
67 }
68 
69 class RequestException : SocksException
70 {
71     public:
72         ReplyCode replyCode;
73 
74         this(ReplyCode replyCode, string msg, string file = __FILE__,
75          size_t line = __LINE__, Throwable next = null) @safe pure nothrow
76         {
77             super(msg, file, line, next);
78             this.replyCode = replyCode;
79         }
80 }
81 
82 abstract class Socks5Packet
83 {
84     ubyte[1] ver = [0x05]; //should be 0x05 (or 0x01 for auth)
85 
86     ubyte getVersion()
87     {
88         return ver[0];
89     }
90 }
91 
92 abstract class IncomingPacket: Socks5Packet
93 {
94     void receiveVersion(Socket socket, ubyte requiredVersion = 0x05)
95     {
96         socket.receive(ver);
97         if (ver[0] != requiredVersion) {
98             throw new SocksException("Incorrect protocol version: " ~ ver[0].to!string);
99         }
100     }
101 
102     void receiveBuffer(Socket s, ref ubyte[1] len, ref ubyte[] buf)
103     {
104         s.receive(len);
105         buf = new ubyte[len[0]];
106         s.receive(buf);
107     }
108 
109     void receive(Socket s);
110 }
111 
112 abstract class OutgoingPacket: Socks5Packet
113 {
114     abstract void send(Socket s);
115 }
116 
117 class MethodIdentificationPacket : IncomingPacket
118 {
119     ubyte[1] nmethods;
120     ubyte[]  methods;
121 
122     override void receive(Socket socket)
123     {
124         receiveVersion(socket);
125         receiveBuffer(socket, nmethods, methods);
126     }
127 
128     AuthMethod detectAuthMethod(AuthMethod[] availableMethods)
129     {
130         import std.algorithm;
131 
132         foreach (AuthMethod method; availableMethods) {
133             if (methods.canFind(method)) {
134                 return method;
135             }
136         }
137 
138         return AuthMethod.NOTAVAILABLE;
139     }
140 
141     ubyte getNMethods()
142     {
143         return nmethods[0];
144     }
145 
146     unittest
147     {
148         auto packet = new MethodIdentificationPacket;
149         auto sp = socketPair();
150         immutable ubyte[] input = [
151             0x05,
152             0x01,
153             AuthMethod.NOAUTH
154         ];
155 
156         sp[0].send(input);
157         packet.receive(sp[1]);
158 
159         assert(packet.getVersion() == 5);
160         assert(packet.getNMethods() == 1);
161         assert(packet.detectAuthMethod([AuthMethod.NOAUTH]) == AuthMethod.NOAUTH);
162         assert(packet.detectAuthMethod([AuthMethod.AUTH]) == AuthMethod.NOTAVAILABLE);
163     }
164 }
165 
166 class MethodSelectionPacket : OutgoingPacket
167 {
168     ubyte method;
169 
170     override void send(Socket s)
171     {
172         s.send(ver);
173         s.send((&method)[0..1]);
174     }
175 }
176 
177 class AuthPacket : IncomingPacket
178 {
179     ubyte[1]  ulen;
180     ubyte[]   uname;
181     ubyte[1]  plen;
182     ubyte[]   passwd;
183 
184     override void receive(Socket socket)
185     {
186         receiveVersion(socket, 0x01);
187         receiveBuffer(socket, ulen, uname);
188         receiveBuffer(socket, plen, passwd);
189     }
190 
191     string getAuthString()
192     {
193         import std.format : format;
194 
195         return format("%s:%s", cast(char[])uname, cast(char[])passwd ) ;
196     }
197 
198     unittest
199     {
200         auto packet = new AuthPacket;
201         auto sp = socketPair();
202         immutable ubyte[] input = [
203             0x01,
204             5,
205             't', 'u', 's', 'e', 'r',
206             7,
207             't', 'p', 'a', 's', 's', 'w', 'd'
208         ];
209 
210         sp[0].send(input);
211         packet.receive(sp[1]);
212 
213         assert(packet.getVersion() == 1);
214         assert(packet.getAuthString() == "tuser:tpasswd");
215     }
216 }
217 
218 class AuthStatusPacket : OutgoingPacket
219 {
220     ubyte status = 0x00;
221 
222     override void send(Socket s)
223     {
224         s.send(ver);
225         s.send((&status)[0..1]);
226     }
227 }
228 
229 class RequestPacket : IncomingPacket
230 {
231     RequestCmd[1]  cmd;
232     ubyte[1]       rsv;
233     AddressType[1] atyp;
234     ubyte[]        dstaddr;
235     ubyte[2]       dstport;
236 
237     private InternetAddress destinationAddress;
238 
239     // fill structure with data from socket
240     override void receive(Socket socket)
241     {
242         receiveVersion(socket);
243         readRequestCommand(socket);
244         socket.receive(rsv);
245         if (rsv[0] != 0x00) {
246             throw new RequestException(ReplyCode.FAILURE, "Received incorrect rsv byte");
247         }
248 
249         destinationAddress = readAddressAndPort(socket);
250     }
251 
252     InternetAddress getDestinationAddress()
253     {
254         return destinationAddress;
255     }
256 
257     private void readRequestCommand(Socket socket)
258     {
259         socket.receive(cmd);
260         if (cmd[0] != RequestCmd.CONNECT) {
261             throw new RequestException(ReplyCode.CMD_NOTSUPPORTED,
262                 "Only CONNECT method is supported, given " ~ cmd[0].to!string);
263         }
264     }
265 
266     private InternetAddress readAddressAndPort(Socket socket)
267     {
268         socket.receive(atyp);
269 
270         switch (atyp[0]) {
271             case AddressType.IPV4:
272                 dstaddr = new ubyte[4];
273                 socket.receive(dstaddr);
274                 socket.receive(dstport);
275 
276                 return new InternetAddress(dstaddr.read!uint, dstport.bigEndianToNative!ushort);
277 
278             case AddressType.DOMAIN:
279                 ubyte[1] length;
280                 receiveBuffer(socket, length, dstaddr);
281                 socket.receive(dstport);
282 
283                 return new InternetAddress(cast(char[])dstaddr, dstport.bigEndianToNative!ushort);
284 
285             case AddressType.IPV6:
286                 throw new RequestException(ReplyCode.ADDR_NOTSUPPORTED, "AddressType=ipv6 is not supported");
287 
288             default:
289                 throw new RequestException(ReplyCode.ADDR_NOTSUPPORTED, "Unknown AddressType: " ~ atyp[0]);
290         }
291     }
292 
293     /// test IPv4 address type
294     unittest
295     {
296         auto packet = new RequestPacket;
297         auto sp = socketPair();
298         immutable ubyte[] input = [
299             0x05,
300             0x01,
301             0x00,
302             AddressType.IPV4,
303             10, 0, 35, 94,
304             0x00, 0x50 // port 80
305         ];
306 
307         sp[0].send(input);
308         packet.receive(sp[1]);
309 
310         assert(packet.getVersion() == 5);
311         assert(packet.getDestinationAddress().toString() == "10.0.35.94:80");
312     }
313 
314     /// test domain address type
315     unittest
316     {
317         auto packet = new RequestPacket;
318         auto sp = socketPair();
319         immutable ubyte[] input = [
320             0x05,
321             0x01,
322             0x00,
323             AddressType.DOMAIN,
324             9, 'l', 'o', 'c', 'a', 'l', 'h', 'o', 's', 't',
325             0x00, 0x50 // port 80
326         ];
327 
328         sp[0].send(input);
329         packet.receive(sp[1]);
330 
331         assert(packet.getVersion() == 5);
332         assert(packet.getDestinationAddress().toString() == "127.0.0.1:80");
333     }
334 }
335 
336 class ResponsePacket : OutgoingPacket
337 {
338     ReplyCode   rep = ReplyCode.SUCCEEDED;
339     ubyte[1]    rsv = [0x00];
340     AddressType atyp;
341     ubyte[4]    bndaddr;
342     ubyte[2]    bndport;
343 
344     override void send(Socket s)
345     {
346         s.send(ver);
347         s.send((&rep)[0..1]);
348         s.send(rsv);
349         s.send((&atyp)[0..1]);
350         s.send(bndaddr);
351         s.send(bndport);
352     }
353 
354     bool setBindAddress(InternetAddress address)
355     {
356         bndport = nativeToBigEndian(address.port);
357         bndaddr = nativeToBigEndian(address.addr);
358 
359         return true;
360     }
361 
362     unittest
363     {
364         auto packet = new ResponsePacket;
365         auto sp = socketPair();
366         immutable ubyte[] output = [
367             0x05,
368             ReplyCode.SUCCEEDED,
369             0x00,
370             AddressType.IPV4,
371             127, 0, 0, 1, // 127.0.0.1
372             0x00, 0x51    // port 81
373         ];
374 
375         packet.setBindAddress(new InternetAddress("127.0.0.1", 81));
376 
377         packet.send(sp[0]);
378         ubyte[output.length] buf;
379         sp[1].receive(buf);
380 
381         assert(buf == output);
382     }
383 }