This article describes an advanced data structure called Segment Tree and provides implementation in the popular modern language Golang. We will solve a problem in SPOJ to show how we can use the Segment Tree in competitive programming.

General Idea

Segment Tree is an advanced tree data structure that allows you to answer range queries efficiently and modify the array quickly. Segment Tree is commonly used in competitive programming to solve problems in O(n log n) time. That includes finding the sum of consecutive array elements, minimum or maximum in the range, maximum consecutive range sum, etc. This structure is more flexible than Fenwick Tree and allows you to solve more complex problems but is more challenging in implementation.

Segment Tree Structure

Let’s review the simplest problem that can be solved with the Segment tree. You are given an array arr, and you should handle two types of queries:

  • find the sum of the array elements between the indices l and r;
  • modify an element of the array.

You should process these queries in O(log n) time. Segment Tree can solve such a problem much more efficiently than brute force solution, which will update the element in O(1) time but find the sum in O(n) time.

Let’s imagine we have an array [7, 14, -5, 2]. The root node will contain the sum of the whole array, 18. Its left child will have the sum of the first two elements 7+14=21, and the right child will be -5+2=-3. You can see an example Segment Tree for this problem below.

Segment Tree

To represent such a tree in your program, we can use a single-dimensional array with length 4n, where n is the array’s length. We can use this formula to express the dependency between parents and children. Let the parent node be a node with index i inside the array. His left child will be a node with index i*2, and right child will be a node with index i*2+1. With such a formula, we will utilize almost all cells inside the array and benefit from memory proximity. The example is presented below.

Segment Tree Array

Golang implementation

Problem description

In this section, we will solve a more complex problem from the popular online problem catalog SPOJ using popular modern language Golang. You can find the problem here (link).

Data structure definition

We should store a more complex data structure inside the Segment Tree node to find the maximum range sum. The idea is to store prefix, suffix, total sum, and best range sum for each node:

type SegmentTreeData struct {
	prefix int
	suffix int
	best   int
	total  int
}

Let’s review the node with length 1. All the properties will be the same because the best prefix of the array with length 1 equals the suffix, total sum and it’s the best range sum.

func createData(value int) SegmentTreeData {
	return SegmentTreeData{
		prefix: value,
		suffix: value,
		best:   value,
		total:  value,
	}
}

The algorithm becomes more interesting with bigger segments. We can merge two nodes by using this algorithm:

  • prefix = max(left.prefix, left.totalSum + right.prefix);
  • suffix = max(right.suffix, right.totalSum + left.suffix);
  • totalSum = left.totalSum + right.totalSum;
  • best = max(left.best, right.best, prefix, suffix, left.suffix + right.prefix);
func merge(left, right SegmentTreeData) SegmentTreeData {
	total := left.total + right.total
	prefix := max(left.prefix, left.total+right.prefix)
	suffix := max(right.suffix, right.total+left.suffix)
	best := max(
		left.best,
		right.best,
		prefix,
		suffix,
		left.suffix+right.prefix,
	)

	return SegmentTreeData{
		prefix,
		suffix,
		best,
		total,
	}
}

Segment Tree data structure should consist of n * 4 nodes, where n is the array’s length.

type SegmentTree struct {
	n    int
	data []SegmentTreeData
}

Build initial data structure

To build an initial data structure, we should recursively dive deep from the top of the tree to the bottom. A leaf node can be easily constructed by using createData defined before. Otherwise, we should build left and right nodes and then merge their data structures using the merge method. This algorithm will run in O(n) time because the length of the tree is 4 * n.

func Build(arr []int) *SegmentTree {
	n := len(arr)
	length := n * 4
	data := make([]SegmentTreeData, length)
	tree := &SegmentTree{
		n,
		data,
	}
	tree.build(arr, 1, 1, n)

	return tree
}

func (tree *SegmentTree) build(arr []int, index int, left int, right int) {
	if left > right {
		return
	} else if left == right {
		tree.data[index] = createData(arr[left-1])
	} else {
		middle := (left + right) / 2

		tree.build(arr, index*2, left, middle)
		tree.build(arr, index*2+1, middle+1, right)
		tree.data[index] = merge(tree.data[index*2], tree.data[index*2+1])
	}
}

Update single element

The update operation is not much different from the build and runs in O(log n) time. Firstly, we should find and update the element with the new value. Secondly, we should go up and update all the nodes that were impacted. To update the node, we should, as before, merge its left and right children.

func (tree *SegmentTree) Update(x, y int) {
	tree.update(1, 1, tree.n, x, y)
}

func (tree *SegmentTree) update(index int, left int, right int, updateIndex int, updateValue int) {
	if left > right || left > updateIndex || right < updateIndex {
		return
	} else if left == right {
		tree.data[index] = createData(updateValue)
	} else {
		middle := (left + right) / 2

		tree.update(index*2, left, middle, updateIndex, updateValue)
		tree.update(index*2+1, middle+1, right, updateIndex, updateValue)
		tree.data[index] = merge(tree.data[index*2], tree.data[index*2+1])
	}
}

Find max range sum

Finally, we should answer queries in O(log n) time using our tree data structure. Once again, we have 3 cases here:

  • The requested range matches to node range. We can return its internal data.
  • The requested range is smaller than the node’s and lies in one of its children. We can recursively call the function inside that children.
  • The requested range is smaller than the node’s and occupies both children. We should recursively build answers for both of them and merge the result using the merge function.
func (tree *SegmentTree) Find(x, y int) int {
	return tree.find(1, 1, tree.n, x, y).best
}

func (tree *SegmentTree) find(index int, left int, right int, findLeft int, findRight int) SegmentTreeData {
	if left == findLeft && right == findRight {
		return tree.data[index]
	} else {
		middle := (left + right) / 2

		if findRight <= middle {
			return tree.find(index*2, left, middle, findLeft, findRight)
		} else if findLeft > middle {
			return tree.find(index*2+1, middle+1, right, findLeft, findRight)
		} else {
			leftResult := tree.find(index*2, left, middle, findLeft, min(middle, findRight))
			rightResult := tree.find(index*2+1, middle+1, right, max(findLeft, middle+1), findRight)
			mergedResult := merge(leftResult, rightResult)
			return mergedResult
		}
	}
}

IO processing

To solve the competitive programming problem, we should read the input from stdin and output the result to stdout. Golang lacks clean console io support. That’s why we should write some boilerplate code by ourselves. First, we should use buffers to read from stdin and write to stdout. Secondly, to read the entire array, we should read the whole line as a string, trim it, and split it by spaces. Then we should convert this string array to an int array. The single number read and write can be done using the fmt package.

var reader *bufio.Reader = bufio.NewReader(os.Stdin)
var writer *bufio.Writer = bufio.NewWriter(os.Stdout)

func readInt() int {
	var value int
	fmt.Fscanf(reader, "%d\n", &value)

	return value
}

func writeInt(value int) {
	fmt.Fprintln(writer, value)
}

func readArray(n int) []int {
	line, err := reader.ReadString('\n')
	if err != nil {
		panic(err)
	}

	stringArray := strings.Split(strings.TrimSpace(line), " ")
	if len(stringArray) != n {
		panic(fmt.Errorf("Expected input array to be of size %d, but was %d", n, len(stringArray)))
	}

	arr := make([]int, n)
	for i := 0; i < n; i++ {
		value, err := strconv.Atoi(stringArray[i])
		if err != nil {
			panic(err)
		}

		arr[i] = value
	}

	return arr
}

Main function

The latest part is the main function. First, we should flush our buffers to output the result correctly. Then we can read the input array and build the Segment Tree data structure using the Build function defined before.

To process each query, we can use either tree.Update or tree.Find function. The result of find operation should be written to the stdout.

func main() {
	defer writer.Flush()

	n := readInt()
	arr := readArray(n)
	tree := Build(arr)

	m := readInt()
	for i := 0; i < m; i++ {
		query := readArray(3)
		t := query[0]
		x := query[1]
		y := query[2]

		if t == 0 {
			tree.Update(x, y)
		} else {
			value := tree.Find(x, y)
			writeInt(value)
		}
	}
}

Final code

You can now submit your solution to the SPOJ website and verify the result. The system should accept it unless you missed some parts in your implementation. My solution took 0.09 seconds and 14M memory. That’s enough to be considered successful. The time limit for this problem is 1 second and 1536M memory.

package main

import (
	"bufio"
	"fmt"
	"os"
	"strconv"
	"strings"
)

type SegmentTreeData struct {
	prefix int
	suffix int
	best   int
	total  int
}

func createData(value int) SegmentTreeData {
	return SegmentTreeData{
		prefix: value,
		suffix: value,
		best:   value,
		total:  value,
	}
}

func merge(left, right SegmentTreeData) SegmentTreeData {
	total := left.total + right.total
	prefix := max(left.prefix, left.total+right.prefix)
	suffix := max(right.suffix, right.total+left.suffix)
	best := max(
		left.best,
		right.best,
		prefix,
		suffix,
		left.suffix+right.prefix,
	)

	return SegmentTreeData{
		prefix,
		suffix,
		best,
		total,
	}
}

type SegmentTree struct {
	n    int
	data []SegmentTreeData
}

func Build(arr []int) *SegmentTree {
	n := len(arr)
	length := n * 4
	data := make([]SegmentTreeData, length)
	tree := &SegmentTree{
		n,
		data,
	}
	tree.build(arr, 1, 1, n)

	return tree
}

func (tree *SegmentTree) build(arr []int, index int, left int, right int) {
	if left > right {
		return
	} else if left == right {
		tree.data[index] = createData(arr[left-1])
	} else {
		middle := (left + right) / 2

		tree.build(arr, index*2, left, middle)
		tree.build(arr, index*2+1, middle+1, right)
		tree.data[index] = merge(tree.data[index*2], tree.data[index*2+1])
	}
}

func (tree *SegmentTree) Update(x, y int) {
	tree.update(1, 1, tree.n, x, y)
}

func (tree *SegmentTree) update(index int, left int, right int, updateIndex int, updateValue int) {
	if left > right || left > updateIndex || right < updateIndex {
		return
	} else if left == right {
		tree.data[index] = createData(updateValue)
	} else {
		middle := (left + right) / 2

		tree.update(index*2, left, middle, updateIndex, updateValue)
		tree.update(index*2+1, middle+1, right, updateIndex, updateValue)
		tree.data[index] = merge(tree.data[index*2], tree.data[index*2+1])
	}
}

func (tree *SegmentTree) Find(x, y int) int {
	return tree.find(1, 1, tree.n, x, y).best
}

func (tree *SegmentTree) find(index int, left int, right int, findLeft int, findRight int) SegmentTreeData {
	if left == findLeft && right == findRight {
		return tree.data[index]
	} else {
		middle := (left + right) / 2

		if findRight <= middle {
			return tree.find(index*2, left, middle, findLeft, findRight)
		} else if findLeft > middle {
			return tree.find(index*2+1, middle+1, right, findLeft, findRight)
		} else {
			leftResult := tree.find(index*2, left, middle, findLeft, min(middle, findRight))
			rightResult := tree.find(index*2+1, middle+1, right, max(findLeft, middle+1), findRight)
			mergedResult := merge(leftResult, rightResult)
			return mergedResult
		}
	}
}

func max(x int, rest ...int) int {
	mx := x

	for _, value := range rest {
		if mx < value {
			mx = value
		}
	}

	return mx
}

func min(x int, rest ...int) int {
	mn := x

	for _, value := range rest {
		if mn > value {
			mn = value
		}
	}

	return mn
}

var reader *bufio.Reader = bufio.NewReader(os.Stdin)
var writer *bufio.Writer = bufio.NewWriter(os.Stdout)

func readInt() int {
	var value int
	fmt.Fscanf(reader, "%d\n", &value)

	return value
}

func writeInt(value int) {
	fmt.Fprintln(writer, value)
}

func readArray(n int) []int {
	line, err := reader.ReadString('\n')
	if err != nil {
		panic(err)
	}

	stringArray := strings.Split(strings.TrimSpace(line), " ")
	if len(stringArray) != n {
		panic(fmt.Errorf("Expected input array to be of size %d, but was %d", n, len(stringArray)))
	}

	arr := make([]int, n)
	for i := 0; i < n; i++ {
		value, err := strconv.Atoi(stringArray[i])
		if err != nil {
			panic(err)
		}

		arr[i] = value
	}

	return arr
}

func main() {
	defer writer.Flush()

	n := readInt()
	arr := readArray(n)
	tree := Build(arr)

	m := readInt()
	for i := 0; i < m; i++ {
		query := readArray(3)
		t := query[0]
		x := query[1]
		y := query[2]

		if t == 0 {
			tree.Update(x, y)
		} else {
			value := tree.Find(x, y)
			writeInt(value)
		}
	}
}