본문 바로가기
-----ETC-----/C++ 성능 최적화 및 고급 테크닉 시리즈

[C++ 성능 최적화 및 고급 테크닉] Day 11: 표현식 템플릿 (Expression Templates)

by cogito21_cpp 2024. 8. 1.
반응형

표현식 템플릿이란?

표현식 템플릿(Expression Templates)은 C++ 템플릿 메타프로그래밍 기법으로, 수식의 계산을 최적화하여 성능을 향상시키는 방법입니다. 주로 수치 연산 라이브러리에서 사용되며, 연산 중간 결과를 저장하지 않고 최적화된 코드로 변환합니다.

 

표현식 템플릿의 작동 방식

표현식 템플릿은 연산자 오버로딩과 템플릿을 사용하여 연산의 중간 결과를 표현식 트리로 변환합니다. 그런 다음 이 트리를 평가하여 최적화된 코드를 생성합니다.

 

기본 예제

간단한 벡터 클래스를 사용하여 표현식 템플릿을 구현해보겠습니다.

 

1. 벡터 클래스 정의

먼저 기본 벡터 클래스를 정의합니다.

#include <iostream>
#include <vector>

class Vector {
public:
    std::vector<double> data;

    Vector(size_t size) : data(size) {}

    size_t size() const { return data.size(); }

    double& operator[](size_t index) { return data[index]; }
    const double& operator[](size_t index) const { return data[index]; }

    Vector& operator=(const Vector& other) {
        if (this != &other) {
            data = other.data;
        }
        return *this;
    }
};

 

2. 연산자 오버로딩

벡터 연산을 위한 연산자를 오버로딩합니다.

Vector operator+(const Vector& lhs, const Vector& rhs) {
    Vector result(lhs.size());
    for (size_t i = 0; i < lhs.size(); ++i) {
        result[i] = lhs[i] + rhs[i];
    }
    return result;
}

 

이 방식은 중간 결과를 저장하기 때문에 비효율적입니다. 이를 표현식 템플릿으로 개선해보겠습니다.

 

3. 표현식 템플릿 구현

표현식 템플릿을 사용하여 중간 결과를 저장하지 않고 연산을 최적화합니다.

 

3.1. 표현식 템플릿 클래스

template <typename L, typename R>
class VectorAdd {
public:
    const L& lhs;
    const R& rhs;

    VectorAdd(const L& lhs, const R& rhs) : lhs(lhs), rhs(rhs) {}

    double operator[](size_t index) const {
        return lhs[index] + rhs[index];
    }

    size_t size() const { return lhs.size(); }
};

 

3.2. 연산자 오버로딩

벡터 연산을 위한 연산자를 오버로딩합니다.

template <typename L, typename R>
VectorAdd<L, R> operator+(const L& lhs, const R& rhs) {
    return VectorAdd<L, R>(lhs, rhs);
}

 

3.3. 벡터 클래스에서 표현식 지원

벡터 클래스에서 표현식을 처리하도록 연산자를 오버로딩합니다.

class Vector {
public:
    std::vector<double> data;

    Vector(size_t size) : data(size) {}

    size_t size() const { return data.size(); }

    double& operator[](size_t index) { return data[index]; }
    const double& operator[](size_t index) const { return data[index]; }

    Vector& operator=(const Vector& other) {
        if (this != &other) {
            data = other.data;
        }
        return *this;
    }

    template <typename Expr>
    Vector& operator=(const Expr& expr) {
        for (size_t i = 0; i < size(); ++i) {
            data[i] = expr[i];
        }
        return *this;
    }
};

 

전체 코드

전체 코드는 다음과 같습니다.

#include <iostream>
#include <vector>

class Vector {
public:
    std::vector<double> data;

    Vector(size_t size) : data(size) {}

    size_t size() const { return data.size(); }

    double& operator[](size_t index) { return data[index]; }
    const double& operator[](size_t index) const { return data[index]; }

    Vector& operator=(const Vector& other) {
        if (this != &other) {
            data = other.data;
        }
        return *this;
    }

    template <typename Expr>
    Vector& operator=(const Expr& expr) {
        for (size_t i = 0; i < size(); ++i) {
            data[i] = expr[i];
        }
        return *this;
    }
};

template <typename L, typename R>
class VectorAdd {
public:
    const L& lhs;
    const R& rhs;

    VectorAdd(const L& lhs, const R& rhs) : lhs(lhs), rhs(rhs) {}

    double operator[](size_t index) const {
        return lhs[index] + rhs[index];
    }

    size_t size() const { return lhs.size(); }
};

template <typename L, typename R>
VectorAdd<L, R> operator+(const L& lhs, const R& rhs) {
    return VectorAdd<L, R>(lhs, rhs);
}

int main() {
    Vector v1(3), v2(3), v3(3);
    v1[0] = 1.0; v1[1] = 2.0; v1[2] = 3.0;
    v2[0] = 4.0; v2[1] = 5.0; v2[2] = 6.0;

    v3 = v1 + v2;

    std::cout << "v3: ";
    for (size_t i = 0; i < v3.size(); ++i) {
        std::cout << v3[i] << " ";
    }
    std::cout << std::endl;

    return 0;
}

 

이 코드는 표현식 템플릿을 사용하여 벡터 덧셈을 최적화합니다.

 

실습 문제

문제 1: 표현식 템플릿 확장

다음 코드에서 표현식 템플릿을 확장하여 벡터의 덧셈과 뺄셈을 지원하도록 수정하세요.

 

main.cpp

#include <iostream>
#include <vector>

class Vector {
public:
    std::vector<double> data;

    Vector(size_t size) : data(size) {}

    size_t size() const { return data.size(); }

    double& operator[](size_t index) { return data[index]; }
    const double& operator[](size_t index) const { return data[index]; }

    Vector& operator=(const Vector& other) {
        if (this != &other) {
            data = other.data;
        }
        return *this;
    }

    template <typename Expr>
    Vector& operator=(const Expr& expr) {
        for (size_t i = 0; i < size(); ++i) {
            data[i] = expr[i];
        }
        return *this;
    }
};

template <typename L, typename R>
class VectorAdd {
public:
    const L& lhs;
    const R& rhs;

    VectorAdd(const L& lhs, const R& rhs) : lhs(lhs), rhs(rhs) {}

    double operator[](size_t index) const {
        return lhs[index] + rhs[index];
    }

    size_t size() const { return lhs.size(); }
};

template <typename L, typename R>
VectorAdd<L, R> operator+(const L& lhs, const R& rhs) {
    return VectorAdd<L, R>(lhs, rhs);
}

int main() {
    Vector v1(3), v2(3), v3(3);
    v1[0] = 1.0; v1[1] = 2.0; v1[2] = 3.0;
    v2[0] = 4.0; v2[1] = 5.0; v2[2] = 6.0;

    v3 = v1 + v2;

    std::cout << "v3: ";
    for (size_t i = 0; i < v3.size(); ++i) {
        std::cout << v3[i] << " ";
    }
    std::cout << std::endl;

    return 0;
}

 

해답:

main.cpp (벡터 뺄셈 지원 추가)

#include <iostream>
#include <vector>

class Vector {
public:
    std::vector<double> data;

    Vector(size_t size) : data(size) {}

    size_t size() const { return data.size(); }

    double& operator[](size_t index) { return data[index]; }
    const double& operator[](size_t index) const { return data[index]; }

    Vector& operator=(const Vector& other) {
        if (this != &other) {
            data = other.data;
        }
        return *this;
    }

    template <typename Expr>
    Vector& operator=(const Expr& expr

) {
        for (size_t i = 0; i < size(); ++i) {
            data[i] = expr[i];
        }
        return *this;
    }
};

template <typename L, typename R>
class VectorAdd {
public:
    const L& lhs;
    const R& rhs;

    VectorAdd(const L& lhs, const R& rhs) : lhs(lhs), rhs(rhs) {}

    double operator[](size_t index) const {
        return lhs[index] + rhs[index];
    }

    size_t size() const { return lhs.size(); }
};

template <typename L, typename R>
class VectorSub {
public:
    const L& lhs;
    const R& rhs;

    VectorSub(const L& lhs, const R& rhs) : lhs(lhs), rhs(rhs) {}

    double operator[](size_t index) const {
        return lhs[index] - rhs[index];
    }

    size_t size() const { return lhs.size(); }
};

template <typename L, typename R>
VectorAdd<L, R> operator+(const L& lhs, const R& rhs) {
    return VectorAdd<L, R>(lhs, rhs);
}

template <typename L, typename R>
VectorSub<L, R> operator-(const L& lhs, const R& rhs) {
    return VectorSub<L, R>(lhs, rhs);
}

int main() {
    Vector v1(3), v2(3), v3(3);
    v1[0] = 1.0; v1[1] = 2.0; v1[2] = 3.0;
    v2[0] = 4.0; v2[1] = 5.0; v2[2] = 6.0;

    v3 = v1 + v2 - v1;

    std::cout << "v3: ";
    for (size_t i = 0; i < v3.size(); ++i) {
        std::cout << v3[i] << " ";
    }
    std::cout << std::endl;

    return 0;
}

 

이제 열한 번째 날의 학습을 마쳤습니다. 표현식 템플릿의 중요성과 이를 구현하는 방법을 이해하고, 실습 문제를 통해 이를 적용하는 방법을 학습했습니다.

질문이나 피드백이 있으면 언제든지 댓글로 남겨 주세요. 내일은 "C++11/14/17/20의 새로운 기능 활용"에 대해 학습하겠습니다.

반응형