From d6f785dcb9aec57b010318acf899397a591c8d8e Mon Sep 17 00:00:00 2001 From: Tejun Heo Date: Mon, 22 Apr 2024 12:20:59 -1000 Subject: [PATCH] scx: Implement scx_bpf_consume_task() This allows consuming a specific task while iterating a DSQ using the BPF iterator significantly increasing DSQ's flexibility. It has to jump through some hoops to work around BPF restrictions which hopefully can be removed in the future. scx_qmap is updated with a silly priority boost mechanism to demonstrate the usage. --- kernel/sched/ext.c | 83 +++++++++++++++++++++++- tools/sched_ext/include/scx/common.bpf.h | 16 +++++ tools/sched_ext/scx_qmap.bpf.c | 51 ++++++++++++++- tools/sched_ext/scx_qmap.c | 12 +++- 4 files changed, 153 insertions(+), 9 deletions(-) diff --git a/kernel/sched/ext.c b/kernel/sched/ext.c index 8ee8c9a0e5271..2c0df278deafc 100644 --- a/kernel/sched/ext.c +++ b/kernel/sched/ext.c @@ -1148,6 +1148,12 @@ enum scx_dsq_iter_flags { }; struct bpf_iter_scx_dsq_kern { + /* + * Must be the first field. Used to work around BPF restriction and pass + * in the iterator pointer to scx_bpf_consume_task(). + */ + struct bpf_iter_scx_dsq_kern *self; + struct scx_dsq_node cursor; struct scx_dispatch_q *dsq; u64 dsq_seq; @@ -1536,7 +1542,7 @@ static void dispatch_enqueue(struct scx_dispatch_q *dsq, struct task_struct *p, p->scx.dsq_seq = dsq->seq; dsq_mod_nr(dsq, 1); - p->scx.dsq = dsq; + WRITE_ONCE(p->scx.dsq, dsq); /* * scx.ddsp_dsq_id and scx.ddsp_enq_flags are only relevant on the @@ -1629,7 +1635,7 @@ static void dispatch_dequeue(struct scx_rq *scx_rq, struct task_struct *p) WARN_ON_ONCE(task_linked_on_dsq(p)); p->scx.holding_cpu = -1; } - p->scx.dsq = NULL; + WRITE_ONCE(p->scx.dsq, NULL); if (!is_local) raw_spin_unlock(&dsq->lock); @@ -2133,7 +2139,7 @@ static void consume_local_task(struct rq *rq, struct scx_dispatch_q *dsq, list_add_tail(&p->scx.dsq_node.list, &scx_rq->local_dsq.list); dsq_mod_nr(dsq, -1); dsq_mod_nr(&scx_rq->local_dsq, 1); - p->scx.dsq = &scx_rq->local_dsq; + WRITE_ONCE(p->scx.dsq, &scx_rq->local_dsq); raw_spin_unlock(&dsq->lock); } @@ -5731,12 +5737,82 @@ __bpf_kfunc bool scx_bpf_consume(u64 dsq_id) } } +/** + * __scx_bpf_consume_task - Transfer a task from DSQ iteration to the local DSQ + * @it: DSQ iterator in progress + * @p: task to consume + * + * Transfer @p which is on the DSQ currently iterated by @it to the current + * CPU's local DSQ. For the transfer to be successful, @p must still be on the + * DSQ and have been queued before the DSQ iteration started. This function + * doesn't care whether @p was obtained from the DSQ iteration. @p just has to + * be on the DSQ and have been queued before the iteration started. + * + * Returns %true if @p has been consumed, %false if @p had already been consumed + * or dequeued. + */ +__bpf_kfunc bool __scx_bpf_consume_task(unsigned long it, struct task_struct *p) +{ + struct bpf_iter_scx_dsq_kern *kit = (void *)it; + struct scx_dispatch_q *dsq, *kit_dsq; + struct scx_dsp_ctx *dspc = this_cpu_ptr(&scx_dsp_ctx); + struct rq *task_rq; + u64 kit_dsq_seq; + + /* can't trust @kit, carefully fetch the values we need */ + if (get_kernel_nofault(kit_dsq, &kit->dsq) || + get_kernel_nofault(kit_dsq_seq, &kit->dsq_seq)) { + scx_ops_error("invalid @it 0x%lx", it); + return false; + } + + /* + * @kit can't be trusted and we can only get the DSQ from @p. As we + * don't know @p's rq is locked, use READ_ONCE() to access the field. + * Derefing is safe as DSQs are RCU protected. + */ + dsq = READ_ONCE(p->scx.dsq); + if (!dsq || dsq != kit_dsq) + return false; + + if (!scx_kf_allowed(SCX_KF_DISPATCH)) + return false; + + flush_dispatch_buf(dspc->rq, dspc->rf); + + raw_spin_lock(&dsq->lock); + + /* + * Did someone else get to it? @p could have already left $dsq, got + * re-enqueud, or be in the process of being consumed by someone else. + */ + if (unlikely(p->scx.dsq != dsq || + time_after64(p->scx.dsq_seq, kit_dsq_seq) || + p->scx.holding_cpu >= 0)) + goto out_unlock; + + task_rq = task_rq(p); + + if (dspc->rq == task_rq) { + consume_local_task(dspc->rq, dsq, p); + return true; + } + + if (task_can_run_on_remote_rq(p, dspc->rq)) + return consume_remote_task(dspc->rq, dspc->rf, dsq, p, task_rq); + +out_unlock: + raw_spin_unlock(&dsq->lock); + return false; +} + __bpf_kfunc_end_defs(); BTF_KFUNCS_START(scx_kfunc_ids_dispatch) BTF_ID_FLAGS(func, scx_bpf_dispatch_nr_slots) BTF_ID_FLAGS(func, scx_bpf_dispatch_cancel) BTF_ID_FLAGS(func, scx_bpf_consume) +BTF_ID_FLAGS(func, __scx_bpf_consume_task) BTF_KFUNCS_END(scx_kfunc_ids_dispatch) static const struct btf_kfunc_id_set scx_kfunc_set_dispatch = { @@ -5945,6 +6021,7 @@ __bpf_kfunc int bpf_iter_scx_dsq_new(struct bpf_iter_scx_dsq *it, u64 dsq_id, INIT_LIST_HEAD(&kit->cursor.list); RB_CLEAR_NODE(&kit->cursor.priq); kit->cursor.flags = SCX_TASK_DSQ_CURSOR; + kit->self = kit; kit->dsq_seq = kit->dsq->seq; kit->flags = flags; diff --git a/tools/sched_ext/include/scx/common.bpf.h b/tools/sched_ext/include/scx/common.bpf.h index 567d8761a7ca6..e255c6725d46f 100644 --- a/tools/sched_ext/include/scx/common.bpf.h +++ b/tools/sched_ext/include/scx/common.bpf.h @@ -35,6 +35,7 @@ void scx_bpf_dispatch_vtime(struct task_struct *p, u64 dsq_id, u64 slice, u64 vt u32 scx_bpf_dispatch_nr_slots(void) __ksym; void scx_bpf_dispatch_cancel(void) __ksym; bool scx_bpf_consume(u64 dsq_id) __ksym; +bool __scx_bpf_consume_task(unsigned long it, struct task_struct *p) __ksym; u32 scx_bpf_reenqueue_local(void) __ksym; void scx_bpf_kick_cpu(s32 cpu, u64 flags) __ksym; s32 scx_bpf_dsq_nr_queued(u64 dsq_id) __ksym; @@ -64,6 +65,21 @@ struct cgroup *scx_bpf_task_cgroup(struct task_struct *p) __ksym; static inline __attribute__((format(printf, 1, 2))) void ___scx_bpf_exit_format_checker(const char *fmt, ...) {} +/* hopefully temporary wrapper to work around BPF restriction */ +static inline bool scx_bpf_consume_task(struct bpf_iter_scx_dsq *it, + struct task_struct *p) +{ + unsigned long ptr; + bpf_probe_read_kernel(&ptr, sizeof(ptr), it); + return __scx_bpf_consume_task(ptr, p); +} + +/* + * Use the following as @it when calling scx_bpf_consume_task() from whitin + * bpf_for_each() loops. + */ +#define BPF_FOR_EACH_ITER (&___it) + /* * Helper macro for initializing the fmt and variadic argument inputs to both * bstr exit kfuncs. Callers to this function should use ___fmt and ___param to diff --git a/tools/sched_ext/scx_qmap.bpf.c b/tools/sched_ext/scx_qmap.bpf.c index 1a85991fe13f8..624c76b8a2611 100644 --- a/tools/sched_ext/scx_qmap.bpf.c +++ b/tools/sched_ext/scx_qmap.bpf.c @@ -23,6 +23,7 @@ * Copyright (c) 2022 David Vernet */ #include +#include enum consts { ONE_SEC_IN_NS = 1000000000, @@ -37,6 +38,7 @@ const volatile u32 stall_kernel_nth; const volatile u32 dsp_inf_loop_after; const volatile u32 dsp_batch; const volatile bool print_shared_dsq; +const volatile char exp_prefix[17]; const volatile s32 disallow_tgid; const volatile bool switch_partial; @@ -121,7 +123,7 @@ struct { /* Statistics */ u64 nr_enqueued, nr_dispatched, nr_reenqueued, nr_dequeued; -u64 nr_core_sched_execed; +u64 nr_core_sched_execed, nr_expedited; u32 cpuperf_min, cpuperf_avg, cpuperf_max; u32 cpuperf_target_min, cpuperf_target_avg, cpuperf_target_max; @@ -260,6 +262,49 @@ static void update_core_sched_head_seq(struct task_struct *p) scx_bpf_error("task_ctx lookup failed"); } +static bool consume_shared_dsq(void) +{ + struct task_struct *p; + bool consumed; + s32 i; + + if (exp_prefix[0] == '\0') + return scx_bpf_consume(SHARED_DSQ); + + /* + * To demonstrate the use of scx_bpf_consume_task(), implement silly + * selective priority boosting mechanism by scanning SHARED_DSQ looking + * for matching comms and consume them first. This makes difference only + * when dsp_batch is larger than 1. + */ + consumed = false; + bpf_for_each(scx_dsq, p, SHARED_DSQ, 0) { + bool match = true; + char comm[sizeof(exp_prefix)]; + + memcpy(comm, p->comm, sizeof(exp_prefix) - 1); + + bpf_for(i, 0, sizeof(exp_prefix)) { + if (exp_prefix[i] == '\0') + break; + if (exp_prefix[i] != comm[i]) { + match = false; + break; + } + } + + if (match && scx_bpf_consume_task(BPF_FOR_EACH_ITER, p)) { + consumed = true; + __sync_fetch_and_add(&nr_expedited, 1); + } + } + + if (consumed) + return true; + else + return scx_bpf_consume(SHARED_DSQ); +} + void BPF_STRUCT_OPS(qmap_dispatch, s32 cpu, struct task_struct *prev) { struct task_struct *p; @@ -268,7 +313,7 @@ void BPF_STRUCT_OPS(qmap_dispatch, s32 cpu, struct task_struct *prev) void *fifo; s32 i, pid; - if (scx_bpf_consume(SHARED_DSQ)) + if (consume_shared_dsq()) return; if (dsp_inf_loop_after && nr_dispatched > dsp_inf_loop_after) { @@ -319,7 +364,7 @@ void BPF_STRUCT_OPS(qmap_dispatch, s32 cpu, struct task_struct *prev) batch--; cpuc->dsp_cnt--; if (!batch || !scx_bpf_dispatch_nr_slots()) { - scx_bpf_consume(SHARED_DSQ); + consume_shared_dsq(); return; } if (!cpuc->dsp_cnt) diff --git a/tools/sched_ext/scx_qmap.c b/tools/sched_ext/scx_qmap.c index a2bd98b7283ac..4c5957417d102 100644 --- a/tools/sched_ext/scx_qmap.c +++ b/tools/sched_ext/scx_qmap.c @@ -29,6 +29,8 @@ const char help_fmt[] = " -l COUNT Trigger dispatch infinite looping after COUNT dispatches\n" " -b COUNT Dispatch upto COUNT tasks together\n" " -P Print out DSQ content to trace_pipe every second, use with -b\n" +" -E PREFIX Expedite consumption of threads w/ matching comm, use with -b\n" +" (e.g. match shell on a loaded system)\n" " -d PID Disallow a process from switching into SCHED_EXT (-1 for self)\n" " -D LEN Set scx_exit_info.dump buffer length\n" " -p Switch only tasks on SCHED_EXT policy intead of all\n" @@ -55,7 +57,7 @@ int main(int argc, char **argv) skel = scx_qmap__open(); SCX_BUG_ON(!skel, "Failed to open skel"); - while ((opt = getopt(argc, argv, "s:e:t:T:l:b:Pd:D:ph")) != -1) { + while ((opt = getopt(argc, argv, "s:e:t:T:l:b:PE:d:D:ph")) != -1) { switch (opt) { case 's': skel->rodata->slice_ns = strtoull(optarg, NULL, 0) * 1000; @@ -78,6 +80,10 @@ int main(int argc, char **argv) case 'P': skel->rodata->print_shared_dsq = true; break; + case 'E': + strncpy(skel->rodata->exp_prefix, optarg, + sizeof(skel->rodata->exp_prefix) - 1); + break; case 'd': skel->rodata->disallow_tgid = strtol(optarg, NULL, 0); if (skel->rodata->disallow_tgid < 0) @@ -103,10 +109,10 @@ int main(int argc, char **argv) long nr_enqueued = skel->bss->nr_enqueued; long nr_dispatched = skel->bss->nr_dispatched; - printf("stats : enq=%lu dsp=%lu delta=%ld reenq=%" PRIu64 " deq=%" PRIu64 " core=%" PRIu64 "\n", + printf("stats : enq=%lu dsp=%lu delta=%ld reenq=%"PRIu64" deq=%"PRIu64" core=%"PRIu64" exp=%"PRIu64"\n", nr_enqueued, nr_dispatched, nr_enqueued - nr_dispatched, skel->bss->nr_reenqueued, skel->bss->nr_dequeued, - skel->bss->nr_core_sched_execed); + skel->bss->nr_core_sched_execed, skel->bss->nr_expedited); printf("cpuperf: cur min/avg/max=%u/%u/%u target min/avg/max=%u/%u/%u\n", skel->bss->cpuperf_min, skel->bss->cpuperf_avg,