Problem Statement
Design a multi-stage task scheduler engine that executes a graph of tasks connected by dependency edges (similar to Apache Airflow or Argo Workflows). The engine must accept a group of tasks, validate that the graph is a Directed Acyclic Graph (DAG) by detecting cycles, execute tasks that have no dependencies in parallel using a thread pool, and dynamically trigger child tasks as their parent dependencies successfully finish.
Design Decisions & Patterns Used
Scheduling multi-stage jobs requires topological dependency sorting. If Task B and Task C depend on Task A, Task A must run and complete first, after which B and C can run concurrently. A robust LLD must validate the dependency graph to prevent cycles (deadlocks) and handle asynchronous event completions.
We will utilize the following Design Patterns:
- Observer Pattern: Allowing child tasks to monitor the completion events of parent tasks.
- Visitor Pattern: Traversing the dependency nodes to perform graph validation checks (cycle detection).
- Template Method: Defining the standard execution lifecycle (e.g., pre-run validation, execution, status logging, child notification) while letting tasks define custom execution logic.
Functional Requirements
- Model tasks as nodes with custom executable payloads and dependency sets.
- Validate the graph to detect circular dependencies (e.g., A -> B -> C -> A) and reject invalid graphs.
- Maintain execution statuses:
PENDING,RUNNING,COMPLETED,FAILED. - Execute independent tasks concurrently using a thread pool.
- Dynamically evaluate child tasks, starting them automatically once all their parent dependencies finish successfully.
Objects Required
TaskStatus(Enum mapping states)Task(Base node containing execution details and dependency links)TaskGraph(Wrapper managing task registrations and cycle validation checks)DagScheduler(Execution engine coordinating threads and monitoring parent completions)
TaskStatus Enum & Task Class
The TaskStatus tracks execution phases. The Task class encapsulates execution details, parent dependencies, and child routes.
public enum TaskStatus {
PENDING,
RUNNING,
COMPLETED,
FAILED
}
Let's define the Task class:
import java.util.ArrayList;
import java.util.List;
public class Task {
private final String id;
private final Runnable payload;
private final List<Task> dependencies;
private final List<Task> children;
private TaskStatus status;
public Task(String id, Runnable payload) {
this.id = id;
this.payload = payload;
this.dependencies = new ArrayList<>();
this.children = new ArrayList<>();
this.status = TaskStatus.PENDING;
}
public void addDependency(Task parent) {
dependencies.add(parent);
parent.addChild(this);
}
private void addChild(Task child) {
children.add(child);
}
public String getId() { return id; }
public Runnable getPayload() { return payload; }
public List<Task> getDependencies() { return dependencies; }
public List<Task> getChildren() { return children; }
public synchronized TaskStatus getStatus() { return status; }
public synchronized void setStatus(TaskStatus status) { this.status = status; }
}
The constructor maps identifiers and payloads. addDependency() updates both the task's parent list and the parent's child list to build the directed edges of the graph.
TaskGraph Class (Cycle Validation)
The TaskGraph maintains the set of tasks and verifies that the graph is acyclic before execution using a Depth-First Search (DFS) traversal.
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
public class TaskGraph {
private final List<Task> tasks = new ArrayList<>();
public void addTask(Task task) {
tasks.add(task);
}
public List<Task> getTasks() { return tasks; }
public boolean validate() {
Set<Task> visited = new HashSet<>();
Set<Task> recStack = new HashSet<>();
for (Task task : tasks) {
if (detectCycle(task, visited, recStack)) {
System.err.println("Validation Error: Circular dependency detected in the task graph!");
return false;
}
}
return true;
}
private boolean detectCycle(Task task, Set<Task> visited, Set<Task> recStack) {
if (recStack.contains(task)) return true;
if (visited.contains(task)) return false;
visited.add(task);
recStack.add(task);
for (Task dependency : task.getDependencies()) {
if (detectCycle(dependency, visited, recStack)) {
return true;
}
}
recStack.remove(task);
return false;
}
}
Here is an explanation of the core operations in the TaskGraph class:
addTask()registers tasks in the collection registry.validate()initializes search states and traverses all entries.detectCycle()runs the DFS traversal. It adds nodes to the recursion stackrecStackas it traverses down, and removes them as it backtracks. If it encounters a node already in the stack, it indicates a cycle.
DagScheduler Class
The DagScheduler handles execution logic. It tracks dependency counters dynamically and executes ready tasks concurrently using a thread pool.
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;
public class DagScheduler {
private final ExecutorService executorService;
private final Map<String, AtomicInteger> remainingDependencies;
public DagScheduler(int poolSize) {
this.executorService = Executors.newFixedThreadPool(poolSize);
this.remainingDependencies = new ConcurrentHashMap<>();
}
public void executeGraph(TaskGraph graph) {
if (!graph.validate()) {
throw new IllegalArgumentException("Cannot execute: Invalid graph containing circular dependencies.");
}
System.out.println("Graph validated. Starting execution...");
// Initialize dependency counters for each task
for (Task task : graph.getTasks()) {
remainingDependencies.put(task.getId(), new AtomicInteger(task.getDependencies().size()));
}
// Find and submit tasks that have no dependencies (counter == 0)
for (Task task : graph.getTasks()) {
if (task.getDependencies().isEmpty()) {
submitTask(task);
}
}
}
private void submitTask(Task task) {
task.setStatus(TaskStatus.RUNNING);
executorService.submit(() -> {
try {
System.out.println("[Scheduler] Starting execution of task: " + task.getId());
task.getPayload().run();
task.setStatus(TaskStatus.COMPLETED);
System.out.println("[Scheduler] Completed task: " + task.getId());
// Trigger downstream dependencies
processTaskCompletion(task);
} catch (Exception e) {
task.setStatus(TaskStatus.FAILED);
System.err.println("[Scheduler] Task " + task.getId() + " failed: " + e.getMessage());
}
});
}
private void processTaskCompletion(Task parent) {
for (Task child : parent.getChildren()) {
AtomicInteger counter = remainingDependencies.get(child.getId());
if (counter != null) {
int remaining = counter.decrementAndGet();
if (remaining == 0) {
System.out.println("[Scheduler] All dependencies met for task: " + child.getId());
submitTask(child);
}
}
}
}
public void shutdown() {
executorService.shutdown();
}
}
Here is an explanation of the core operations in the DagScheduler class:
- The constructor configures the thread pool and initializes concurrent maps using
AtomicIntegerto track dependencies safely. executeGraph()validates the graph, populates the dependency counters, identifies independent entry tasks (zero dependencies), and submits them to the thread pool.submitTask()executes the task runnable payload in the thread pool, updates statuses, and triggers downstream checks.processTaskCompletion()decrements the dependency counters of all child tasks. If a child's counter hits zero, indicating all its parent tasks have completed, it is submitted for execution.
Main Driver Class
This class tests our DAG scheduler by constructing a dependency graph, running executions, and validating cycle-detection logic with an invalid graph.
public class Main {
public static void main(String[] args) throws InterruptedException {
DagScheduler scheduler = new DagScheduler(3);
System.out.println("==========================================");
System.out.println("Scenario 1: Testing Successful DAG Execution");
System.out.println("==========================================");
// Define tasks:
// A
// / \
// B C
// \ /
// D
Task taskA = new Task("TaskA", () -> {
System.out.println("Executing Task A: Fetching raw telemetry logs.");
try { Thread.sleep(500); } catch (InterruptedException e) {}
});
Task taskB = new Task("TaskB", () -> {
System.out.println("Executing Task B: Extracting metrics from logs.");
try { Thread.sleep(500); } catch (InterruptedException e) {}
});
Task taskC = new Task("TaskC", () -> {
System.out.println("Executing Task C: Indexing log entries.");
try { Thread.sleep(100); } catch (InterruptedException e) {}
});
Task taskD = new Task("TaskD", () -> {
System.out.println("Executing Task D: Consolidating final report.");
});
// Set up dependencies
taskB.addDependency(taskA);
taskC.addDependency(taskA);
taskD.addDependency(taskB);
taskD.addDependency(taskC);
TaskGraph successGraph = new TaskGraph();
successGraph.addTask(taskA);
successGraph.addTask(taskB);
successGraph.addTask(taskC);
successGraph.addTask(taskD);
scheduler.executeGraph(successGraph);
// Sleep to wait for executions to finish
Thread.sleep(2000);
System.out.println("\n==========================================");
System.out.println("Scenario 2: Testing Cycle Detection (Invalid Graph)");
System.out.println("==========================================");
Task t1 = new Task("T1", () -> System.out.println("Executing T1"));
Task t2 = new Task("T2", () -> System.out.println("Executing T2"));
// Circular edge: T1 depends on T2, T2 depends on T1
t1.addDependency(t2);
t2.addDependency(t1);
TaskGraph invalidGraph = new TaskGraph();
invalidGraph.addTask(t1);
invalidGraph.addTask(t2);
try {
scheduler.executeGraph(invalidGraph);
} catch (Exception e) {
System.out.println("Caught Expected Exception: " + e.getMessage());
}
scheduler.shutdown();
}
}
The main() driver constructs a dependency graph, verifies that the scheduler executes independent branches in parallel, and asserts that cyclic graphs are correctly detected and rejected.
Comments
Post a Comment