using System.Net; using Cysharp.Threading.Tasks; using MongoDB.Bson; namespace KYFramework.Network { [ObjectSystem] public class SessionAwakeSystem : AwakeSystem { public override void Awake(Session self, AChannel b) { self.Awake(b); } } public sealed class Session : Entity { private static int RpcId { get; set; } private AChannel channel; private readonly Dictionary> requestCallback = new Dictionary>(); private readonly byte[] opcodeBytes = new byte[2]; public Action OnError; public Action OnClose; public NetworkComponent Network { get { return this.GetParent(); } } public int Error { get { return this.channel.Error; } set { this.channel.Error = value; } } public void Awake(AChannel aChannel) { this.channel = aChannel; this.requestCallback.Clear(); long id = this.Id; channel.ErrorCallback += (c, e) => { OnClose?.Invoke(c.RemoteAddress.ToString()); this.Network.Remove(id); }; channel.ReadCallback += this.OnRead; } public override void Dispose() { if (this.IsDisposed) { return; } this.Network.Remove(this.Id); base.Dispose(); foreach (Action action in this.requestCallback.Values.ToArray()) { action.Invoke(new ResponseMessage { Error = this.Error }); } this.channel.Dispose(); this.requestCallback.Clear(); } public void Start() { this.channel.Start(); } public IPEndPoint RemoteAddress { get { return this.channel.RemoteAddress; } } public ChannelType ChannelType { get { return this.channel.ChannelType; } } public MemoryStream Stream { get { return this.channel.Stream; } } public void OnRead(MemoryStream memoryStream) { try { this.Run(memoryStream); } catch (Exception e) { Log.Error(e); } } private void Run(MemoryStream memoryStream) { memoryStream.Seek(Packet.MessageIndex, SeekOrigin.Begin); ushort opcode = BitConverter.ToUInt16(memoryStream.GetBuffer(), Packet.OpcodeIndex); object message; try { OpcodeTypeComponent opcodeTypeComponent = this.Network.Entity.GetComponent(); object instance = opcodeTypeComponent.GetInstance(opcode); message = this.Network.MessagePacker.DeserializeFrom(instance, memoryStream); if (OpcodeHelper.IsNeedDebugLogMessage(opcode)) { Log.Info("receive -->" + message.ToJson()); } } catch (Exception e) { // 出现任何消息解析异常都要断开Session,防止客户端伪造消息 Log.Error($"opcode: {opcode} {this.Network.Count} {e} "); this.Error = ErrorCode.ERR_PacketParserError; this.Network.Remove(this.Id); return; } IResponse response = message as IResponse; if (response == null) { this.Network.MessageDispatcher.Dispatch(this, opcode, message); return; } Action action; if (!this.requestCallback.TryGetValue(response.RpcId, out action)) { throw new Exception($"not found rpc, response message: {StringHelper.MessageToStr(response)}"); } this.requestCallback.Remove(response.RpcId); action(response); } public UniTask Call(IRequest request) { int rpcId = ++RpcId; var tcs = new UniTaskCompletionSource(); this.requestCallback[rpcId] = (response) => { try { if (ErrorCode.IsRpcNeedThrowException(response.Error)) { //TODO Show MessageBox //MessageBoxHelper.Show("网络异常!请检查..." , "连接失败").BtnYes = () => //{ //LoginHelper.Reset(); //}; OnError?.Invoke("网络异常!请检查..."); } tcs.TrySetResult(response); } catch (Exception e) { tcs.TrySetException(new Exception($"Rpc Error: {request.GetType().FullName}", e)); } }; request.RpcId = rpcId; this.Send(request); return tcs.Task; } public UniTask Call(IRequest request, CancellationToken cancellationToken) { int rpcId = ++RpcId; var tcs = new UniTaskCompletionSource(); this.requestCallback[rpcId] = (response) => { try { if (ErrorCode.IsRpcNeedThrowException(response.Error)) { OnError?.Invoke("网络异常!请检查..."); return; } tcs.TrySetResult(response); } catch (Exception e) { tcs.TrySetException(new Exception($"Rpc Error: {request.GetType().FullName}", e)); } }; cancellationToken.Register(() => this.requestCallback.Remove(rpcId)); request.RpcId = rpcId; this.Send(request); return tcs.Task; } public void Reply(IResponse message) { if (this.IsDisposed) { throw new Exception("session已经被Dispose了"); } this.Send(message); } public void Send(IMessage message) { OpcodeTypeComponent opcodeTypeComponent = this.Network.Entity.GetComponent(); ushort opcode = opcodeTypeComponent.GetOpcode(message.GetType()); Send(opcode, message); } public void Send(ushort opcode, object message) { if (this.IsDisposed) { throw new Exception("session已经被Dispose了"); } if (OpcodeHelper.IsNeedDebugLogMessage(opcode)) { Log.Info("send --> " + message.ToJson()); } MemoryStream stream = this.Stream; stream.Seek(Packet.MessageIndex, SeekOrigin.Begin); stream.SetLength(Packet.MessageIndex); this.Network.MessagePacker.SerializeTo(message, stream); stream.Seek(0, SeekOrigin.Begin); opcodeBytes.WriteTo(0, opcode); Array.Copy(opcodeBytes, 0, stream.GetBuffer(), 0, opcodeBytes.Length); this.Send(stream); } public void Send(MemoryStream stream) { channel.Send(stream); } } }