import {
    AbstractLogger,
    ErrorBuilder,
    Errors,
    MessageType,
    SocketMessage,
    DisconnectionHandler,
    IClientSocket,
    PlainMessageHandler,
    ReconnectionHandler,
} from "@clairejs/core";

import { AbstractClientSocketManager } from "./AbstractClientSocketManager";
import { SocketConfig } from "./SocketConfig";
import { DefaultClientSocket } from "./DefaultClientSocket";
import { IWebSocket } from "./IWebSocket";

const pingpongMaxCount = 10;

export class DefaultClientSocketManager extends AbstractClientSocketManager {
    private socket?: IWebSocket;

    private allSockets: {
        socket: DefaultClientSocket;
        registeredChannels: string[];
        reconnectionHandler?: ReconnectionHandler;
        disconnectionHandler?: DisconnectionHandler;
        messageHandler?: PlainMessageHandler;
    }[] = [];

    private allChannels: { channel: string; connected: boolean; pendingMessages: any[] }[] = [];

    private socketConnected?: boolean;
    private pingIntervalId?: any;
    private retryTimeoutId?: any;
    private accumulatedPing = 0;
    private retryDelay = 0;
    private intendedDisconnection?: boolean;
    private pingpongId = 0;

    public pingpong: { id: number; sentTimestamp: number; receivedTimestamp?: number }[] = [];

    constructor(
        private readonly wsProvider: () => Promise<IWebSocket>,
        private readonly logger: AbstractLogger,
        private readonly config?: SocketConfig,
    ) {
        super();
    }

    getPingMs(): number {
        const candidates = this.pingpong.filter((pp) => pp.sentTimestamp && pp.receivedTimestamp);
        if (!candidates.length) {
            return 0;
        }
        return (
            candidates.map((pp) => pp.receivedTimestamp! - pp.sentTimestamp).reduce((total, diff) => total + diff, 0) /
            candidates.length
        );
    }

    subChannels(socket: IClientSocket, channels: string[]) {
        const info = this.allSockets.find((s) => s.socket === socket);
        if (!info) {
            this.logger.debug("Socket not found");
            return;
        }

        //-- un-connected channels
        const unconnectedChannels: string[] = [];
        for (const channel of channels) {
            if (!info.registeredChannels.includes(channel)) {
                info.registeredChannels.push(channel);
            }
            const foundChannel = this.allChannels.find((c) => c.channel === channel);
            if (!foundChannel || !foundChannel.connected) {
                unconnectedChannels.push(channel);
            }
        }

        this.joinChannels(unconnectedChannels);
    }

    unsubChannels(socket: IClientSocket, channels: string[]) {
        const info = this.allSockets.find((s) => s.socket === socket);
        if (!info) {
            this.logger.debug("Socket not found");
            return;
        }

        info.registeredChannels = info.registeredChannels.filter((c) => !channels.includes(c));

        //-- check if no one is subscribing these channel then remove
        const removedChannels: string[] = [];
        for (const channel of channels) {
            //-- no one is listening
            if (!this.allSockets.find((socket) => socket.registeredChannels.includes(channel))) {
                removedChannels.push(channel);
            }
        }

        this.leaveChannels(removedChannels);
    }

    removeSocket(socket: DefaultClientSocket) {
        const socketInfo = this.allSockets.find((s) => s.socket === socket);
        if (socketInfo && socketInfo.disconnectionHandler) {
            socketInfo.disconnectionHandler();
        }

        this.allSockets = this.allSockets.filter((s) => s.socket !== socket);
        if (!this.allSockets.length) {
            this.forceDisconnect();
        }
    }

    registerReconnectionHandler(socket: DefaultClientSocket, handler: ReconnectionHandler) {
        const info = this.allSockets.find((s) => s.socket === socket);
        if (info) {
            info.reconnectionHandler = handler;
        }
    }

    registerDisconnectionHandler(socket: DefaultClientSocket, handler: DisconnectionHandler) {
        const info = this.allSockets.find((s) => s.socket === socket);
        if (info) {
            info.disconnectionHandler = handler;
        }
    }

    registerMessageHandler(socket: DefaultClientSocket, handler: PlainMessageHandler) {
        const info = this.allSockets.find((s) => s.socket === socket);
        if (info) {
            info.messageHandler = handler;
        }
    }

    sendPlainMessageToChannel(message: any, channel?: string) {
        if (channel) {
            //-- check if channel has been connected
            let foundChannel = this.allChannels.find((c) => c.channel === channel);
            if (!foundChannel) {
                foundChannel = { channel, connected: false, pendingMessages: [] };
                this.allChannels.push(foundChannel);
            }
            if (foundChannel.connected) {
                //-- send
                this.sendRawMessage({
                    type: MessageType.PLAIN,
                    data: { channel, message },
                });
            } else {
                foundChannel.pendingMessages.push(message);
            }
        } else {
            this.sendRawMessage({
                type: MessageType.PLAIN,
                data: {
                    message,
                },
            });
        }
    }

    joinChannels(channels: string[]) {
        if (!channels.length) {
            return;
        }

        if (!this.socket || !this.socketConnected) {
            //-- create pending channels
            for (const channel of channels) {
                let foundChannel = this.allChannels.find((c) => c.channel === channel);
                if (!foundChannel) {
                    foundChannel = { channel, connected: false, pendingMessages: [] };
                    this.allChannels.push(foundChannel);
                }

                ///-- this channel is already connected, skip
                if (foundChannel.connected) {
                    continue;
                }
            }
        } else {
            this.sendRawMessage({
                type: MessageType.CHANNEL_JOIN,
                data: channels,
            });
        }
    }

    leaveChannels(channels: string[]) {
        if (!channels.length) {
            return;
        }

        this.handleChannelLeave(channels);

        if (this.socket && this.socketConnected) {
            this.sendRawMessage({
                type: MessageType.CHANNEL_LEAVE,
                data: channels,
            });
        }
    }

    private sendPingPong() {
        this.pingpongId += 1;
        this.pingpongId %= 100;
        const pingpongInfo = { id: this.pingpongId, sentTimestamp: Date.now() };
        this.pingpong.push(pingpongInfo);

        if (this.pingpong.length > pingpongMaxCount) {
            this.pingpong.shift();
        }

        this.sendRawMessage({
            type: MessageType.PING_PONG,
            data: pingpongInfo.id,
        });
    }

    private sendRawMessage(message: SocketMessage) {
        if (!this.socket || !this.socketConnected) {
            throw ErrorBuilder.error(Errors.BAD_STATE, "Socket not available");
        }
        this.socket.send(message);
        this.logger.debug("Raw send", message);
    }

    private handleChannelJoin(channels: string[]) {
        this.logger.debug("Joinning channels", channels);
        for (const channel of channels) {
            let foundChannel = this.allChannels.find((c) => c.channel === channel);

            if (!foundChannel) {
                foundChannel = { channel, connected: false, pendingMessages: [] };
                this.allChannels.push(foundChannel);
            }
            foundChannel.connected = true;

            //-- flush messages
            if (foundChannel.pendingMessages.length) {
                this.logger.debug(`Flushing ${foundChannel.pendingMessages.length} message`);
                for (const message of foundChannel.pendingMessages) {
                    this.sendPlainMessageToChannel(message, foundChannel.channel);
                }
                foundChannel.pendingMessages = [];
            }
        }
    }

    private handleChannelLeave(channels: string[]) {
        this.logger.debug("Leaving channels", channels);
        this.allChannels = this.allChannels.filter((c) => !channels.includes(c.channel));
    }

    private handlePlainMessage(message: any, channel?: string) {
        for (const socket of this.allSockets) {
            socket.messageHandler && socket.messageHandler(message, channel);
        }
    }

    private handleConnect() {
        this.logger.debug("Socket connected");

        //-- socket open, set interval
        if (this.pingIntervalId) {
            clearInterval(this.pingIntervalId);
        }

        this.accumulatedPing = 0;
        this.pingIntervalId = setInterval(() => {
            this.accumulatedPing += 1;
            if (this.accumulatedPing > (this.config?.keepAlive?.deadThreshold || 3)) {
                //-- socket connection lost, not intended
                this.intendedDisconnection = false;
                this.socket?.close();
            } else {
                this.sendPingPong();
            }
        }, this.config?.keepAlive?.pingIntervalMs || 10000);

        this.intendedDisconnection = false;
        if (this.retryTimeoutId) {
            clearTimeout(this.retryTimeoutId);
            this.retryTimeoutId = undefined;
        }

        for (const socket of this.allSockets) {
            socket.reconnectionHandler && socket.reconnectionHandler();
        }

        //-- try join pending channels
        const pendingChannels = this.allChannels.filter((c) => !c.connected).map((c) => c.channel);
        this.joinChannels(pendingChannels);
    }

    private handleDisconnect(err?: any) {
        if (err) {
            this.intendedDisconnection = true;
        }

        this.logger.debug("Socket connnection closed, error: ", err);

        if (this.pingIntervalId) {
            clearInterval(this.pingIntervalId);
        }

        this.socket = undefined;
        this.socketConnected = false;

        for (const socket of this.allSockets) {
            socket.disconnectionHandler && socket.disconnectionHandler(err);
        }

        if (this.intendedDisconnection) {
            //-- remove all channels
            this.allSockets = [];
            this.allChannels = [];
            this.logger.debug("Socket connection terminated");
        } else {
            //-- disconnect all channels to be reconnected when the socket is connected again
            this.allChannels = this.allChannels.map((c) => ({ ...c, connected: false }));
            if (!this.retryTimeoutId) {
                this.retryDelay = 0;
                this.retry();
            }
        }
    }

    private handleMessage(message: SocketMessage) {
        this.logger.debug("Raw receive", message);
        switch (message.type) {
            case MessageType.READY:
                this.handleConnect();
                break;
            case MessageType.PING_PONG:
                this.accumulatedPing = 0;
                const id = message.data;
                if (id) {
                    const pong = this.pingpong.find((pp) => pp.id === id);
                    if (pong) {
                        pong.receivedTimestamp = Date.now();
                    }
                }
                break;
            case MessageType.CHANNEL_JOIN:
                if (message.data?.error) {
                    this.logger.error("Join channel error", message.data.error);
                } else {
                    this.handleChannelJoin(message.data);
                }
                break;
            case MessageType.CHANNEL_LEAVE:
                this.handleChannelLeave(message.data);
                break;
            case MessageType.PLAIN:
                this.handlePlainMessage(message.data.message, message.data.channel);
                break;
        }
    }

    create(): IClientSocket {
        const socket = new DefaultClientSocket(this);

        this.allSockets.push({ socket, registeredChannels: [] });

        if (this.allSockets.length === 1) {
            this.forceReconnect();
        }

        return socket;
    }

    forceDisconnect() {
        this.intendedDisconnection = true;
        if (this.socket) {
            this.socket.close();
            this.socket = undefined;
        }

        if (this.retryTimeoutId) {
            clearTimeout(this.retryTimeoutId);
        }
    }

    forceReconnect(): void {
        this.intendedDisconnection = false;
        if (this.socketConnected) {
            this.socket?.close();
        } else {
            this.retryDelay = 0;
            if (this.socket === undefined) {
                if (this.retryTimeoutId) {
                    clearTimeout(this.retryTimeoutId);
                    this.retryTimeoutId = undefined;
                }
                this.physicConnect();
            }
        }
    }

    private physicConnect() {
        this.wsProvider().then((socket) => {
            this.socket = socket;

            this.socket.onopen(() => {
                this.socketConnected = true;
                this.logger.debug("Physic link connected, sending & waiting for READY message");
                this.sendRawMessage({
                    type: MessageType.READY,
                    data: "",
                });
            });

            this.socket.onmessage((data) => {
                const message: SocketMessage = JSON.parse(data);
                if (!message || !message.type) {
                    this.logger.debug("Invalid mesasge structure", data);
                    return;
                }

                this.handleMessage(message);
            });

            this.socket.onclose((err) => {
                this.handleDisconnect(err);
            });
        });
    }

    private retry() {
        this.logger.debug(`Socket connection retrying in ${this.retryDelay}ms`);
        this.physicConnect();
        this.retryDelay += this.config?.reconnectTimeDeltaMs || 3000;
        this.retryTimeoutId = setTimeout(() => {
            if (!this.socketConnected && !this.intendedDisconnection) {
                this.retry();
            }
        }, this.retryDelay);
    }
}
