1+ package com .devil .mabatis .service ;
2+
3+ import java .lang .annotation .Documented ;
4+ import java .lang .annotation .ElementType ;
5+ import java .lang .annotation .Retention ;
6+ import java .lang .annotation .RetentionPolicy ;
7+ import java .lang .annotation .Target ;
8+ import java .lang .reflect .InvocationTargetException ;
9+ import java .lang .reflect .Method ;
10+ import java .util .ArrayList ;
11+ import java .util .Arrays ;
12+ import java .util .Collection ;
13+ import java .util .Collections ;
14+ import java .util .HashMap ;
15+ import java .util .HashSet ;
16+ import java .util .List ;
17+ import java .util .Map ;
18+ import java .util .Map .Entry ;
19+ import java .util .Properties ;
20+ import java .util .Set ;
21+
22+ import org .apache .ibatis .binding .MapperMethod .ParamMap ;
23+ import org .apache .ibatis .builder .StaticSqlSource ;
24+ import org .apache .ibatis .executor .Executor ;
25+ import org .apache .ibatis .mapping .BoundSql ;
26+ import org .apache .ibatis .mapping .MappedStatement ;
27+ import org .apache .ibatis .mapping .ResultMap ;
28+ import org .apache .ibatis .mapping .ResultMapping ;
29+ import org .apache .ibatis .mapping .SqlCommandType ;
30+ import org .apache .ibatis .mapping .SqlSource ;
31+ import org .apache .ibatis .mapping .StatementType ;
32+ import org .apache .ibatis .plugin .Interceptor ;
33+ import org .apache .ibatis .plugin .Intercepts ;
34+ import org .apache .ibatis .plugin .Invocation ;
35+ import org .apache .ibatis .plugin .Plugin ;
36+ import org .apache .ibatis .plugin .Signature ;
37+ import org .apache .ibatis .scripting .xmltags .OgnlCache ;
38+ import org .apache .ibatis .session .ResultHandler ;
39+ import org .apache .ibatis .session .RowBounds ;
40+ import org .slf4j .Logger ;
41+ import org .slf4j .LoggerFactory ;
42+
43+ import com .devil .utils .ClzUtil ;
44+ import com .devil .utils .CommUtil ;
45+
46+ @ Intercepts ({
47+ @ Signature (type = Executor .class , method = "query" , args = { MappedStatement .class , Object .class ,
48+ RowBounds .class , ResultHandler .class }),
49+ @ Signature (type = Executor .class , method = "update" , args = { MappedStatement .class , Object .class }) })
50+
51+ public class ShardHelper implements Interceptor {
52+ private static final Logger log = LoggerFactory .getLogger (ShardHelper .class );
53+ private ShardStrategy strategy ;
54+ private Map <Class <?>, String > includeMap = new HashMap <>();
55+ private Set <String > includeSet = new HashSet <>();
56+
57+ @ Override
58+ public Object intercept (Invocation invocation ) throws Throwable {
59+ // 获取原始的ms
60+ Object [] args = invocation .getArgs ();
61+ MappedStatement ms = (MappedStatement ) args [0 ];
62+ String msid = ms .getId ();
63+ int lastDotIdx = msid .lastIndexOf ("." );
64+ String className = msid .substring (0 , lastDotIdx );
65+ boolean valid = false ;
66+ for (String str : includeSet ) {
67+ if (msid .startsWith (str )) {
68+ valid = true ;
69+ break ;
70+ }
71+ }
72+ if (!valid ) {
73+ return invocation .proceed ();
74+ }
75+ String methodName = msid .substring (lastDotIdx + 1 );
76+ Class <?> daoClass = Class .forName (className );
77+ Method method = ClzUtil .getMethod (daoClass , methodName );
78+ Shard shard = method .getAnnotation (Shard .class );
79+ if (shard != null ) {
80+ ShardType type = shard .value ();
81+ SqlSource sqlSource = ms .getSqlSource ();
82+ try {
83+ String tableName = includeMap .get (daoClass );
84+ if (type == ShardType .OneById ) {
85+ Object paramobj = args [1 ];
86+
87+ long id ;
88+ if (paramobj instanceof ParamMap ) {
89+ Object value = OgnlCache .getValue (shard .idKey (), paramobj );
90+ id =Long .valueOf (value .toString ());
91+ } else {
92+ id = (long ) paramobj ;
93+ }
94+
95+ String newTableName = strategy .buildTableNameById (tableName , id );
96+
97+ ShardSqlSource shardSqlSource = new ShardSqlSource (sqlSource , tableName , newTableName );
98+ ClzUtil .setField (ms , "sqlSource" , shardSqlSource );
99+ Object result =null ;
100+ try {
101+ result = invocation .proceed ();
102+ } catch (InvocationTargetException e ) {
103+ Throwable newe = CommUtil .getException (e , "com.mysql.jdbc.exceptions.jdbc4.MySQLSyntaxErrorException" );
104+ if (newe != null ) {
105+ if (ms .getSqlCommandType ()==SqlCommandType .INSERT ){
106+ StaticSqlSource createSqlSource = new StaticSqlSource (ms .getConfiguration (), "CREATE TABLE " +newTableName +" LIKE " +tableName );
107+ MappedStatement .Builder builder = new MappedStatement .Builder (ms .getConfiguration (), ms .getId () + "_CREATE" , createSqlSource , SqlCommandType .UPDATE );
108+ builder .resource (ms .getResource ());
109+ builder .fetchSize (1 );
110+ builder .statementType (StatementType .STATEMENT );
111+ builder .keyGenerator (ms .getKeyGenerator ());
112+ builder .timeout (ms .getTimeout ());
113+ //count查询返回值int
114+ List <ResultMap > resultMaps = new ArrayList <ResultMap >();
115+ List <ResultMapping > EMPTY_RESULTMAPPING = Collections .emptyList ();
116+ ResultMap resultMap = new ResultMap .Builder (ms .getConfiguration (), ms .getId (), int .class , EMPTY_RESULTMAPPING ).build ();
117+ resultMaps .add (resultMap );
118+ builder .resultMaps (resultMaps );
119+ builder .resultSetType (ms .getResultSetType ());
120+ builder .cache (ms .getCache ());
121+ builder .flushCacheRequired (ms .isFlushCacheRequired ());
122+ builder .useCache (ms .isUseCache ());
123+ args [0 ]=builder .build ();
124+ Object tmpresult = invocation .proceed ();
125+ System .err .println (tmpresult );
126+ args [0 ]=ms ;
127+ return invocation .proceed ();
128+ }
129+ }
130+ throw e ;
131+ }
132+ return result ;
133+
134+ } else if (type == ShardType .One ) {
135+ int tableid = 0 ;
136+ while (true ) {
137+ try {
138+ String newTableName = strategy .buildTableNameByIdx (includeMap .get (daoClass ), tableid );
139+ ShardSqlSource shardSqlSource = new ShardSqlSource (sqlSource , tableName , newTableName );
140+ ClzUtil .setField (ms , "sqlSource" , shardSqlSource );
141+ Object obj = invocation .proceed ();
142+ if (obj != null ) {
143+ return obj ;
144+ }
145+ tableid ++;
146+ } catch (InvocationTargetException e ) {
147+ Throwable newe = CommUtil .getException (e , "com.mysql.jdbc.exceptions.jdbc4.MySQLSyntaxErrorException" );
148+ if (newe != null ) {
149+ return null ;
150+ }
151+ throw e ;
152+ }
153+ }
154+ } else if (type == ShardType .Count ) {
155+ int tableid = 0 ;
156+ long cnt = 0 ;
157+ List <Long > resultlist = null ;
158+ while (true ) {
159+ try {
160+ String newTableName = strategy .buildTableNameByIdx (tableName ,tableid );
161+ ShardSqlSource shardSqlSource = new ShardSqlSource (sqlSource , tableName , newTableName );
162+ ClzUtil .setField (ms , "sqlSource" , shardSqlSource );
163+ resultlist = (List <Long >) invocation .proceed ();
164+ Long thiscnt = resultlist .get (0 );
165+ cnt += thiscnt ;
166+ tableid ++;
167+ } catch (InvocationTargetException e ) {
168+ Throwable newe = CommUtil .getException (e , "com.mysql.jdbc.exceptions.jdbc4.MySQLSyntaxErrorException" );
169+ if (newe != null ) {
170+ resultlist .set (0 , cnt );
171+ return resultlist ;
172+ }
173+ throw e ;
174+ }
175+ }
176+ } else if (type == ShardType .ManyById ) {
177+ Object paramobj = args [1 ];
178+ Object value = OgnlCache .getValue (shard .idKey (), paramobj );
179+ Collection <Long > ids = null ;
180+ if (value instanceof Collection ) {
181+ ids = (Collection ) value ;
182+ } else if (value .getClass ().isArray ()) {
183+ Long [] arr = (Long []) value ;
184+ ids = Arrays .asList (arr );
185+ }
186+
187+ Map <String , List <Long >> uidsegmap = new HashMap <>();
188+ for (Long uid : ids ) {
189+ String tblName = strategy .buildTableNameById (tableName , uid );
190+ List <Long > uidseg = uidsegmap .get (tblName );// new
191+ // ArrayList<>();
192+ if (uidseg == null ) {
193+ uidseg = new ArrayList <>();
194+ uidsegmap .put (tblName , uidseg );
195+ }
196+ uidseg .add (uid );
197+ }
198+ List <?> resultlist = new ArrayList <>();
199+ for (Entry <String , List <Long >> entry : uidsegmap .entrySet ()) {
200+ try {
201+ String newTableName = entry .getKey ();
202+
203+ Map <String , Object > newParam = new ParamMap <Object >();
204+ newParam .put (shard .idKey (), entry .getValue ());
205+
206+ ShardInSqlSource ShardInSqlSource = new ShardInSqlSource (sqlSource , tableName , newTableName ,
207+ newParam );
208+ ClzUtil .setField (ms , "sqlSource" , ShardInSqlSource );
209+ resultlist .addAll ((Collection ) invocation .proceed ());
210+ } catch (Exception e ) {
211+ log .error ("shard manybyid error" , e );
212+ }
213+ }
214+ return resultlist ;
215+ } else if (type == ShardType .Many ) {
216+ int tableid = 0 ;
217+ List <?> resultlist = new ArrayList <>();
218+ while (true ) {
219+ try {
220+ String newTableName = strategy .buildTableNameByIdx (tableName ,tableid );
221+ ShardSqlSource shardSqlSource = new ShardSqlSource (sqlSource , tableName , newTableName );
222+ ClzUtil .setField (ms , "sqlSource" , shardSqlSource );
223+ resultlist .addAll ((Collection ) invocation .proceed ());
224+ tableid ++;
225+ } catch (InvocationTargetException e ) {
226+ Throwable newe = CommUtil .getException (e , "com.mysql.jdbc.exceptions.jdbc4.MySQLSyntaxErrorException" );
227+ if (newe != null ) {
228+ return resultlist ;
229+ }
230+ throw e ;
231+ }
232+ }
233+ }else if (type == ShardType .Max ||type == ShardType .Min ){
234+ int tableid = 0 ;
235+ Set <Long > set =new HashSet <>();
236+ List <Long > resultlist = null ;
237+ while (true ) {
238+ try {
239+ String newTableName = strategy .buildTableNameByIdx (tableName ,tableid );
240+ ShardSqlSource shardSqlSource = new ShardSqlSource (sqlSource , tableName , newTableName );
241+ ClzUtil .setField (ms , "sqlSource" , shardSqlSource );
242+ resultlist = (List <Long >) invocation .proceed ();
243+ if (!resultlist .isEmpty ()) {
244+ set .add (resultlist .get (0 ));
245+ }
246+ tableid ++;
247+ } catch (InvocationTargetException e ) {
248+ Throwable newe = CommUtil .getException (e , "com.mysql.jdbc.exceptions.jdbc4.MySQLSyntaxErrorException" );
249+ if (newe != null ) {
250+ Long value = (type == ShardType .Max ?Collections .max (set ):Collections .min (set ));
251+ resultlist .set (0 , value );
252+ return resultlist ;
253+ }
254+ throw e ;
255+ }
256+ }
257+ }
258+ // 传递给下一个拦截器处理
259+ } finally {
260+ ClzUtil .setField (ms , "sqlSource" , sqlSource );
261+ }
262+ }
263+ return invocation .proceed ();
264+ }
265+
266+ @ Override
267+ public Object plugin (Object target ) {
268+ // 当目标类是StatementHandler类型时,才包装目标类,否者直接返回目标本身,减少目标被代理的
269+ // 次数
270+ if (target instanceof Executor ) {
271+ return Plugin .wrap (target , this );
272+ } else {
273+ return target ;
274+ }
275+ }
276+
277+ @ Override
278+ public void setProperties (Properties properties ) {
279+ try {
280+ String include = properties .getProperty ("include" );
281+ String [] items = include .split ("," );
282+ for (String item : items ) {
283+ String [] kv = item .split ("=" );
284+ String clzName = kv [0 ];
285+ String tableName = kv [1 ];
286+ includeSet .add (clzName );
287+ includeMap .put (Class .forName (clzName ), tableName );
288+ }
289+ Class <ShardStrategy > strategyClz = (Class <ShardStrategy >) Class .forName (properties .getProperty ("strategy" ));
290+ this .strategy = strategyClz .newInstance ();
291+ System .out .println (properties );
292+ } catch (ClassNotFoundException | InstantiationException | IllegalAccessException e ) {
293+ throw new IllegalArgumentException (e );
294+ }
295+ }
296+
297+ @ Target ({ ElementType .FIELD , ElementType .METHOD })
298+ @ Retention (RetentionPolicy .RUNTIME )
299+ @ Documented
300+ public @interface Shard {
301+ public ShardType value ();
302+ public String idKey () default "" ;
303+ }
304+ public static enum ShardType {One ,OneById ,Many ,ManyById ,Count ,Max ,Min };
305+ public static interface ShardStrategy {
306+ public String buildTableNameById (String tableName , long id );
307+
308+ public String buildTableNameByIdx (String tableName , int tableIdx );
309+ }
310+
311+ private static class ShardSqlSource implements SqlSource {
312+ private SqlSource sqlSource ;
313+ private String oldTableName ;
314+ private String newTableName ;
315+
316+ public ShardSqlSource (SqlSource src ,String oldTableName ,String newTableName ) {
317+ this .sqlSource = src ;
318+ this .oldTableName =oldTableName ;
319+ this .newTableName =newTableName ;
320+ }
321+
322+ @ Override
323+ public BoundSql getBoundSql (Object parameterObject ) {
324+ BoundSql boundSql = sqlSource .getBoundSql (parameterObject );
325+ String sql = boundSql .getSql ().replaceAll (oldTableName , newTableName );
326+ ClzUtil .setField (boundSql , "sql" , sql );
327+ return boundSql ;
328+ }
329+ }
330+ private static class ShardInSqlSource implements SqlSource {
331+ private SqlSource sqlSource ;
332+ private String oldTableName ;
333+ private String newTableName ;
334+ private Object newParameterObject ;
335+
336+ public ShardInSqlSource (SqlSource src , String oldTableName , String newTableName , Object newParameterObject ) {
337+ this .sqlSource = src ;
338+ this .oldTableName = oldTableName ;
339+ this .newTableName = newTableName ;
340+ this .newParameterObject = newParameterObject ;
341+ }
342+
343+ @ Override
344+ public BoundSql getBoundSql (Object parameterObject ) {
345+ BoundSql boundSql = sqlSource .getBoundSql (newParameterObject );
346+ String sql = boundSql .getSql ().replaceAll (oldTableName , newTableName );
347+ ClzUtil .setField (boundSql , "sql" , sql );
348+ return boundSql ;
349+ }
350+ }
351+ }
0 commit comments