Note
Click here to download the full example code
Automatically use a function at insert or selectΒΆ
Sometimes the application wants to apply a function in an insert or in a select.
For example, the application might need the geometry with lat/lon coordinates while they
are projected in the DB. To avoid having to always tweak the query with a
ST_Transform()
, it is possible to define a TypeDecorator
11 from sqlalchemy import Column
12 from sqlalchemy import Integer
13 from sqlalchemy import MetaData
14 from sqlalchemy import func
15 from sqlalchemy import text
16 from sqlalchemy.ext.declarative import declarative_base
17 from sqlalchemy.types import TypeDecorator
18
19 from geoalchemy2 import Geometry
20 from geoalchemy2 import shape
21
22 # Tests imports
23 from tests import test_only_with_dialects
24
25 metadata = MetaData()
26
27 Base = declarative_base(metadata=metadata)
28
29
30 class TransformedGeometry(TypeDecorator):
31 """This class is used to insert a ST_Transform() in each insert or select."""
32 impl = Geometry
33
34 cache_ok = True
35
36 def __init__(self, db_srid, app_srid, **kwargs):
37 kwargs["srid"] = db_srid
38 super().__init__(**kwargs)
39 self.app_srid = app_srid
40 self.db_srid = db_srid
41
42 def column_expression(self, col):
43 """The column_expression() method is overridden to set the correct type.
44
45 This is needed so that the returned element will also be decorated. In this case we don't
46 want to transform it again afterwards so we set the same SRID to both the ``db_srid`` and
47 ``app_srid`` arguments.
48 Without this the SRID of the WKBElement would be wrong.
49 """
50 return getattr(func, self.impl.as_binary)(
51 func.ST_Transform(col, self.app_srid),
52 type_=self.__class__(db_srid=self.app_srid, app_srid=self.app_srid)
53 )
54
55 def bind_expression(self, bindvalue):
56 return func.ST_Transform(
57 self.impl.bind_expression(bindvalue), self.db_srid,
58 type_=self,
59 )
60
61
62 class ThreeDGeometry(TypeDecorator):
63 """This class is used to insert a ST_Force3D() in each insert."""
64 impl = Geometry
65
66 cache_ok = True
67
68 def column_expression(self, col):
69 """The column_expression() method is overridden to set the correct type.
70
71 This is not needed in this example but it is needed if one wants to override other methods
72 of the TypeDecorator class, like ``process_result_value()`` for example.
73 """
74 return getattr(func, self.impl.as_binary)(col, type_=self)
75
76 def bind_expression(self, bindvalue):
77 return func.ST_Force3D(
78 self.impl.bind_expression(bindvalue),
79 type=self,
80 )
81
82
83 class Point(Base):
84 __tablename__ = "point"
85 id = Column(Integer, primary_key=True)
86 raw_geom = Column(Geometry(srid=4326, geometry_type="POINT"))
87 geom = Column(
88 TransformedGeometry(
89 db_srid=2154, app_srid=4326, geometry_type="POINT"))
90 three_d_geom = Column(
91 ThreeDGeometry(srid=4326, geometry_type="POINTZ", dimension=3))
92
93
94 def check_wkb(wkb, x, y):
95 pt = shape.to_shape(wkb)
96 assert round(pt.x, 5) == x
97 assert round(pt.y, 5) == y
98
99
100 @test_only_with_dialects("postgresql")
101 class TestTypeDecorator():
102
103 def _create_one_point(self, session, conn):
104 metadata.drop_all(conn, checkfirst=True)
105 metadata.create_all(conn)
106
107 # Create new point instance
108 p = Point()
109 p.raw_geom = "SRID=4326;POINT(5 45)"
110 p.geom = "SRID=4326;POINT(5 45)"
111 p.three_d_geom = "SRID=4326;POINT(5 45)" # Insert 2D geometry into 3D column
112
113 # Insert point
114 session.add(p)
115 session.flush()
116 session.expire(p)
117
118 return p.id
119
120 def test_transform(self, session, conn):
121 self._create_one_point(session, conn)
122
123 # Query the point and check the result
124 pt = session.query(Point).one()
125 assert pt.id == 1
126 assert pt.raw_geom.srid == 4326
127 check_wkb(pt.raw_geom, 5, 45)
128
129 assert pt.geom.srid == 4326
130 check_wkb(pt.geom, 5, 45)
131
132 # Check that the data is correct in DB using raw query
133 q = text("SELECT id, ST_AsEWKT(geom) AS geom FROM point;")
134 res_q = session.execute(q).fetchone()
135 assert res_q.id == 1
136 assert res_q.geom == "SRID=2154;POINT(857581.899319668 6435414.7478354)"
137
138 # Compare geom, raw_geom with auto transform and explicit transform
139 pt_trans = session.query(
140 Point,
141 Point.raw_geom,
142 func.ST_Transform(Point.raw_geom, 2154).label("trans"),
143 ).one()
144
145 assert pt_trans[0].id == 1
146
147 assert pt_trans[0].geom.srid == 4326
148 check_wkb(pt_trans[0].geom, 5, 45)
149
150 assert pt_trans[0].raw_geom.srid == 4326
151 check_wkb(pt_trans[0].raw_geom, 5, 45)
152
153 assert pt_trans[1].srid == 4326
154 check_wkb(pt_trans[1], 5, 45)
155
156 assert pt_trans[2].srid == 2154
157 check_wkb(pt_trans[2], 857581.89932, 6435414.74784)
158
159 def test_force_3d(self, session, conn):
160 self._create_one_point(session, conn)
161
162 # Query the point and check the result
163 pt = session.query(Point).one()
164
165 assert pt.id == 1
166 assert pt.three_d_geom.srid == 4326
167 assert pt.three_d_geom.desc.lower() == (
168 '01010000a0e6100000000000000000144000000000008046400000000000000000')