using System; using Tesses.WebServer; using System.Threading.Tasks; using Tesses.WebServer.WebSocket; using System.Threading; using System.Text; using Newtonsoft.Json; using System.Security.Cryptography; using System.Collections.Generic; using System.Linq; using System.Collections; using System.IO; namespace Tesses.WebServer { public static class WebSocketExtensions { internal static bool FirstEquals(this Dictionary> dict,T1 t,T2 t2) { T2 firstVal; return dict.TryGetFirst(t,out firstVal) && firstVal.Equals(t2); } internal static bool AnyEquals(this Dictionary> dict,T1 t,T2 t2) { List items; if(dict.TryGetValue(t,out items)) { foreach(var item in items) { if(item.Equals(t2)) return true; } } return false; } public static void StartWebSocketConnection(this ServerContext ctx,Action,Action,CancellationToken> opened,Action arrived,Action closed) { var t=ctx.StartWebSocketConnectionAsync(async(s,p,c)=>await Task.Run(()=>opened( (mm)=>{ Task.Run(async()=>await s(mm)).Wait(); },(data)=>{ Task.Run(async()=>await p(data)).Wait(); },c)),async(m)=>await Task.Run(()=>arrived(m)),closed); Task.Run(()=>t).Wait(); } public static async Task StartWebSocketConnectionAsync(this ServerContext ctx,Func,Func,CancellationToken,Task> opened,Func arrived,Action closed) { WebSocketServer server=new WebSocketServer(ctx); server.MessageArrived+=async(sender,e)=>{ try{ await arrived(e.Message); }catch(Exception ex) { _=ex; } }; server.WebSocketClosed+=(sender,e)=>{ closed(e.Clean); }; using(var cts=new CancellationTokenSource()){ Thread t=new Thread(async()=>{ try{ await opened(server.SendMessageAsync,server.Ping,cts.Token); }catch(Exception ex) { _=ex; } }); t.Start(); await server.StartAsync(); cts.Cancel(); t.Join(); } } } } namespace Tesses.WebServer.WebSocket { public class WebSocketMessage { public static WebSocketMessage Create(string text) { WebSocketMessage msg=new WebSocketMessage(); msg.Text = text; return msg; } public static WebSocketMessage Create(byte[] data) { WebSocketMessage msg=new WebSocketMessage(); msg.Data=data; return msg; } public static WebSocketMessage Create(object data) { WebSocketMessage msg=new WebSocketMessage(); msg.EncodeJson(data); return msg; } private WebSocketMessage() { Data=new byte[0]; } internal WebSocketMessage(byte[] message,bool binary) { data=message; Binary=binary; } private byte[] data; public bool Binary {get;private set;} public byte[] Data {get{return data;} private set{data=value; Binary=true;}} public T DecodeJson() { return JsonConvert.DeserializeObject(Text); } private void EncodeJson(object data) { Text=JsonConvert.SerializeObject(data); } internal IEnumerable<(byte[] array,int)> GetPackets() { int read=0; int offset=0; byte[] buffer=new byte[4096]; do { read = Math.Min(buffer.Length,data.Length-offset); Array.Copy(data,offset,buffer,0,read); yield return (buffer,read); offset+=read; }while(read>0); } public string Text {get{return Encoding.UTF8.GetString(Data);} private set{data=Encoding.UTF8.GetBytes(value); Binary=false;}} } public class WebSocketMessageEventArgs : EventArgs { public WebSocketMessageEventArgs(WebSocketMessage message) { Message=message; } public WebSocketMessage Message {get;private set;} } public class WebSocketClosedEventArgs : EventArgs { public WebSocketClosedEventArgs(bool clean) { Clean=clean; } public bool Clean {get;private set;} } public class WebSocketServer { bool hasInit=false; ServerContext context; public WebSocketServer(ServerContext ctx) { context=ctx; } public EventHandler MessageArrived; public EventHandler WebSocketClosed; private byte[] glenBytes(long len) { if(len < 126) { return new byte[]{(byte)len}; }else if(len <= ushort.MaxValue) { byte[] num = BitConverter.GetBytes((ushort)len); if(BitConverter.IsLittleEndian) { Array.Reverse(num); } return new byte[]{126,num[0],num[1]}; }else{ byte[] num = BitConverter.GetBytes(len); if(BitConverter.IsLittleEndian) { Array.Reverse(num); } return new byte[]{127,num[0],num[1],num[2],num[3],num[4],num[5],num[6],num[7]}; } } public async Task SendMessageAsync(WebSocketMessage msg) { while(!hasInit) ; int opCode = msg.Binary ? 0x2 : 0x1; (byte[] buff,int len)[] parts = msg.GetPackets().ToArray(); for(int i = 0;i get_long() { byte[] data = new byte[8]; await context.NetworkStream.ReadAsync(data,0,data.Length); if(BitConverter.IsLittleEndian) { Array.Reverse(data); } return BitConverter.ToInt64(data,0); } public async Task Ping(byte[] ping) { int finField = 0b10000000 ; byte firstByte= (byte)(finField | 0x9); var b=glenBytes(ping.Length); byte[] message = new byte[1+b.Length + ping.Length]; message[0]=firstByte; Array.Copy(b,0,message,1,b.Length); Array.Copy(ping,0,message,1+b.Length,ping.Length); await context.NetworkStream.WriteAsync(message,0,message.Length); } private async Task get_short() { byte[] data = new byte[2]; await context.NetworkStream.ReadAsync(data,0,data.Length); if(BitConverter.IsLittleEndian) { Array.Reverse(data); } return BitConverter.ToInt16(data,0); } private async Task<(byte[] data,long len)> read_packet_async(byte len) { int realLen=len & 127; bool masked=(len & 0b10000000) > 0; long realLen2 = realLen >= 126 ? realLen > 126 ? await get_long() : await get_short() : realLen; byte[] maskingKey = new byte[4]; if(masked) { await context.NetworkStream.ReadAsync(maskingKey,0,maskingKey.Length); } byte[] data = new byte[realLen2]; await context.NetworkStream.ReadAsync(data,0,data.Length); if(masked) { MaskMessage(maskingKey,data); } return (data,realLen2); } public async Task StartAsync() { /* GET /chatUrl HTTP/1.1 Host: server.example.com Upgrade: websocket */ string sec_websocket_accept=""; if(context.RequestHeaders.TryGetFirst("Sec-WebSocket-Key",out sec_websocket_accept)) { sec_websocket_accept=get_Sec_WebSocketAccept(sec_websocket_accept); }else{ return; } if(!context.RequestHeaders.AnyEquals("Upgrade","websocket")) { //Console.WriteLine("Doesn't contain Upgrade: websocket"); return; } if(!context.RequestHeaders.AnyEquals("Sec-WebSocket-Version", "13")) { //Console.WriteLine("Doesn't contain version 13"); return; } context.StatusCode = 101; context.ResponseHeaders.Add("Upgrade","websocket"); if(context.ResponseHeaders.ContainsKey("Connection")) { context.ResponseHeaders["Connection"].Clear(); } context.ResponseHeaders.Add("Connection","Upgrade"); context.ResponseHeaders.Add("Sec-WebSocket-Accept",sec_websocket_accept); await context.WriteHeadersAsync(); //await context.NetworkStream.FlushAsync(); await context.NetworkStream.FlushAsync(); hasInit=true; bool isBinary=false; MemoryStream strm=new MemoryStream(); while(context.Connected) { byte[] frame_start=new byte[2]; await context.NetworkStream.ReadAsync(frame_start,0,2); byte first= frame_start[0]; bool hasMessage =false; int opcode = first & 0xF; bool fin = (first | 0b10000000) > 0; switch(opcode) { case 0x0: if(!hasMessage) break; var (data,len)= await read_packet_async(frame_start[1]); strm.Write(data,0,(int)len); break; case 0x1: case 0x2: hasMessage=true; strm.Dispose(); strm=new MemoryStream(); isBinary = opcode == 0x2; var (data2,len2)= await read_packet_async(frame_start[1]); strm.Write(data2,0,(int)len2); break; case 0x8: WebSocketClosed?.Invoke(this,new WebSocketClosedEventArgs(true)); return; case 0x9: var (data3,len3) =await read_packet_async(frame_start[1]); await PongSend(data3,len3); break; case 0xA: var (data4,len4) =await read_packet_async(frame_start[1]); break; } if(fin && hasMessage) { hasMessage=false; WebSocketMessage msg=new WebSocketMessage(strm.ToArray(),isBinary); MessageArrived?.Invoke(this,new WebSocketMessageEventArgs(msg)); } } WebSocketClosed?.Invoke(this,new WebSocketClosedEventArgs(false)); } } }