1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21 package org.dbunit.database;
22
23 import java.sql.PreparedStatement;
24 import java.sql.ResultSet;
25 import java.sql.SQLException;
26 import java.util.ArrayList;
27 import java.util.HashMap;
28 import java.util.HashSet;
29 import java.util.Iterator;
30 import java.util.List;
31 import java.util.Map;
32 import java.util.Set;
33 import java.util.SortedSet;
34 import java.util.TreeSet;
35
36 import org.apache.commons.collections.map.ListOrderedMap;
37 import org.dbunit.database.search.ForeignKeyRelationshipEdge;
38 import org.dbunit.dataset.DataSetException;
39 import org.dbunit.dataset.IDataSet;
40 import org.dbunit.dataset.ITable;
41 import org.dbunit.dataset.ITableIterator;
42 import org.dbunit.dataset.ITableMetaData;
43 import org.dbunit.dataset.filter.AbstractTableFilter;
44 import org.dbunit.util.SQLHelper;
45 import org.slf4j.Logger;
46 import org.slf4j.LoggerFactory;
47
48
49
50
51
52
53
54
55
56
57
58
59
60 public class PrimaryKeyFilter extends AbstractTableFilter {
61
62 private final IDatabaseConnection connection;
63
64 private final PkTableMap allowedPKsPerTable;
65 private final PkTableMap allowedPKsInput;
66 private final PkTableMap pksToScanPerTable;
67
68 private final boolean reverseScan;
69
70 protected final Logger logger = LoggerFactory.getLogger(getClass());
71
72
73 private final Map pkColumnPerTable = new HashMap();
74
75 private final Map fkEdgesPerTable = new HashMap();
76 private final Map fkReverseEdgesPerTable = new HashMap();
77
78
79 private final List tableNames = new ArrayList();
80
81
82
83
84
85
86
87
88
89
90
91 public PrimaryKeyFilter(IDatabaseConnection connection, PkTableMap allowedPKs, boolean reverseDependency) {
92 this.connection = connection;
93 this.allowedPKsPerTable = new PkTableMap();
94 this.allowedPKsInput = allowedPKs;
95 this.reverseScan = reverseDependency;
96
97
98 this.pksToScanPerTable = new PkTableMap(allowedPKs);
99 }
100
101 public void nodeAdded(Object node) {
102 this.tableNames.add( node );
103 if ( this.logger.isDebugEnabled() ) {
104 this.logger.debug("nodeAdded: " + node );
105 }
106 }
107
108 public void edgeAdded(ForeignKeyRelationshipEdge edge) {
109 if ( this.logger.isDebugEnabled() ) {
110 this.logger.debug("edgeAdded: " + edge );
111 }
112
113 String from = (String) edge.getFrom();
114 Set edges = (Set) this.fkEdgesPerTable.get(from);
115 if ( edges == null ) {
116 edges = new HashSet();
117 this.fkEdgesPerTable.put( from, edges );
118 }
119 if ( ! edges.contains(edge) ) {
120 edges.add(edge);
121 }
122
123
124 String to = (String) edge.getTo();
125 edges = (Set) this.fkReverseEdgesPerTable.get(to);
126 if ( edges == null ) {
127 edges = new HashSet();
128 this.fkReverseEdgesPerTable.put( to, edges );
129 }
130 if ( ! edges.contains(edge) ) {
131 edges.add(edge);
132 }
133
134
135 updatePkCache(to, edge);
136
137 }
138
139
140
141
142 public boolean isValidName(String tableName) throws DataSetException {
143
144
145 return true;
146 }
147
148 public ITableIterator iterator(IDataSet dataSet, boolean reversed)
149 throws DataSetException {
150 if ( this.logger.isDebugEnabled() ) {
151 this.logger.debug("Filter.iterator()" );
152 }
153 try {
154 searchPKs(dataSet);
155 } catch (SQLException e) {
156 throw new DataSetException( e );
157 }
158 return new FilterIterator(reversed ? dataSet.reverseIterator() : dataSet
159 .iterator());
160 }
161
162 private void searchPKs(IDataSet dataSet) throws DataSetException, SQLException {
163 logger.debug("searchPKs(dataSet={}) - start", dataSet);
164
165 int counter = 0;
166 while ( !this.pksToScanPerTable.isEmpty() ) {
167 counter ++;
168 if ( this.logger.isDebugEnabled() ) {
169 this.logger.debug( "RUN # " + counter );
170 }
171
172 for( int i=this.tableNames.size()-1; i>=0; i-- ) {
173 String tableName = (String) this.tableNames.get(i);
174
175 String pkColumn = dataSet.getTable(tableName).getTableMetaData().getPrimaryKeys()[0].getColumnName();
176 Set tmpSet = this.pksToScanPerTable.get( tableName );
177 if ( tmpSet != null && ! tmpSet.isEmpty() ) {
178 Set pksToScan = new HashSet( tmpSet );
179 if ( this.logger.isDebugEnabled() ) {
180 this.logger.debug( "before search: "+ tableName + "=>" + pksToScan );
181 }
182 scanPKs( tableName, pkColumn, pksToScan );
183 scanReversePKs( tableName, pksToScan );
184 allowPKs( tableName, pksToScan );
185 removePKsToScan( tableName, pksToScan );
186 }
187 }
188 removeScannedTables();
189 }
190 if ( this.logger.isDebugEnabled() ) {
191 this.logger.debug( "Finished searchIds()" );
192 }
193 }
194
195 private void removeScannedTables() {
196 logger.debug("removeScannedTables() - start");
197 this.pksToScanPerTable.retainOnly(this.tableNames);
198 }
199
200 private void allowPKs(String table, Set newAllowedPKs) {
201 logger.debug("allowPKs(table={}, newAllowedPKs={}) - start", table, newAllowedPKs);
202
203
204 Set forcedAllowedPKs = this.allowedPKsInput.get( table );
205 if( forcedAllowedPKs == null || forcedAllowedPKs.isEmpty() ) {
206 allowedPKsPerTable.addAll(table, newAllowedPKs );
207 } else {
208 for(Iterator iterator = newAllowedPKs.iterator(); iterator.hasNext(); ) {
209 Object id = iterator.next();
210 if( forcedAllowedPKs.contains(id) ) {
211 allowedPKsPerTable.add(table, id);
212 }
213 else
214 {
215 if ( this.logger.isDebugEnabled() ) {
216 this.logger.debug( "Discarding id " + id + " of table " + table +
217 " as it was not included in the input!" );
218 }
219 }
220 }
221 }
222 }
223
224 private void scanPKs( String table, String pkColumn, Set allowedIds ) throws SQLException {
225 if (logger.isDebugEnabled())
226 {
227 logger.debug("scanPKs(table={}, pkColumn={}, allowedIds={}) - start",
228 new Object[]{ table, pkColumn, allowedIds });
229 }
230
231 Set fkEdges = (Set) this.fkEdgesPerTable.get( table );
232 if ( fkEdges == null || fkEdges.isEmpty() ) {
233 return;
234 }
235
236 List fkTables = new ArrayList( fkEdges.size() );
237 StringBuffer colsBuffer = new StringBuffer();
238 for(Iterator iterator = fkEdges.iterator(); iterator.hasNext(); ) {
239 ForeignKeyRelationshipEdge edge = (ForeignKeyRelationshipEdge) iterator.next();
240 fkTables.add( edge.getTo() );
241 colsBuffer.append( edge.getFKColumn() );
242 if ( iterator.hasNext() ) {
243 colsBuffer.append( ", " );
244 }
245 }
246
247 String sql = "SELECT " + colsBuffer + " FROM " + table +
248 " WHERE " + pkColumn + " = ? ";
249 if ( this.logger.isDebugEnabled() ) {
250 this.logger.debug( "SQL: " + sql );
251 }
252
253 scanPKs(table, sql, allowedIds, fkTables);
254 }
255
256 private void scanPKs(String table, String sql, Set allowedIds, List fkTables) throws SQLException
257 {
258 PreparedStatement pstmt = null;
259 ResultSet rs = null;
260 try {
261 pstmt = this.connection.getConnection().prepareStatement( sql );
262 for(Iterator iterator = allowedIds.iterator(); iterator.hasNext(); ) {
263 Object pk = iterator.next();
264 if( this.logger.isDebugEnabled() ) {
265 this.logger.debug("Executing sql for ? = " + pk );
266 }
267 pstmt.setObject( 1, pk );
268 rs = pstmt.executeQuery();
269 while( rs.next() ) {
270 for( int i=0; i<fkTables.size(); i++ ) {
271 String newTable = (String) fkTables.get(i);
272 Object fk = rs.getObject(i+1);
273 if( fk != null ) {
274 if( this.logger.isDebugEnabled() ) {
275 this.logger.debug("New ID: " + newTable + "->" + fk);
276 }
277 addPKToScan( newTable, fk );
278 }
279 else {
280 this.logger.warn( "Found null FK for relationship " +
281 table + "=>" + newTable );
282 }
283 }
284 }
285 }
286 } catch (SQLException e) {
287 logger.error("scanPKs()", e);
288 }
289 finally {
290
291 SQLHelper.close( rs, pstmt );
292 }
293 }
294
295 private void scanReversePKs(String table, Set pksToScan) throws SQLException {
296 logger.debug("scanReversePKs(table={}, pksToScan={}) - start", table, pksToScan);
297
298 if ( ! this.reverseScan ) {
299 return;
300 }
301 Set fkReverseEdges = (Set) this.fkReverseEdgesPerTable.get( table );
302 if ( fkReverseEdges == null || fkReverseEdges.isEmpty() ) {
303 return;
304 }
305 Iterator iterator = fkReverseEdges.iterator();
306 while ( iterator.hasNext() ) {
307 ForeignKeyRelationshipEdge edge = (ForeignKeyRelationshipEdge) iterator.next();
308 addReverseEdge( edge, pksToScan );
309 }
310 }
311
312 private void addReverseEdge(ForeignKeyRelationshipEdge edge, Set idsToScan) throws SQLException {
313 logger.debug("addReverseEdge(edge={}, idsToScan=) - start", edge, idsToScan);
314
315 String fkTable = (String) edge.getFrom();
316 String fkColumn = edge.getFKColumn();
317 String pkColumn = getPKColumn( fkTable );
318
319 String sql = "SELECT " + pkColumn + " FROM " + fkTable + " WHERE " + fkColumn + " = ? ";
320
321 PreparedStatement pstmt = null;
322 ResultSet rs = null;
323 try {
324 if ( this.logger.isDebugEnabled() ) {
325 this.logger.debug( "Preparing SQL query '" + sql + "'" );
326 }
327 pstmt = this.connection.getConnection().prepareStatement( sql );
328 for(Iterator iterator = idsToScan.iterator(); iterator.hasNext(); ) {
329 Object pk = iterator.next();
330 if ( this.logger.isDebugEnabled() ) {
331 this.logger.debug( "executing query '" + sql + "' for ? = " + pk );
332 }
333 pstmt.setObject( 1, pk );
334 rs = pstmt.executeQuery();
335 while( rs.next() ) {
336 Object fk = rs.getObject(1);
337 addPKToScan( fkTable, fk );
338 }
339 }
340 } finally {
341 SQLHelper.close( rs, pstmt );
342 }
343 }
344
345 private void updatePkCache(String table, ForeignKeyRelationshipEdge edge) {
346 logger.debug("updatePkCache(to={}, edge={}) - start", table, edge);
347
348 Object pkTo = this.pkColumnPerTable.get(table);
349 if ( pkTo == null ) {
350 String pkColumn = edge.getPKColumn();
351 this.pkColumnPerTable.put( table, pkColumn );
352 }
353 }
354
355
356 private String getPKColumn( String table ) throws SQLException {
357 logger.debug("getPKColumn(table={}) - start", table);
358
359
360 String pkColumn = (String) this.pkColumnPerTable.get( table );
361 if ( pkColumn == null ) {
362
363 pkColumn = SQLHelper.getPrimaryKeyColumn( this.connection.getConnection(), table );
364 this.pkColumnPerTable.put( table, pkColumn );
365 }
366 return pkColumn;
367 }
368
369
370 private void removePKsToScan(String table, Set ids) {
371 logger.debug("removePKsToScan(table={}, ids={}) - start", table, ids);
372
373 Set pksToScan = this.pksToScanPerTable.get(table);
374 if ( pksToScan != null ) {
375 if ( pksToScan == ids ) {
376 throw new RuntimeException( "INTERNAL ERROR on removeIdsToScan() for table " + table );
377 } else {
378 pksToScan.removeAll( ids );
379 }
380 }
381 }
382
383 private void addPKToScan(String table, Object pk) {
384 logger.debug("addPKToScan(table={}, pk={}) - start", table, pk);
385
386
387 if(this.allowedPKsPerTable.contains(table, pk)) {
388 if ( this.logger.isDebugEnabled() ) {
389 this.logger.debug( "Discarding already scanned id=" + pk + " for table " + table );
390 }
391 return;
392 }
393
394 this.pksToScanPerTable.add(table, pk);
395 }
396
397 public String toString() {
398 StringBuffer sb = new StringBuffer();
399 sb.append("tableNames=").append(tableNames);
400 sb.append(", allowedPKsInput=").append(allowedPKsInput);
401 sb.append(", allowedPKsPerTable=").append(allowedPKsPerTable);
402 sb.append(", fkEdgesPerTable=").append(fkEdgesPerTable);
403 sb.append(", fkReverseEdgesPerTable=").append(fkReverseEdgesPerTable);
404 sb.append(", pkColumnPerTable=").append(pkColumnPerTable);
405 sb.append(", pksToScanPerTable=").append(pksToScanPerTable);
406 sb.append(", reverseScan=").append(reverseScan);
407 sb.append(", connection=").append(connection);
408 return sb.toString();
409 }
410
411
412 private class FilterIterator implements ITableIterator {
413
414 private final ITableIterator _iterator;
415
416 public FilterIterator(ITableIterator iterator) {
417
418 _iterator = iterator;
419 }
420
421
422
423
424 public boolean next() throws DataSetException {
425 if ( logger.isDebugEnabled() ) {
426 logger.debug("Iterator.next()" );
427 }
428 while (_iterator.next()) {
429 if (accept(_iterator.getTableMetaData().getTableName())) {
430 return true;
431 }
432 }
433 return false;
434 }
435
436 public ITableMetaData getTableMetaData() throws DataSetException {
437 if ( logger.isDebugEnabled() ) {
438 logger.debug("Iterator.getTableMetaData()" );
439 }
440 return _iterator.getTableMetaData();
441 }
442
443 public ITable getTable() throws DataSetException {
444 if ( logger.isDebugEnabled() ) {
445 logger.debug("Iterator.getTable()" );
446 }
447 ITable table = _iterator.getTable();
448 String tableName = table.getTableMetaData().getTableName();
449 Set allowedPKs = allowedPKsPerTable.get( tableName );
450 if ( allowedPKs != null ) {
451 return new PrimaryKeyFilteredTableWrapper(table, allowedPKs);
452 }
453 return table;
454 }
455 }
456
457
458
459
460
461
462
463
464
465 public static class PkTableMap
466 {
467 private final ListOrderedMap pksPerTable;
468 private final Logger logger = LoggerFactory.getLogger(PkTableMap.class);
469
470 public PkTableMap()
471 {
472 this.pksPerTable = new ListOrderedMap();
473 }
474
475
476
477
478
479 public PkTableMap(PkTableMap allowedPKs) {
480 this.pksPerTable = new ListOrderedMap();
481 Iterator iterator = allowedPKs.pksPerTable.entrySet().iterator();
482 while ( iterator.hasNext() ) {
483 Map.Entry entry = (Map.Entry) iterator.next();
484 String table = (String)entry.getKey();
485 SortedSet pkObjectSet = (SortedSet) entry.getValue();
486 SortedSet newSet = new TreeSet( pkObjectSet );
487 this.pksPerTable.put( table, newSet );
488 }
489 }
490
491 public int size() {
492 return pksPerTable.size();
493 }
494
495 public boolean isEmpty() {
496 return pksPerTable.isEmpty();
497 }
498
499 public boolean contains(String table, Object pkObject) {
500 Set pksPerTable = this.get(table);
501 return (pksPerTable != null && pksPerTable.contains(pkObject));
502 }
503
504 public void remove(String tableName) {
505 this.pksPerTable.remove(tableName);
506 }
507
508 public void put(String table, SortedSet pkObjects) {
509 this.pksPerTable.put(table, pkObjects);
510 }
511
512 public void add(String tableName, Object pkObject) {
513 Set pksPerTable = getCreateIfNeeded(tableName);
514 pksPerTable.add(pkObject);
515 }
516
517 public void addAll(String tableName, Set pkObjectsToAdd) {
518 Set pksPerTable = this.getCreateIfNeeded(tableName);
519 pksPerTable.addAll(pkObjectsToAdd);
520 }
521
522 public SortedSet get(String tableName) {
523 return (SortedSet) this.pksPerTable.get(tableName);
524 }
525
526 private SortedSet getCreateIfNeeded(String tableName){
527 SortedSet pksPerTable = this.get(tableName);
528
529 if( pksPerTable == null ) {
530 pksPerTable = new TreeSet();
531 this.pksPerTable.put(tableName, pksPerTable);
532 }
533 return pksPerTable;
534 }
535
536 public String[] getTableNames() {
537 return (String[]) this.pksPerTable.keySet().toArray(new String[0]);
538 }
539
540 public void retainOnly(List tableNames) {
541
542 List tablesToRemove = new ArrayList();
543 for(Iterator iterator = this.pksPerTable.entrySet().iterator(); iterator.hasNext(); ) {
544 Map.Entry entry = (Map.Entry) iterator.next();
545 String table = (String) entry.getKey();
546 SortedSet pksToScan = (SortedSet) entry.getValue();
547 boolean removeIt = pksToScan.isEmpty();
548
549 if ( ! tableNames.contains(table) ) {
550 if ( this.logger.isWarnEnabled() ) {
551 this.logger.warn("Discarding ids " + pksToScan + " of table " + table +
552 "as this table has not been passed as input" );
553 }
554 removeIt = true;
555 }
556 if ( removeIt ) {
557 tablesToRemove.add( table );
558 }
559 }
560
561 for(Iterator iterator = tablesToRemove.iterator(); iterator.hasNext(); ) {
562 this.remove( (String)iterator.next() );
563 }
564 }
565
566
567 public String toString() {
568 StringBuffer sb = new StringBuffer();
569 sb.append("pKsPerTable=").append(pksPerTable);
570 return sb.toString();
571 }
572
573 }
574 }