const std = @import("std");
const eql = std.mem.eql;

pub const std_options = .{
    .log_level = .info,
};

fn http_error(writer: std.net.Stream.Writer, status_code: u16) !void {
    const status = switch (status_code) {
        200 => "OK",
        400 => "Bad Request",
        404 => "Not Found",
        405 => "Method Not Allowed",
        else => @panic("Invalid status"),
    };

    try writer.print("HTTP/1.1 {d} {s}\r\nServer: zig\r\nContent-Type: text/plain\r\nContent-Length: {d}\r\n\r\n{s}", .{ status_code, status, status.len, status });
}

fn format_bytes(buf: []u8, bytes: usize) ![]u8 {
    const fbytes: f32 = @floatFromInt(bytes);
    return switch (bytes) {
        0...1024 => std.fmt.bufPrint(buf, "{d} B", .{bytes}),
        1025...(1024 * 1024) => std.fmt.bufPrint(buf, "{d:.1} KiB", .{fbytes / 1024}),
        (1024 * 1024 + 1)...(1024 * 1024 * 1024) => std.fmt.bufPrint(buf, "{d:.1} MiB", .{fbytes / 1024 / 1024}),
        else => std.fmt.bufPrint(buf, "{d:.1} GiB", .{fbytes / 1024 / 1024 / 1024}),
    };
}

fn handle_connection(connection: std.net.Server.Connection, path: []u8, filename: []u8) !void {
    const writer = connection.stream.writer();

    {
        var buffer: [255]u8 = undefined;
        const n = try connection.stream.read(&buffer);
        if (n < 12) {
            return http_error(writer, 400);
        }

        if (!eql(u8, buffer[0..4], "GET ")) {
            return http_error(writer, 405);
        }

        if (n < (11 + path.len) or !eql(u8, buffer[4 .. 4 + path.len], path)) {
            return http_error(writer, 404);
        }
    }

    var prng = std.rand.DefaultPrng.init(blk: {
        var seed: u64 = undefined;
        try std.posix.getrandom(std.mem.asBytes(&seed));
        break :blk seed;
    });
    const random = prng.random();

    std.log.info("Client {} connected", .{connection.address});
    defer {
        connection.stream.close();
        std.log.info("Client {} disconnected", .{connection.address});
    }

    try writer.print("HTTP/1.1 200 OK\r\nServer: zig\r\nContent-Type: application/octet-stream\r\nContent-Disposition: attachment; filename=\"{s}\"\r\n\r\n", .{filename});

    var downloaded: usize = 0;
    var out_counter: usize = 0;

    while (true) {
        var writeBuffer: [1024]u8 = undefined;
        random.bytes(&writeBuffer);
        const wrote = connection.stream.write(&writeBuffer) catch |err| switch (err) {
            error.BrokenPipe => break,
            error.ConnectionResetByPeer => break,
            else => return err,
        };
        downloaded += wrote;
        out_counter += wrote;

        if (out_counter >= 1024 * 1024 * 100) {
            var buf: [16]u8 = undefined;
            const downloaded_fmt = try format_bytes(&buf, downloaded);
            std.log.info("Client {} downloaded {s}", .{ connection.address, downloaded_fmt });
            out_counter = 0;
        }
    }
}

pub fn main() !void {
    var gpa = std.heap.GeneralPurposeAllocator(.{}){};
    const allocator = gpa.allocator();
    defer {
        const check = gpa.deinit();
        if (check == .leak) @panic("memory leaks");
    }

    const args = try std.process.argsAlloc(allocator);
    defer std.process.argsFree(allocator, args);

    if (args.len != 5) {
        const writer = std.io.getStdErr().writer();
        try writer.print("Usage: {s} DOWNLOAD_PATH DOWNLOAD_FILENAME ADDRESS PORT\n", .{args[0]});
        std.process.exit(1);
        unreachable;
    }

    const listen_address = try std.net.Address.parseIp(args[3], try std.fmt.parseInt(u16, args[4], 10));
    var server = try listen_address.listen(.{ .reuse_address = true });
    defer server.deinit();
    std.log.info("Listening on http://{s}:{s}", .{ args[3], args[4] });

    while (true) {
        const connection = try server.accept();
        _ = try std.Thread.spawn(.{}, handle_connection, .{ connection, args[1], args[2] });
    }
}
