diff pytouhou/ui/shader.pyx @ 424:f4d76d3d6f2a

Make the Shader class use cython too.
author Emmanuel Gil Peyrot <linkmauve@linkmauve.fr>
date Tue, 16 Jul 2013 21:07:15 +0200
parents pytouhou/ui/shader.py@346614f788f1
children 878273a984c4
line wrap: on
line diff
copy from pytouhou/ui/shader.py
copy to pytouhou/ui/shader.pyx
--- a/pytouhou/ui/shader.py
+++ b/pytouhou/ui/shader.pyx
@@ -1,3 +1,4 @@
+# -*- encoding: utf-8 -*-
 #
 # Copyright Tristam Macdonald 2008.
 # Copyright Emmanuel Gil Peyrot 2012.
@@ -8,25 +9,28 @@
 # Source: https://swiftcoder.wordpress.com/2008/12/19/simple-glsl-wrapper-for-pyglet/
 #
 
-from pyglet.gl import (glCreateProgram, glCreateShader, GL_VERTEX_SHADER,
-                       GL_FRAGMENT_SHADER, glShaderSource, glCompileShader,
-                       glGetShaderiv, GL_COMPILE_STATUS, GL_INFO_LOG_LENGTH,
-                       glGetShaderInfoLog, glAttachShader, glLinkProgram,
-                       glGetProgramiv, glGetProgramInfoLog, GL_LINK_STATUS,
-                       glUseProgram, glGetUniformLocation, glUniform1f,
-                       glUniform2f, glUniform3f, glUniform4f, glUniform1i,
-                       glUniform2i, glUniform3i, glUniform4i,
-                       glUniformMatrix4fv, glBindAttribLocation)
+from pytouhou.lib.opengl cimport \
+         (glCreateProgram, glCreateShader, GL_VERTEX_SHADER,
+          GL_FRAGMENT_SHADER, glShaderSource, glCompileShader, glGetShaderiv,
+          GL_COMPILE_STATUS, GL_INFO_LOG_LENGTH, glGetShaderInfoLog,
+          glAttachShader, glLinkProgram, glGetProgramiv, glGetProgramInfoLog,
+          GL_LINK_STATUS, glUseProgram, glGetUniformLocation, glUniform1fv,
+          glUniform4fv, glUniformMatrix4fv, glBindAttribLocation, GLint,
+          GLuint, GLchar, GLfloat, GLenum)
 
-from ctypes import (c_char, c_char_p, c_int, POINTER, byref, cast,
-                    create_string_buffer)
+from libc.stdlib cimport malloc, free
+from pytouhou.utils.matrix cimport Matrix, matrix_to_floats
 
 
 class GLSLException(Exception):
     pass
 
 
-class Shader(object):
+cdef class Shader:
+    cdef GLuint handle
+    cdef bint linked
+    cdef dict location_cache
+
     # vert and frag take arrays of source strings the arrays will be
     # concattenated into one string by OpenGL
     def __init__(self, vert=None, frag=None):
@@ -39,9 +43,9 @@ class Shader(object):
         self.location_cache = {}
 
         # create the vertex shader
-        self.createShader(vert, GL_VERTEX_SHADER)
+        self.create_shader(vert[0], GL_VERTEX_SHADER)
         # create the fragment shader
-        self.createShader(frag, GL_FRAGMENT_SHADER)
+        self.create_shader(frag[0], GL_FRAGMENT_SHADER)
 
         #TODO: put those elsewhere.
         glBindAttribLocation(self.handle, 0, 'in_position')
@@ -51,119 +55,94 @@ class Shader(object):
         # attempt to link the program
         self.link()
 
-    def load_source(self, path):
-        with open(path, 'rb') as file:
-            source = file.read()
-        return source
-
-    def createShader(self, strings, type):
-        count = len(strings)
-        # if we have no source code, ignore this shader
-        if count < 1:
-            return
+    cdef void create_shader(self, const GLchar *string, GLenum shader_type):
+        cdef GLint temp
+        cdef const GLchar **strings = &string
 
         # create the shader handle
-        shader = glCreateShader(type)
+        shader = glCreateShader(shader_type)
 
-        # convert the source strings into a ctypes pointer-to-char array, and upload them
-        # this is deep, dark, dangerous black magick - don't try stuff like this at home!
-        src = (c_char_p * count)(*strings)
-        glShaderSource(shader, count, cast(byref(src), POINTER(POINTER(c_char))), None)
+        # upload the source strings
+        glShaderSource(shader, 1, strings, NULL)
 
         # compile the shader
         glCompileShader(shader)
 
-        temp = c_int(0)
         # retrieve the compile status
-        glGetShaderiv(shader, GL_COMPILE_STATUS, byref(temp))
+        glGetShaderiv(shader, GL_COMPILE_STATUS, &temp)
 
         # if compilation failed, print the log
         if not temp:
             # retrieve the log length
-            glGetShaderiv(shader, GL_INFO_LOG_LENGTH, byref(temp))
+            glGetShaderiv(shader, GL_INFO_LOG_LENGTH, &temp)
             # create a buffer for the log
-            buffer = create_string_buffer(temp.value)
+            temp_buf = <GLchar*>malloc(temp * sizeof(GLchar))
             # retrieve the log text
-            glGetShaderInfoLog(shader, temp, None, buffer)
+            glGetShaderInfoLog(shader, temp, NULL, temp_buf)
+            buf = temp_buf[:temp]
+            free(temp_buf)
             # print the log to the console
-            raise GLSLException(buffer.value)
+            raise GLSLException(buf)
         else:
             # all is well, so attach the shader to the program
-            glAttachShader(self.handle, shader);
+            glAttachShader(self.handle, shader)
 
-    def link(self):
+    cdef void link(self):
+        cdef GLint temp
+
         # link the program
         glLinkProgram(self.handle)
 
-        temp = c_int(0)
         # retrieve the link status
-        glGetProgramiv(self.handle, GL_LINK_STATUS, byref(temp))
+        glGetProgramiv(self.handle, GL_LINK_STATUS, &temp)
 
         # if linking failed, print the log
         if not temp:
             #   retrieve the log length
-            glGetProgramiv(self.handle, GL_INFO_LOG_LENGTH, byref(temp))
+            glGetProgramiv(self.handle, GL_INFO_LOG_LENGTH, &temp)
             # create a buffer for the log
-            buffer = create_string_buffer(temp.value)
+            temp_buf = <GLchar*>malloc(temp * sizeof(GLchar))
             # retrieve the log text
-            glGetProgramInfoLog(self.handle, temp, None, buffer)
+            glGetProgramInfoLog(self.handle, temp, NULL, temp_buf)
+            buf = temp_buf[:temp]
+            free(temp_buf)
             # print the log to the console
-            raise GLSLException(buffer.value)
+            raise GLSLException(buf)
         else:
             # all is well, so we are linked
             self.linked = True
 
+    cdef GLint get_uniform_location(self, name):
+        if name not in self.location_cache:
+            loc = glGetUniformLocation(self.handle, name)
+            if loc == -1:
+                raise GLSLException('Undefined {} uniform.'.format(name))
+            self.location_cache[name] = loc
+        return self.location_cache[name]
+
     def bind(self):
         # bind the program
         glUseProgram(self.handle)
 
-    @classmethod
-    def unbind(self):
-        # unbind whatever program is currently bound
-        glUseProgram(0)
-
-    def get_uniform_location(self, name):
-        try:
-            return self.location_cache[name]
-        except KeyError:
-            loc = glGetUniformLocation(self.handle, name)
-            if loc == -1:
-                raise GLSLException #TODO
-            self.location_cache[name] = loc
-            return loc
-
     # upload a floating point uniform
     # this program must be currently bound
-    def uniformf(self, name, *vals):
-        # check there are 1-4 values
-        if len(vals) in range(1, 5):
-            # select the correct function
-            { 1 : glUniform1f,
-                2 : glUniform2f,
-                3 : glUniform3f,
-                4 : glUniform4f
-                # retrieve the uniform location, and set
-            }[len(vals)](self.get_uniform_location(name), *vals)
+    def uniform_1(self, name, GLfloat val):
+        glUniform1fv(self.get_uniform_location(name), 1, &val)
 
-    # upload an integer uniform
-    # this program must be currently bound
-    def uniformi(self, name, *vals):
-        # check there are 1-4 values
-        if len(vals) in range(1, 5):
-            # select the correct function
-            { 1 : glUniform1i,
-                2 : glUniform2i,
-                3 : glUniform3i,
-                4 : glUniform4i
-                # retrieve the uniform location, and set
-            }[len(vals)](self.get_uniform_location(name), *vals)
+    # upload a vec4 uniform
+    def uniform_4(self, name, GLfloat a, GLfloat b, GLfloat c, GLfloat d):
+        cdef GLfloat vals[4]
+        vals[0] = a
+        vals[1] = b
+        vals[2] = c
+        vals[3] = d
+        glUniform4fv(self.get_uniform_location(name), 1, vals)
 
     # upload a uniform matrix
     # works with matrices stored as lists,
     # as well as euclid matrices
-    def uniform_matrixf(self, name, mat):
-        # obtian the uniform location
+    def uniform_matrix(self, name, Matrix mat):
+        # obtain the uniform location
         loc = self.get_uniform_location(name)
         # uplaod the 4x4 floating point matrix
-        glUniformMatrix4fv(loc, 1, False, mat)
-
+        glUniformMatrix4fv(loc, 1, False, matrix_to_floats(mat))