Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements for task_group API #1498

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/oneapi/tbb/detail/_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ TBB_EXPORT d1::slot_id __TBB_EXPORTED_FUNC execution_slot(const d1::execution_da
TBB_EXPORT d1::slot_id __TBB_EXPORTED_FUNC execution_slot(const d1::task_arena_base&);
TBB_EXPORT d1::task_group_context* __TBB_EXPORTED_FUNC current_context();
TBB_EXPORT d1::wait_tree_vertex_interface* get_thread_reference_vertex(d1::wait_tree_vertex_interface* wc);
TBB_EXPORT d1::task* __TBB_EXPORTED_FUNC current_task();

// Do not place under __TBB_RESUMABLE_TASKS. It is a stub for unsupported platforms.
struct suspend_point_type;
Expand Down Expand Up @@ -213,6 +214,7 @@ class reference_vertex : public wait_tree_vertex_interface {
}
private:
wait_tree_vertex_interface* my_parent;
protected:
std::atomic<std::uint64_t> m_ref_count;
};

Expand Down Expand Up @@ -268,6 +270,7 @@ inline void wait(wait_context& wait_ctx, task_group_context& ctx) {
call_itt_task_notify(destroy, &wait_ctx);
}

using r1::current_task;
using r1::current_context;

class task_traits {
Expand Down
178 changes: 167 additions & 11 deletions include/oneapi/tbb/detail/_task_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
#include "_task.h"
#include "_small_object_pool.h"
#include "_utils.h"
#include "oneapi/tbb/mutex.h"

#include <memory>
#include <forward_list>
#include <iostream>

namespace tbb {
namespace detail {
Expand All @@ -31,12 +35,89 @@ namespace d1 { class task_group_context; class wait_context; struct execution_da
namespace d2 {

class task_handle;
class task_handle_task;
class task_state_handler;

class task_handle_task : public d1::task {
std::uint64_t m_version_and_traits{};
d1::wait_tree_vertex_interface* m_wait_tree_vertex;
class continuation_vertex : public d1::reference_vertex {
public:
continuation_vertex(task_handle_task* continuation_task, d1::task_group_context& ctx, d1::small_object_allocator& alloc)
: d1::reference_vertex(nullptr, 1)
, m_continuation_task(continuation_task)
, m_ctx(ctx)
, m_allocator(alloc)
{}

void release(std::uint32_t delta = 1) override;

private:
task_handle_task* m_continuation_task;
d1::task_group_context& m_ctx;
d1::small_object_allocator m_allocator;
};

class transfer_vertex : public d1::reference_vertex {
public:
transfer_vertex(task_state_handler* handler, d1::small_object_allocator& alloc)
: d1::reference_vertex(nullptr, 1)
, m_handler(handler)
, m_allocator(alloc)
{}

void release(std::uint32_t) override;

void add_successor(d1::wait_tree_vertex_interface* successor) {
m_wait_tree_vertex_successors.push_front(successor);
}

private:
task_state_handler* m_handler{nullptr};
std::forward_list<d1::wait_tree_vertex_interface*> m_wait_tree_vertex_successors;
d1::small_object_allocator m_allocator;
};

class task_state_handler {
public:
task_state_handler(task_handle_task* task, d1::small_object_allocator& alloc) : m_task(task), m_alloc(alloc) {}
void release() {
d1::mutex::scoped_lock lock(m_mutex);
release_impl(lock);
}

void complete_task(bool is_from_transfer = false) {
d1::mutex::scoped_lock lock(m_mutex);
if (m_transfer == nullptr || is_from_transfer) {
m_is_finished = true;
}
release_impl(lock);
}

void add_successor(task_handle_task& successor);
void transfer_successors_to(task_handle_task* target);

transfer_vertex* create_transfer_vertex() {
d1::small_object_allocator alloc;
++m_num_references;
m_transfer = alloc.new_object<transfer_vertex>(this, alloc);
return m_transfer;
}

private:
void release_impl(d1::mutex::scoped_lock& lock) {
if (--m_num_references == 0) {
lock.release();
m_alloc.delete_object(this);
}
}

task_handle_task* m_task;
bool m_is_finished{false};
transfer_vertex* m_transfer{nullptr};
int m_num_references{2};
d1::mutex m_mutex;
d1::small_object_allocator m_alloc;
};

class task_handle_task : public d1::task {
public:
void finalize(const d1::execution_data* ed = nullptr) {
if (ed) {
Expand All @@ -47,18 +128,56 @@ class task_handle_task : public d1::task {
}

task_handle_task(d1::wait_tree_vertex_interface* vertex, d1::task_group_context& ctx, d1::small_object_allocator& alloc)
: m_wait_tree_vertex(vertex)
: m_wait_tree_vertex_successors{vertex}
, m_ctx(ctx)
, m_allocator(alloc) {
, m_allocator(alloc)
{
suppress_unused_warning(m_version_and_traits);
m_wait_tree_vertex->reserve();
vertex->reserve();
}

~task_handle_task() override {
m_wait_tree_vertex->release();
~task_handle_task() {
if (m_state_holder) {
m_state_holder->complete_task();
}
release_successors();
}

d1::task_group_context& ctx() const { return m_ctx; }

bool has_dependency() const { return m_continuation != nullptr; }

void release_continuation() { m_continuation->release(); }

void unset_continuation() { m_continuation = nullptr; }

void transfer_successors_to(task_handle_task* target) {
// TODO: What if we set current task as a dependency later?
if (m_state_holder) {
m_state_holder->transfer_successors_to(target);
}
}

task_state_handler* get_state_holder() {
d1::small_object_allocator alloc;
m_state_holder = alloc.new_object<task_state_handler>(this, alloc);
return m_state_holder;
}

private:
void release_successors() {
for (auto successor : m_wait_tree_vertex_successors) {
successor->release();
}
}

friend task_state_handler;
std::uint64_t m_version_and_traits{};
task_state_handler* m_state_holder{nullptr};
std::forward_list<d1::wait_tree_vertex_interface*> m_wait_tree_vertex_successors;
continuation_vertex* m_continuation{nullptr};
d1::task_group_context& m_ctx;
d1::small_object_allocator m_allocator;
};


Expand All @@ -69,10 +188,39 @@ class task_handle {
using handle_impl_t = std::unique_ptr<task_handle_task, task_handle_task_finalizer_t>;

handle_impl_t m_handle = {nullptr};
task_state_handler* m_state_holder = {nullptr};
public:
task_handle() = default;
task_handle(task_handle&&) = default;
task_handle& operator=(task_handle&&) = default;
task_handle(task_handle&& th) : m_handle(std::move(th.m_handle)), m_state_holder(th.m_state_holder) {
th.m_state_holder = nullptr;
}

task_handle& operator=(task_handle&& th) {
if (this != &th) {
m_handle = std::move(th.m_handle);
m_state_holder = th.m_state_holder;
th.m_state_holder = nullptr;
}
return *this;
}

~task_handle() {
if (m_state_holder) {
m_state_holder->release();
}
}

void add_predecessor(task_handle& th) {
if (m_state_holder) {
th.m_state_holder->add_successor(*m_handle);
}
}

void add_successor(task_handle& th) {
if (m_state_holder) {
m_state_holder->add_successor(*th.m_handle);
}
}

explicit operator bool() const noexcept { return static_cast<bool>(m_handle); }

Expand All @@ -85,7 +233,7 @@ class task_handle {
private:
friend struct task_handle_accessor;

task_handle(task_handle_task* t) : m_handle {t}{};
task_handle(task_handle_task* t) : m_handle{t}, m_state_holder(t->get_state_holder()) {};

d1::task* release() {
return m_handle.release();
Expand All @@ -99,6 +247,14 @@ static d1::task_group_context& ctx_of(task_handle& th) {
__TBB_ASSERT(th.m_handle, "ctx_of does not expect empty task_handle.");
return th.m_handle->ctx();
}
static bool has_dependency(task_handle& th) { return th.m_handle->has_dependency(); }
static void release_continuation(task_handle& th) {
th.m_handle->release_continuation();
th.release();
}
static void transfer_successors_to(task_handle& th, task_handle_task* task) {
task->transfer_successors_to(th.m_handle.get());
}
};

inline bool operator==(task_handle const& th, std::nullptr_t) noexcept {
Expand Down
70 changes: 67 additions & 3 deletions include/oneapi/tbb/task_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ class task_group_context : no_copy {
friend struct r1::task_arena_impl;
friend struct r1::task_group_context_impl;
friend class d2::task_group_base;
friend class d2::continuation_vertex;
}; // class task_group_context

static_assert(sizeof(task_group_context) == 128, "Wrong size of task_group_context");
Expand Down Expand Up @@ -445,7 +446,9 @@ class function_stack_task : public d1::task {
d1::wait_tree_vertex_interface* m_wait_tree_vertex;

void finalize() {
m_wait_tree_vertex->release();
if (m_wait_tree_vertex) {
m_wait_tree_vertex->release();
}
}
task* execute(d1::execution_data&) override {
task* res = d2::task_ptr_or_nullptr(m_func);
Expand Down Expand Up @@ -581,13 +584,16 @@ class task_group : public task_group_base {
using acs = d2::task_handle_accessor;
__TBB_ASSERT(&acs::ctx_of(h) == &context(), "Attempt to schedule task_handle into different task_group");

d1::spawn(*acs::release(h), context());
if (!acs::has_dependency(h)) {
d1::spawn(*acs::release(h), context());
} else {
acs::release_continuation(h);
}
}

template<typename F>
d2::task_handle defer(F&& f) {
return prepare_task_handle(std::forward<F>(f));

}

template<typename F>
Expand All @@ -600,6 +606,57 @@ class task_group : public task_group_base {
}
}; // class task_group

inline void continuation_vertex::release(std::uint32_t delta) {
std::uint64_t ref = m_ref_count.fetch_sub(static_cast<std::uint64_t>(delta)) - static_cast<std::uint64_t>(delta);
if (ref == 0) {
m_continuation_task->unset_continuation();
d1::spawn(*m_continuation_task, m_ctx.actual_context());
m_allocator.delete_object(this);
}
}

inline void task_state_handler::add_successor(task_handle_task& successor) {
if (successor.m_continuation == nullptr) {
d1::small_object_allocator alloc;
successor.m_continuation = alloc.new_object<continuation_vertex>(&successor, successor.m_ctx, alloc);
}

d1::mutex::scoped_lock lock(m_mutex);
if (!m_is_finished && m_transfer) {
successor.m_continuation->reserve();
m_transfer->add_successor(successor.m_continuation);
} else if (!m_is_finished) {
successor.m_continuation->reserve();
m_task->m_wait_tree_vertex_successors.push_front(successor.m_continuation);
}
}

inline void task_state_handler::transfer_successors_to(task_handle_task* target) {
d1::mutex::scoped_lock lock(m_mutex);

auto task_finalizer = create_transfer_vertex();
target->m_wait_tree_vertex_successors.push_front(task_finalizer);
target->m_wait_tree_vertex_successors.splice_after(target->m_wait_tree_vertex_successors.begin(), m_task->m_wait_tree_vertex_successors);
m_task->m_wait_tree_vertex_successors.clear();
}

inline void transfer_vertex::release(std::uint32_t) {
m_handler->complete_task(true);
for (auto successor : m_wait_tree_vertex_successors) {
successor->release();
}
m_allocator.delete_object(this);
}

inline void transfer_successors_to(d2::task_handle& h) {
task_handle_task* task = dynamic_cast<task_handle_task*>(d1::current_task());
__TBB_ASSERT_RELEASE(task, "Attempt to transfer successors from non-task_handle_task");
using acs = d2::task_handle_accessor;
__TBB_ASSERT(&acs::ctx_of(h) == &task->ctx(), "Attempt to transfer successors to task_handle into different task_group");

acs::transfer_successors_to(h, task);
}

#if TBB_PREVIEW_ISOLATED_TASK_GROUP
class spawn_delegate : public d1::delegate_base {
d1::task* task_to_spawn;
Expand Down Expand Up @@ -701,6 +758,13 @@ using detail::d1::is_current_task_group_canceling;
using detail::r1::missing_wait;

using detail::d2::task_handle;

namespace this_task_group {
namespace current_task {
using detail::d2::transfer_successors_to;
}
}

}

} // namespace tbb
Expand Down
1 change: 1 addition & 0 deletions src/tbb/def/lin32-tbb.def
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ _ZN3tbb6detail2r16resumeEPNS1_18suspend_point_typeE;
_ZN3tbb6detail2r121current_suspend_pointEv;
_ZN3tbb6detail2r114notify_waitersEj;
_ZN3tbb6detail2r127get_thread_reference_vertexEPNS0_2d126wait_tree_vertex_interfaceE;
_ZN3tbb6detail2r112current_taskEv;

/* Task dispatcher (task_dispatcher.cpp) */
_ZN3tbb6detail2r114execution_slotEPKNS0_2d114execution_dataE;
Expand Down
1 change: 1 addition & 0 deletions src/tbb/def/lin64-tbb.def
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ _ZN3tbb6detail2r16resumeEPNS1_18suspend_point_typeE;
_ZN3tbb6detail2r121current_suspend_pointEv;
_ZN3tbb6detail2r114notify_waitersEm;
_ZN3tbb6detail2r127get_thread_reference_vertexEPNS0_2d126wait_tree_vertex_interfaceE;
_ZN3tbb6detail2r112current_taskEv;

/* Task dispatcher (task_dispatcher.cpp) */
_ZN3tbb6detail2r114execution_slotEPKNS0_2d114execution_dataE;
Expand Down
1 change: 1 addition & 0 deletions src/tbb/def/mac64-tbb.def
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ __ZN3tbb6detail2r16resumeEPNS1_18suspend_point_typeE
__ZN3tbb6detail2r121current_suspend_pointEv
__ZN3tbb6detail2r114notify_waitersEm
__ZN3tbb6detail2r127get_thread_reference_vertexEPNS0_2d126wait_tree_vertex_interfaceE
__ZN3tbb6detail2r112current_taskEv

# Task dispatcher (task_dispatcher.cpp)
__ZN3tbb6detail2r114execution_slotEPKNS0_2d114execution_dataE
Expand Down
1 change: 1 addition & 0 deletions src/tbb/def/win32-tbb.def
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ EXPORTS
?suspend@r1@detail@tbb@@YAXP6AXPAXPAUsuspend_point_type@123@@Z0@Z
?notify_waiters@r1@detail@tbb@@YAXI@Z
?get_thread_reference_vertex@r1@detail@tbb@@YAPAVwait_tree_vertex_interface@d1@23@PAV4523@@Z
?current_task@r1@detail@tbb@@YAPAVtask@d1@23@XZ

; Task dispatcher (task_dispatcher.cpp)
?spawn@r1@detail@tbb@@YAXAAVtask@d1@23@AAVtask_group_context@523@G@Z
Expand Down
Loading
Loading